Skip to content

Commit 39cdbba

Browse files
committed
Allow additional hyperparameter exclusion
1 parent 7480783 commit 39cdbba

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

chebai/models/base.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Any, Dict, Optional, Union
2+
from typing import Any, Dict, Optional, Union, Iterable
33

44
import torch
55
from lightning.pytorch.core.module import LightningModule
@@ -41,12 +41,15 @@ def __init__(
4141
test_metrics: Optional[torch.nn.Module] = None,
4242
pass_loss_kwargs: bool = True,
4343
optimizer_kwargs: Optional[Dict[str, Any]] = None,
44+
exclude_hyperparameter_logging: Optional[Iterable[str]] = None,
4445
**kwargs,
4546
):
4647
super().__init__()
48+
if exclude_hyperparameter_logging is None:
49+
exclude_hyperparameter_logging = tuple()
4750
self.criterion = criterion
4851
self.save_hyperparameters(
49-
ignore=["criterion", "train_metrics", "val_metrics", "test_metrics"]
52+
ignore=["criterion", "train_metrics", "val_metrics", "test_metrics", *exclude_hyperparameter_logging]
5053
)
5154
self.out_dim = out_dim
5255
if optimizer_kwargs:

0 commit comments

Comments
 (0)