Skip to content

Commit a5d55e9

Browse files
author
sfluegel
committed
add hyperparameter saving to datamodule and trainer
1 parent e9e3b71 commit a5d55e9

File tree

2 files changed

+31
-1
lines changed

2 files changed

+31
-1
lines changed

chebai/preprocessing/datasets/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def __init__(
108108
if self.use_inner_cross_validation:
109109
os.makedirs(os.path.join(self.raw_dir, self.fold_dir), exist_ok=True)
110110
os.makedirs(os.path.join(self.processed_dir, self.fold_dir), exist_ok=True)
111+
self.save_hyperparameters()
111112

112113
@property
113114
def identifier(self):

chebai/trainer/CustomTrainer.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch
1010

1111
from chebai.preprocessing.reader import CLS_TOKEN, ChemDataReader
12+
from chebai.loggers.custom import CustomLogger
1213

1314
log = logging.getLogger(__name__)
1415

@@ -20,7 +21,35 @@ def __init__(self, *args, **kwargs):
2021
super().__init__(*args, **kwargs)
2122
# instantiation custom logger connector
2223
self._logger_connector.on_trainer_init(self.logger, 1)
23-
self.logger.log_hyperparams(self.init_kwargs)
24+
# log additional hyperparameters to wandb
25+
if isinstance(self.logger, CustomLogger):
26+
custom_logger = self.logger
27+
assert isinstance(custom_logger, CustomLogger)
28+
if custom_logger.verbose_hyperparameters:
29+
log_kwargs = {}
30+
for key, value in self.init_kwargs.items():
31+
log_key, log_value = self._resolve_logging_argument(key, value)
32+
log_kwargs[log_key] = log_value
33+
self.logger.log_hyperparams(log_kwargs)
34+
35+
def _resolve_logging_argument(self, key, value):
36+
if isinstance(value, list):
37+
key_value_pairs = [
38+
self._resolve_logging_argument(f"{key}_{i}", v)
39+
for i, v in enumerate(value)
40+
]
41+
return key, {k: v for k, v in key_value_pairs}
42+
if not (
43+
isinstance(value, str)
44+
or isinstance(value, float)
45+
or isinstance(value, int)
46+
or value is None
47+
):
48+
params = {"class": value.__class__}
49+
params.update(value.__dict__)
50+
return key, params
51+
else:
52+
return key, value
2453

2554
def predict_from_file(
2655
self,

0 commit comments

Comments
 (0)