Skip to content

Commit 976f2b8

Browse files
committed
add simple Feed-forward network (for ESM2->chebi task)
1 parent 7da8963 commit 976f2b8

File tree

3 files changed

+67
-0
lines changed

3 files changed

+67
-0
lines changed

chebai/models/ffn.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from typing import Dict, Any, Tuple
2+
3+
from chebai.models import ChebaiBaseNet
4+
import torch
5+
from torch import Tensor
6+
7+
class FFN(ChebaiBaseNet):
8+
9+
NAME = "FFN"
10+
11+
def __init__(self, input_size: int = 1000, num_hidden_layers: int = 3, hidden_size: int = 128, **kwargs):
12+
super().__init__(**kwargs)
13+
14+
self.layers = torch.nn.ModuleList()
15+
self.layers.append(torch.nn.Linear(input_size, hidden_size))
16+
for _ in range(num_hidden_layers):
17+
self.layers.append(torch.nn.Linear(hidden_size, hidden_size))
18+
self.layers.append(torch.nn.Linear(hidden_size, self.out_dim))
19+
20+
def _get_prediction_and_labels(self, data, labels, model_output):
21+
d = model_output["logits"]
22+
loss_kwargs = data.get("loss_kwargs", dict())
23+
if "non_null_labels" in loss_kwargs:
24+
n = loss_kwargs["non_null_labels"]
25+
d = data[n]
26+
return torch.sigmoid(d), labels.int() if labels is not None else None
27+
28+
def _process_for_loss(
29+
self,
30+
model_output: Dict[str, Tensor],
31+
labels: Tensor,
32+
loss_kwargs: Dict[str, Any],
33+
) -> Tuple[Tensor, Tensor, Dict[str, Any]]:
34+
"""
35+
Process the model output for calculating the loss.
36+
37+
Args:
38+
model_output (Dict[str, Tensor]): The output of the model.
39+
labels (Tensor): The target labels.
40+
loss_kwargs (Dict[str, Any]): Additional loss arguments.
41+
42+
Returns:
43+
tuple: A tuple containing the processed model output, labels, and loss arguments.
44+
"""
45+
kwargs_copy = dict(loss_kwargs)
46+
if labels is not None:
47+
labels = labels.float()
48+
return model_output["logits"], labels, kwargs_copy
49+
50+
def forward(self, data, **kwargs):
51+
x = data["features"]
52+
for layer in self.layers:
53+
x = torch.relu(layer(x))
54+
return {"logits": x}
55+
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
class_path: chebai.preprocessing.datasets.deepGO.go_uniprot.DeepGO2MigratedData
2+
init_args:
3+
go_branch: "MF"
4+
max_sequence_length: 1000
5+
use_esm2_embeddings: True

configs/model/ffn.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
class_path: chebai.models.ffn.FFN
2+
init_args:
3+
optimizer_kwargs:
4+
lr: 1e-3
5+
hidden_size: 128
6+
num_hidden_layers: 3
7+
input_size: 2560

0 commit comments

Comments
 (0)