Skip to content

Commit d4d25a2

Browse files
committed
option to use Adam optim as in deepgo
1 parent 5cec15e commit d4d25a2

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

chebai/models/ffn.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,14 @@ def __init__(
1414
hidden_layers: List[int] = [
1515
1024,
1616
],
17+
use_adam_optimizer: bool = False,
1718
**kwargs,
1819
):
1920
super().__init__(**kwargs)
2021

22+
self.use_adam_optimizer: bool = bool(use_adam_optimizer)
23+
print(f"Using Adam optimizer: {self.use_adam_optimizer}")
24+
2125
layers = []
2226
current_layer_input_size = self.input_dim
2327
for hidden_dim in hidden_layers:
@@ -62,6 +66,21 @@ def forward(self, data, **kwargs):
6266
x = data["features"]
6367
return {"logits": self.model(x)}
6468

69+
def configure_optimizers(self, **kwargs) -> torch.optim.Optimizer:
70+
"""
71+
Configures the optimizers.
72+
73+
Args:
74+
**kwargs: Additional keyword arguments.
75+
76+
Returns:
77+
torch.optim.Optimizer: The optimizer.
78+
"""
79+
if self.use_adam_optimizer:
80+
return torch.optim.Adam(self.parameters(), **self.optimizer_kwargs)
81+
82+
return torch.optim.Adamax(self.parameters(), **self.optimizer_kwargs)
83+
6584

6685
class Residual(nn.Module):
6786
"""

0 commit comments

Comments
 (0)