Skip to content

Commit fd19028

Browse files
ChemicalX
1 parent 91b3fa0 commit fd19028

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

chemicalx/models/deepsynergy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def forward(
3838
drug_features_right: torch.FloatTensor,
3939
) -> torch.FloatTensor:
4040

41-
hidden = torch.cat([drug_features_left, drug_features_right, context_features], dim=1)
41+
hidden = torch.cat([context_features, drug_features_left, drug_features_right], dim=1)
4242
hidden = self.encoder(hidden)
4343
hidden = F.relu(hidden)
4444
hidden = self.hidden_first(hidden)

tests/unit/test_models.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,22 +78,24 @@ def test_DeepDrug(self):
7878
def test_DeepSynergy(self):
7979

8080
model = DeepSynergy(
81-
context_channels=64,
82-
drug_channels=32,
81+
context_channels=112,
82+
drug_channels=256,
8383
input_hidden_channels=32,
8484
middle_hidden_channels=16,
8585
final_hidden_channels=16,
8686
dropout_rate=0.5,
8787
)
8888

8989
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0001)
90+
model.train()
9091
loss = torch.nn.BCELoss()
9192
for batch in self.generator:
9293
optimizer.zero_grad()
9394
prediction = model(batch.context_features, batch.drug_features_left, batch.drug_features_right)
9495
output = loss(prediction, batch.labels)
9596
output.backward()
9697
optimizer.step()
98+
assert prediction.shape[0] == batch.labels.shape[0]
9799

98100
def test_DeepDDS(self):
99101
model = DeepDDS(x=2)

0 commit comments

Comments
 (0)