220220
221221
222222class WriterWrapper :
223- def __init__ (self , replica_numbers , pdf_objects , stopping_object , all_chi2s , theory , timings ):
223+ def __init__ (self , replica_numbers , pdf_objects , stopping_object , all_chi2s , theory , timings , trials ):
224224 """
225225 Initializes the writer for all replicas.
226226
@@ -241,13 +241,16 @@ def __init__(self, replica_numbers, pdf_objects, stopping_object, all_chi2s, the
241241 theory information of the fit
242242 `timings`
243243 dictionary of the timing of the different events that happened
244+ `trials`
245+ dictionary with best trials settings
244246 """
245247 self .replica_numbers = replica_numbers
246248 self .pdf_objects = pdf_objects
247249 self .stopping_object = stopping_object
248250 self .theory = theory
249251 self .timings = timings
250252 self .tr_chi2 , self .vl_chi2 , self .true_chi2 = all_chi2s
253+ self .trials = trials
251254
252255 def write_data (self , save_path , fitname , weights_name ):
253256 """
@@ -273,13 +276,12 @@ def write_data(self, save_path, fitname, weights_name):
273276 self .integrability_numbers .append (
274277 vpinterface .integrability_numbers (pdf_object ).tolist ()
275278 )
276-
277279 for i , rn in enumerate (self .replica_numbers ):
278280 replica_path = save_path / f"replica_{ rn } "
279281 replica_path .mkdir (exist_ok = True , parents = True )
280282
281283 self ._write_chi2s (replica_path / "chi2exps.log" )
282- self ._write_metadata_json (i , replica_path / f"{ fitname } .json" )
284+ self ._write_metadata_json (i , rn , replica_path / f"{ fitname } .json" )
283285 self ._export_pdf_grid (i , replica_path / f"{ fitname } .exportgrid" )
284286 if weights_name :
285287 self ._write_weights (i , replica_path / f"{ weights_name } " )
@@ -290,7 +292,26 @@ def _write_chi2s(self, out_path):
290292 with open (out_path , "w" , encoding = "utf-8" ) as fs :
291293 json .dump (chi2_log , fs , indent = 2 , cls = SuperEncoder )
292294
293- def _write_metadata_json (self , i , out_path ):
295+ def _hyperparam_settings (self , replica_number ):
296+ """Collect replica hyperparameter settings"""
297+ trials_number = self .trials ["number_of_trials" ]
298+ idx_trial = replica_number % trials_number
299+ hyperparam_info = {}
300+ hyperparam_info ["optimizer" ]= self .trials ["optimizer" ][idx_trial ]
301+ hyperparam_info ["learning_rate" ]= self .trials ["learning_rate" ][idx_trial ]
302+ hyperparam_info ["clipnorm" ]= self .trials ["clipnorm" ][idx_trial ]
303+ hyperparam_info ["epochs" ]= self .trials ["epochs" ][idx_trial ]
304+ hyperparam_info ["stopping_patience" ]= self .trials ["stopping_patience" ][idx_trial ]
305+ hyperparam_info ["initial" ]= self .trials ["initial" ][idx_trial ]
306+ hyperparam_info ["nodes_per_layer" ]= self .trials ["nodes_per_layer" ][idx_trial ]
307+ hyperparam_info ["number_of_layers" ]= self .trials ["number_of_layers" ][idx_trial ]
308+ hyperparam_info ["activation" ]= self .trials ["activation_per_layer" ][idx_trial ]
309+ hyperparam_info ["layer_type" ]= self .trials ["layer_type" ][idx_trial ]
310+ hyperparam_info ["initializer" ]= self .trials ["initializer" ][idx_trial ]
311+ hyperparam_info ["dropout" ]= self .trials ["dropout" ][idx_trial ]
312+ return hyperparam_info
313+
314+ def _write_metadata_json (self , i , replica_number , out_path ):
294315 json_dict = jsonfit (
295316 best_epoch = self .stopping_object .e_best_chi2 [i ],
296317 positivity_status = self .stopping_object .positivity_statuses [i ],
@@ -300,6 +321,7 @@ def _write_metadata_json(self, i, out_path):
300321 tr_chi2 = self .tr_chi2 [i ],
301322 vl_chi2 = self .vl_chi2 [i ],
302323 true_chi2 = self .true_chi2 [i ],
324+ hyperparam_info = self ._hyperparam_settings (replica_number ),
303325 # Note: the 2 arguments below are the same for all replicas, unless run separately
304326 timing = self .timings ,
305327 stop_epoch = self .stopping_object .stop_epoch ,
@@ -347,6 +369,7 @@ def jsonfit(
347369 true_chi2 ,
348370 stop_epoch ,
349371 timing ,
372+ hyperparam_info ,
350373):
351374 """Generates a dictionary containing all relevant metadata for the fit
352375
@@ -372,6 +395,8 @@ def jsonfit(
372395 epoch at which the stopping stopped (not the one for the best fit!)
373396 timing: dict
374397 dictionary of the timing of the different events that happened
398+ hyperparam_info: dict
399+ dictionary of hyperparameter settings
375400 """
376401 all_info = {}
377402 # Generate preprocessing information
@@ -386,6 +411,7 @@ def jsonfit(
386411 all_info ["arc_lengths" ] = arc_lengths
387412 all_info ["integrability" ] = integrability_numbers
388413 all_info ["timing" ] = timing
414+ all_info ["hyperparameters" ] = hyperparam_info
389415 # Versioning info
390416 all_info ["version" ] = version ()
391417 return all_info
0 commit comments