Skip to content

Commit 732660a

Browse files
Merge pull request #27 from AstraZeneca/deepsynergy
DeepSynergy and EPGCN-DS
2 parents c6e2f04 + 4a1d701 commit 732660a

File tree

6 files changed

+180
-33
lines changed

6 files changed

+180
-33
lines changed

README.md

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ Our framework solves the so called [drug pair scoring task](https://arxiv.org/ab
4242

4343
**Case Study Tutorials**
4444

45-
We provide in-depth case study tutorials in the [Documentation](https://chemicalx.readthedocs.io/en/latest/), each covers an aspect of ChemicalX’s functionality.
45+
We provide in-depth case study like tutorials in the [Documentation](https://chemicalx.readthedocs.io/en/latest/), each covers an aspect of ChemicalX’s functionality.
4646

4747
--------------------------------------------------------------------------------
4848

@@ -59,18 +59,9 @@ If you find *ChemicalX* and the new datasets useful in your research, please con
5959
}
6060
```
6161

62-
--------------------------------------------------------------------------------
63-
64-
**A simple example**
65-
66-
```python
67-
68-
```
69-
--------------------------------------------------------------------------------
70-
7162
**Methods Included**
7263

73-
In detail, the following temporal graph neural networks were implemented.
64+
In detail, the following drug pair scoring models were implemented.
7465

7566
**2017**
7667

@@ -112,13 +103,6 @@ In detail, the following temporal graph neural networks were implemented.
112103

113104
--------------------------------------------------------------------------------
114105

115-
**Auxiliary Layers**
116-
117-
118-
119-
--------------------------------------------------------------------------------
120-
121-
122106
Head over to our [documentation](https://chemicalx.readthedocs.io) to find out more about installation, creation of datasets and a full list of implemented methods and available datasets.
123107
For a quick start, check out the [examples](https://chemicalx.readthedocs.io) in the `examples/` directory.
124108

chemicalx/models/deepsynergy.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@ def __init__(
1919
self,
2020
context_channels: int,
2121
drug_channels: int,
22-
input_hidden_channels: int,
23-
middle_hidden_channels: int,
24-
final_hidden_channels: int,
25-
dropout_rate: float,
22+
input_hidden_channels: int = 32,
23+
middle_hidden_channels: int = 32,
24+
final_hidden_channels: int = 32,
25+
dropout_rate: float = 0.5,
2626
):
2727
super(DeepSynergy, self).__init__()
2828
self.encoder = torch.nn.Linear(drug_channels + drug_channels + context_channels, input_hidden_channels)
@@ -37,7 +37,16 @@ def forward(
3737
drug_features_left: torch.FloatTensor,
3838
drug_features_right: torch.FloatTensor,
3939
) -> torch.FloatTensor:
40+
"""
41+
A forward pass of the DeepSynergy model.
4042
43+
Args:
44+
context_features (torch.FloatTensor): A matrix of biological context features.
45+
drug_features_left (torch.FloatTensor): A matrix of head drug features.
46+
drug_features_right (torch.FloatTensor): A matrix of tail drug features.
47+
Returns:
48+
hidden (torch.FloatTensor): A column vector of predicted synergy scores.
49+
"""
4150
hidden = torch.cat([context_features, drug_features_left, drug_features_right], dim=1)
4251
hidden = self.encoder(hidden)
4352
hidden = F.relu(hidden)

chemicalx/models/epgcnds.py

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,49 @@
1-
from .base import Model
1+
import torch
2+
import torch.nn.functional as F
3+
from torchdrug.data import PackedGraph
4+
from torchdrug.layers import MeanReadout
5+
from torchdrug.models import GraphConvolutionalNetwork
26

3-
__all__ = [
4-
"EPGCNDS",
5-
]
67

8+
class EPGCNDS(torch.nn.Module):
9+
r"""The EPGCN-DS model from the `"Structure-Based Drug-Drug Interaction Detection
10+
via Expressive Graph Convolutional Networks and Deep Sets " <https://ojs.aaai.org/index.php/AAAI/article/view/7236>`_ paper.
711
8-
class EPGCNDS(Model):
9-
"""An implementation of the EPGCNDS model.
10-
11-
.. seealso:: https://github.com/AstraZeneca/chemicalx/issues/22
12+
Args:
13+
in_channels (int): The number of molecular features.
14+
hidden_channels (int): The number of graph convolutional filters.
15+
out_channels (int): The number of hidden layer neurons in the last layer.
1216
"""
17+
18+
def __init__(self, in_channels: int, hidden_channels: int = 32, out_channels: int = 16):
19+
super(EPGCNDS, self).__init__()
20+
self.graph_convolution_in = GraphConvolutionalNetwork(in_channels, hidden_channels)
21+
self.graph_convolution_out = GraphConvolutionalNetwork(hidden_channels, out_channels)
22+
self.mean_readout = MeanReadout()
23+
self.final = torch.nn.Linear(out_channels, 1)
24+
25+
def forward(self, molecules_left: PackedGraph, molecules_right: PackedGraph) -> torch.FloatTensor:
26+
"""
27+
A forward pass of the EPGCN-DS model.
28+
29+
Args:
30+
molecules_left (torch.FloatTensor): Batched molecules for the left side drugs.
31+
molecules_right (torch.FloatTensor): Batched molecules for the right side drugs.
32+
Returns:
33+
hidden (torch.FloatTensor): A column vector of predicted synergy scores.
34+
"""
35+
features_left = self.graph_convolution_in(molecules_left, molecules_left.data_dict["node_feature"])[
36+
"node_feature"
37+
]
38+
features_right = self.graph_convolution_in(molecules_right, molecules_right.data_dict["node_feature"])[
39+
"node_feature"
40+
]
41+
42+
features_left = self.graph_convolution_out(molecules_left, features_left)["node_feature"]
43+
features_right = self.graph_convolution_out(molecules_right, features_right)["node_feature"]
44+
45+
features_left = self.mean_readout(molecules_left, features_left)
46+
features_right = self.mean_readout(molecules_right, features_right)
47+
hidden = features_left + features_right
48+
hidden = torch.sigmoid(self.final(hidden))
49+
return hidden

examples/deepsynergy_examples.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import torch
2+
import pandas as pd
3+
from tqdm import tqdm
4+
from sklearn.metrics import roc_auc_score
5+
from chemicalx.models import DeepSynergy
6+
from chemicalx.data import DatasetLoader, BatchGenerator
7+
8+
loader = DatasetLoader("drugcombdb")
9+
10+
drug_feature_set = loader.get_drug_features()
11+
context_feature_set = loader.get_context_features()
12+
labeled_triples = loader.get_labeled_triples()
13+
14+
train_triples, test_triples = labeled_triples.train_test_split()
15+
16+
generator = BatchGenerator(
17+
batch_size=5120, context_features=True, drug_features=True, drug_molecules=False, labels=True
18+
)
19+
20+
generator.set_data(context_feature_set, drug_feature_set, train_triples)
21+
22+
model = DeepSynergy(context_channels=112, drug_channels=256)
23+
24+
optimizer = torch.optim.Adam(model.parameters())
25+
26+
model.train()
27+
28+
loss = torch.nn.BCELoss()
29+
30+
for epoch in tqdm(range(100)):
31+
for batch in generator:
32+
optimizer.zero_grad()
33+
34+
prediction = model(batch.context_features, batch.drug_features_left, batch.drug_features_right)
35+
36+
loss_value = loss(prediction, batch.labels)
37+
loss_value.backward()
38+
optimizer.step()
39+
40+
model.eval()
41+
42+
generator.set_labeled_triples(test_triples)
43+
44+
predictions = []
45+
for batch in generator:
46+
prediction = model(batch.context_features, batch.drug_features_left, batch.drug_features_right)
47+
prediction = prediction.detach().cpu().numpy()
48+
identifiers = batch.identifiers
49+
identifiers["prediction"] = prediction
50+
predictions.append(identifiers)
51+
52+
predictions = pd.concat(predictions)
53+
au_roc = roc_auc_score(predictions["label"], predictions["prediction"])
54+
print(f"AUROC : {au_roc:.4f}")

examples/epgcnds_examples.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import torch
2+
import pandas as pd
3+
from tqdm import tqdm
4+
from sklearn.metrics import roc_auc_score
5+
from chemicalx.data import DatasetLoader, BatchGenerator
6+
from chemicalx.models import EPGCNDS
7+
8+
loader = DatasetLoader("drugcombdb")
9+
10+
drug_feature_set = loader.get_drug_features()
11+
context_feature_set = loader.get_context_features()
12+
labeled_triples = loader.get_labeled_triples()
13+
14+
15+
generator = BatchGenerator(batch_size=1024, context_features=True, drug_features=True, drug_molecules=True, labels=True)
16+
17+
train_triples, test_triples = labeled_triples.train_test_split()
18+
19+
generator.set_data(context_feature_set, drug_feature_set, train_triples)
20+
21+
22+
model = EPGCNDS(69)
23+
24+
model.train()
25+
26+
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=10 ** -7)
27+
28+
loss = torch.nn.BCELoss()
29+
30+
for epoch in range(20):
31+
for batch in tqdm(generator):
32+
optimizer.zero_grad()
33+
prediction = model(batch.drug_molecules_left, batch.drug_molecules_right)
34+
output = loss(prediction, batch.labels)
35+
output.backward()
36+
optimizer.step()
37+
38+
model.eval()
39+
generator.set_labeled_triples(test_triples)
40+
41+
predictions = []
42+
for batch in tqdm(generator):
43+
prediction = model(batch.drug_molecules_left, batch.drug_molecules_right)
44+
prediction = prediction.detach().cpu().numpy()
45+
identifiers = batch.identifiers
46+
identifiers["prediction"] = prediction
47+
predictions.append(identifiers)
48+
49+
predictions = pd.concat(predictions)
50+
au_roc = roc_auc_score(predictions["label"], predictions["prediction"])
51+
print(f"AUROC : {au_roc:.4f}")

tests/unit/test_models.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,9 @@ def setUp(self):
2626
drug_feature_set = loader.get_drug_features()
2727
context_feature_set = loader.get_context_features()
2828
labeled_triples = loader.get_labeled_triples()
29+
labeled_triples, _ = labeled_triples.train_test_split(train_size=0.005)
2930
self.generator = BatchGenerator(
30-
batch_size=5120, context_features=True, drug_features=True, drug_molecules=True, labels=True
31+
batch_size=32, context_features=True, drug_features=True, drug_molecules=True, labels=True
3132
)
3233
self.generator.set_data(context_feature_set, drug_feature_set, labeled_triples)
3334

@@ -40,8 +41,19 @@ def test_DPDDI(self):
4041
assert model.x == 2
4142

4243
def test_EPGCNDS(self):
43-
model = EPGCNDS(x=2)
44-
assert model.x == 2
44+
45+
model = EPGCNDS(in_channels=69)
46+
47+
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0001)
48+
model.train()
49+
loss = torch.nn.BCELoss()
50+
for batch in self.generator:
51+
optimizer.zero_grad()
52+
prediction = model(batch.drug_molecules_left, batch.drug_molecules_right)
53+
output = loss(prediction, batch.labels)
54+
output.backward()
55+
optimizer.step()
56+
assert prediction.shape[0] == batch.labels.shape[0]
4557

4658
def test_GCNBMP(self):
4759
model = GCNBMP(x=2)

0 commit comments

Comments
 (0)