Skip to content

Commit 191c979

Browse files
committed
ffn: update for as per deepgo2 mlp architecture
1 parent 6370572 commit 191c979

File tree

2 files changed

+106
-16
lines changed

2 files changed

+106
-16
lines changed

chebai/models/ffn.py

Lines changed: 106 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,36 @@
1-
from typing import Dict, Any, Tuple
1+
from typing import Any, Dict, List, Optional, Tuple
22

3-
from chebai.models import ChebaiBaseNet
43
import torch
5-
from torch import Tensor
4+
from torch import Tensor, nn
5+
6+
from chebai.models import ChebaiBaseNet
67

78

89
class FFN(ChebaiBaseNet):
10+
# Reference: https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/models.py#L121-L139
911

1012
NAME = "FFN"
1113

1214
def __init__(
1315
self,
14-
input_size: int = 1000,
15-
num_hidden_layers: int = 3,
16-
hidden_size: int = 128,
16+
input_size: int,
17+
hidden_layers: List[int] = [
18+
1024,
19+
],
1720
**kwargs
1821
):
1922
super().__init__(**kwargs)
2023

21-
self.layers = torch.nn.ModuleList()
22-
self.layers.append(torch.nn.Linear(input_size, hidden_size))
23-
for _ in range(num_hidden_layers):
24-
self.layers.append(torch.nn.Linear(hidden_size, hidden_size))
25-
self.layers.append(torch.nn.Linear(hidden_size, self.out_dim))
24+
layers = []
25+
current_layer_input_size = input_size
26+
for hidden_dim in hidden_layers:
27+
layers.append(MLPBlock(current_layer_input_size, hidden_dim))
28+
layers.append(Residual(MLPBlock(current_layer_input_size, hidden_dim)))
29+
current_layer_input_size = hidden_dim
30+
31+
layers.append(torch.nn.Linear(current_layer_input_size, self.out_dim))
32+
layers.append(nn.Sigmoid())
33+
self.model = nn.Sequential(*layers)
2634

2735
def _get_prediction_and_labels(self, data, labels, model_output):
2836
d = model_output["logits"]
@@ -56,6 +64,90 @@ def _process_for_loss(
5664

5765
def forward(self, data, **kwargs):
5866
x = data["features"]
59-
for layer in self.layers:
60-
x = torch.relu(layer(x))
61-
return {"logits": x}
67+
return {"logits": self.model(x)}
68+
69+
70+
class Residual(nn.Module):
71+
"""
72+
A residual layer that adds the output of a function to its input.
73+
74+
Args:
75+
fn (nn.Module): The function to be applied to the input.
76+
77+
References:
78+
https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/base.py#L6-L35
79+
"""
80+
81+
def __init__(self, fn):
82+
"""
83+
Initialize the Residual layer with a given function.
84+
85+
Args:
86+
fn (nn.Module): The function to be applied to the input.
87+
"""
88+
super().__init__()
89+
self.fn = fn
90+
91+
def forward(self, x):
92+
"""
93+
Forward pass of the Residual layer.
94+
95+
Args:
96+
x: Input tensor.
97+
98+
Returns:
99+
torch.Tensor: The input tensor added to the result of applying the function `fn` to it.
100+
"""
101+
return x + self.fn(x)
102+
103+
104+
class MLPBlock(nn.Module):
105+
"""
106+
A basic Multi-Layer Perceptron (MLP) block with one fully connected layer.
107+
108+
Args:
109+
in_features (int): The number of input features.
110+
output_size (int): The number of output features.
111+
bias (boolean): Add bias to the linear layer
112+
layer_norm (boolean): Apply layer normalization
113+
dropout (float): The dropout value
114+
activation (nn.Module): The activation function to be applied after each fully connected layer.
115+
116+
References:
117+
https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/base.py#L38-L73
118+
119+
Example:
120+
```python
121+
# Create an MLP block with 2 hidden layers and ReLU activation
122+
mlp_block = MLPBlock(input_size=64, output_size=10, activation=nn.ReLU())
123+
124+
# Apply the MLP block to an input tensor
125+
input_tensor = torch.randn(32, 64)
126+
output = mlp_block(input_tensor)
127+
```
128+
"""
129+
130+
def __init__(
131+
self,
132+
in_features,
133+
out_features,
134+
bias=True,
135+
layer_norm=True,
136+
dropout=0.1,
137+
activation=nn.ReLU,
138+
):
139+
super().__init__()
140+
self.linear = nn.Linear(in_features, out_features, bias)
141+
self.activation = activation()
142+
self.layer_norm: Optional[nn.LayerNorm] = (
143+
nn.LayerNorm(out_features) if layer_norm else None
144+
)
145+
self.dropout: Optional[nn.Dropout] = nn.Dropout(dropout) if dropout else None
146+
147+
def forward(self, x):
148+
x = self.activation(self.linear(x))
149+
if self.layer_norm:
150+
x = self.layer_norm(x)
151+
if self.dropout:
152+
x = self.dropout(x)
153+
return x

configs/model/ffn.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,4 @@ class_path: chebai.models.ffn.FFN
22
init_args:
33
optimizer_kwargs:
44
lr: 1e-3
5-
hidden_size: 128
6-
num_hidden_layers: 3
75
input_size: 2560

0 commit comments

Comments
 (0)