99import torch
1010
1111from chebai .preprocessing .reader import CLS_TOKEN , ChemDataReader
12+ from chebai .loggers .custom import CustomLogger
1213
1314log = logging .getLogger (__name__ )
1415
@@ -20,7 +21,35 @@ def __init__(self, *args, **kwargs):
2021 super ().__init__ (* args , ** kwargs )
2122 # instantiation custom logger connector
2223 self ._logger_connector .on_trainer_init (self .logger , 1 )
23- self .logger .log_hyperparams (self .init_kwargs )
24+ # log additional hyperparameters to wandb
25+ if isinstance (self .logger , CustomLogger ):
26+ custom_logger = self .logger
27+ assert isinstance (custom_logger , CustomLogger )
28+ if custom_logger .verbose_hyperparameters :
29+ log_kwargs = {}
30+ for key , value in self .init_kwargs .items ():
31+ log_key , log_value = self ._resolve_logging_argument (key , value )
32+ log_kwargs [log_key ] = log_value
33+ self .logger .log_hyperparams (log_kwargs )
34+
35+ def _resolve_logging_argument (self , key , value ):
36+ if isinstance (value , list ):
37+ key_value_pairs = [
38+ self ._resolve_logging_argument (f"{ key } _{ i } " , v )
39+ for i , v in enumerate (value )
40+ ]
41+ return key , {k : v for k , v in key_value_pairs }
42+ if not (
43+ isinstance (value , str )
44+ or isinstance (value , float )
45+ or isinstance (value , int )
46+ or value is None
47+ ):
48+ params = {"class" : value .__class__ }
49+ params .update (value .__dict__ )
50+ return key , params
51+ else :
52+ return key , value
2453
2554 def predict_from_file (
2655 self ,
0 commit comments