|
| 1 | +from typing import Dict, Any, Tuple |
| 2 | + |
| 3 | +from chebai.models import ChebaiBaseNet |
| 4 | +import torch |
| 5 | +from torch import Tensor |
| 6 | + |
| 7 | +class FFN(ChebaiBaseNet): |
| 8 | + |
| 9 | + NAME = "FFN" |
| 10 | + |
| 11 | + def __init__(self, input_size: int = 1000, num_hidden_layers: int = 3, hidden_size: int = 128, **kwargs): |
| 12 | + super().__init__(**kwargs) |
| 13 | + |
| 14 | + self.layers = torch.nn.ModuleList() |
| 15 | + self.layers.append(torch.nn.Linear(input_size, hidden_size)) |
| 16 | + for _ in range(num_hidden_layers): |
| 17 | + self.layers.append(torch.nn.Linear(hidden_size, hidden_size)) |
| 18 | + self.layers.append(torch.nn.Linear(hidden_size, self.out_dim)) |
| 19 | + |
| 20 | + def _get_prediction_and_labels(self, data, labels, model_output): |
| 21 | + d = model_output["logits"] |
| 22 | + loss_kwargs = data.get("loss_kwargs", dict()) |
| 23 | + if "non_null_labels" in loss_kwargs: |
| 24 | + n = loss_kwargs["non_null_labels"] |
| 25 | + d = data[n] |
| 26 | + return torch.sigmoid(d), labels.int() if labels is not None else None |
| 27 | + |
| 28 | + def _process_for_loss( |
| 29 | + self, |
| 30 | + model_output: Dict[str, Tensor], |
| 31 | + labels: Tensor, |
| 32 | + loss_kwargs: Dict[str, Any], |
| 33 | + ) -> Tuple[Tensor, Tensor, Dict[str, Any]]: |
| 34 | + """ |
| 35 | + Process the model output for calculating the loss. |
| 36 | +
|
| 37 | + Args: |
| 38 | + model_output (Dict[str, Tensor]): The output of the model. |
| 39 | + labels (Tensor): The target labels. |
| 40 | + loss_kwargs (Dict[str, Any]): Additional loss arguments. |
| 41 | +
|
| 42 | + Returns: |
| 43 | + tuple: A tuple containing the processed model output, labels, and loss arguments. |
| 44 | + """ |
| 45 | + kwargs_copy = dict(loss_kwargs) |
| 46 | + if labels is not None: |
| 47 | + labels = labels.float() |
| 48 | + return model_output["logits"], labels, kwargs_copy |
| 49 | + |
| 50 | + def forward(self, data, **kwargs): |
| 51 | + x = data["features"] |
| 52 | + for layer in self.layers: |
| 53 | + x = torch.relu(layer(x)) |
| 54 | + return {"logits": x} |
| 55 | + |
0 commit comments