|
1 | | -from .base import Model |
| 1 | +import torch |
| 2 | +import torch.nn.functional as F |
| 3 | +from torchdrug.data import PackedGraph |
| 4 | +from torchdrug.layers import MeanReadout |
| 5 | +from torchdrug.models import GraphConvolutionalNetwork |
2 | 6 |
|
3 | | -__all__ = [ |
4 | | - "EPGCNDS", |
5 | | -] |
6 | 7 |
|
| 8 | +class EPGCNDS(torch.nn.Module): |
| 9 | + r"""The EPGCN-DS model from the `"Structure-Based Drug-Drug Interaction Detection |
| 10 | + via Expressive Graph Convolutional Networks and Deep Sets " <https://ojs.aaai.org/index.php/AAAI/article/view/7236>`_ paper. |
7 | 11 |
|
8 | | -class EPGCNDS(Model): |
9 | | - """An implementation of the EPGCNDS model. |
10 | | -
|
11 | | - .. seealso:: https://github.com/AstraZeneca/chemicalx/issues/22 |
| 12 | + Args: |
| 13 | + in_channels (int): The number of molecular features. |
| 14 | + hidden_channels (int): The number of graph convolutional filters. |
| 15 | + out_channels (int): The number of hidden layer neurons in the last layer. |
12 | 16 | """ |
| 17 | + |
| 18 | + def __init__(self, in_channels: int, hidden_channels: int = 32, out_channels: int = 16): |
| 19 | + super(EPGCNDS, self).__init__() |
| 20 | + self.graph_convolution_in = GraphConvolutionalNetwork(in_channels, hidden_channels) |
| 21 | + self.graph_convolution_out = GraphConvolutionalNetwork(hidden_channels, out_channels) |
| 22 | + self.mean_readout = MeanReadout() |
| 23 | + self.final = torch.nn.Linear(out_channels, 1) |
| 24 | + |
| 25 | + def forward(self, molecules_left: PackedGraph, molecules_right: PackedGraph) -> torch.FloatTensor: |
| 26 | + """ |
| 27 | + A forward pass of the EPGCN-DS model. |
| 28 | +
|
| 29 | + Args: |
| 30 | + molecules_left (torch.FloatTensor): Batched molecules for the left side drugs. |
| 31 | + molecules_right (torch.FloatTensor): Batched molecules for the right side drugs. |
| 32 | + Returns: |
| 33 | + hidden (torch.FloatTensor): A column vector of predicted synergy scores. |
| 34 | + """ |
| 35 | + features_left = self.graph_convolution_in(molecules_left, molecules_left.data_dict["node_feature"])[ |
| 36 | + "node_feature" |
| 37 | + ] |
| 38 | + features_right = self.graph_convolution_in(molecules_right, molecules_right.data_dict["node_feature"])[ |
| 39 | + "node_feature" |
| 40 | + ] |
| 41 | + |
| 42 | + features_left = self.graph_convolution_out(molecules_left, features_left)["node_feature"] |
| 43 | + features_right = self.graph_convolution_out(molecules_right, features_right)["node_feature"] |
| 44 | + |
| 45 | + features_left = self.mean_readout(molecules_left, features_left) |
| 46 | + features_right = self.mean_readout(molecules_right, features_right) |
| 47 | + hidden = features_left + features_right |
| 48 | + hidden = torch.sigmoid(self.final(hidden)) |
| 49 | + return hidden |
0 commit comments