diff --git a/chebai/models/ffn.py b/chebai/models/ffn.py index 194dffe9..18e9df4d 100644 --- a/chebai/models/ffn.py +++ b/chebai/models/ffn.py @@ -14,10 +14,14 @@ def __init__( hidden_layers: List[int] = [ 1024, ], + use_adam_optimizer: bool = False, **kwargs, ): super().__init__(**kwargs) + self.use_adam_optimizer: bool = bool(use_adam_optimizer) + print(f"Using Adam optimizer: {self.use_adam_optimizer}") + layers = [] current_layer_input_size = self.input_dim for hidden_dim in hidden_layers: @@ -26,7 +30,6 @@ def __init__( current_layer_input_size = hidden_dim layers.append(torch.nn.Linear(current_layer_input_size, self.out_dim)) - layers.append(nn.Sigmoid()) self.model = nn.Sequential(*layers) def _get_prediction_and_labels(self, data, labels, model_output): @@ -63,6 +66,21 @@ def forward(self, data, **kwargs): x = data["features"] return {"logits": self.model(x)} + def configure_optimizers(self, **kwargs) -> torch.optim.Optimizer: + """ + Configures the optimizers. + + Args: + **kwargs: Additional keyword arguments. + + Returns: + torch.optim.Optimizer: The optimizer. + """ + if self.use_adam_optimizer: + return torch.optim.Adam(self.parameters(), **self.optimizer_kwargs) + + return torch.optim.Adamax(self.parameters(), **self.optimizer_kwargs) + class Residual(nn.Module): """