Skip to content

Commit a11134d

Browse files
committed
add hyperparam specs at the end of each replica json
1 parent c42c628 commit a11134d

File tree

1 file changed

+30
-4
lines changed

1 file changed

+30
-4
lines changed

n3fit/src/n3fit/io/writer.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@
220220

221221

222222
class WriterWrapper:
223-
def __init__(self, replica_numbers, pdf_objects, stopping_object, all_chi2s, theory, timings):
223+
def __init__(self, replica_numbers, pdf_objects, stopping_object, all_chi2s, theory, timings, trials):
224224
"""
225225
Initializes the writer for all replicas.
226226
@@ -241,13 +241,16 @@ def __init__(self, replica_numbers, pdf_objects, stopping_object, all_chi2s, the
241241
theory information of the fit
242242
`timings`
243243
dictionary of the timing of the different events that happened
244+
`trials`
245+
dictionary with best trials settings
244246
"""
245247
self.replica_numbers = replica_numbers
246248
self.pdf_objects = pdf_objects
247249
self.stopping_object = stopping_object
248250
self.theory = theory
249251
self.timings = timings
250252
self.tr_chi2, self.vl_chi2, self.true_chi2 = all_chi2s
253+
self.trials = trials
251254

252255
def write_data(self, save_path, fitname, weights_name):
253256
"""
@@ -273,13 +276,12 @@ def write_data(self, save_path, fitname, weights_name):
273276
self.integrability_numbers.append(
274277
vpinterface.integrability_numbers(pdf_object).tolist()
275278
)
276-
277279
for i, rn in enumerate(self.replica_numbers):
278280
replica_path = save_path / f"replica_{rn}"
279281
replica_path.mkdir(exist_ok=True, parents=True)
280282

281283
self._write_chi2s(replica_path / "chi2exps.log")
282-
self._write_metadata_json(i, replica_path / f"{fitname}.json")
284+
self._write_metadata_json(i, rn, replica_path / f"{fitname}.json")
283285
self._export_pdf_grid(i, replica_path / f"{fitname}.exportgrid")
284286
if weights_name:
285287
self._write_weights(i, replica_path / f"{weights_name}")
@@ -290,7 +292,26 @@ def _write_chi2s(self, out_path):
290292
with open(out_path, "w", encoding="utf-8") as fs:
291293
json.dump(chi2_log, fs, indent=2, cls=SuperEncoder)
292294

293-
def _write_metadata_json(self, i, out_path):
295+
def _hyperparam_settings(self, replica_number):
296+
"""Collect replica hyperparameter settings"""
297+
trials_number = self.trials["number_of_trials"]
298+
idx_trial = replica_number % trials_number
299+
hyperparam_info = {}
300+
hyperparam_info["optimizer"]=self.trials["optimizer"][idx_trial]
301+
hyperparam_info["learning_rate"]=self.trials["learning_rate"][idx_trial]
302+
hyperparam_info["clipnorm"]=self.trials["clipnorm"][idx_trial]
303+
hyperparam_info["epochs"]=self.trials["epochs"][idx_trial]
304+
hyperparam_info["stopping_patience"]=self.trials["stopping_patience"][idx_trial]
305+
hyperparam_info["initial"]=self.trials["initial"][idx_trial]
306+
hyperparam_info["nodes_per_layer"]=self.trials["nodes_per_layer"][idx_trial]
307+
hyperparam_info["number_of_layers"]=self.trials["number_of_layers"][idx_trial]
308+
hyperparam_info["activation"]=self.trials["activation_per_layer"][idx_trial]
309+
hyperparam_info["layer_type"]=self.trials["layer_type"][idx_trial]
310+
hyperparam_info["initializer"]=self.trials["initializer"][idx_trial]
311+
hyperparam_info["dropout"]=self.trials["dropout"][idx_trial]
312+
return hyperparam_info
313+
314+
def _write_metadata_json(self, i, replica_number, out_path):
294315
json_dict = jsonfit(
295316
best_epoch=self.stopping_object.e_best_chi2[i],
296317
positivity_status=self.stopping_object.positivity_statuses[i],
@@ -300,6 +321,7 @@ def _write_metadata_json(self, i, out_path):
300321
tr_chi2=self.tr_chi2[i],
301322
vl_chi2=self.vl_chi2[i],
302323
true_chi2=self.true_chi2[i],
324+
hyperparam_info=self._hyperparam_settings(replica_number),
303325
# Note: the 2 arguments below are the same for all replicas, unless run separately
304326
timing=self.timings,
305327
stop_epoch=self.stopping_object.stop_epoch,
@@ -347,6 +369,7 @@ def jsonfit(
347369
true_chi2,
348370
stop_epoch,
349371
timing,
372+
hyperparam_info,
350373
):
351374
"""Generates a dictionary containing all relevant metadata for the fit
352375
@@ -372,6 +395,8 @@ def jsonfit(
372395
epoch at which the stopping stopped (not the one for the best fit!)
373396
timing: dict
374397
dictionary of the timing of the different events that happened
398+
hyperparam_info: dict
399+
dictionary of hyperparameter settings
375400
"""
376401
all_info = {}
377402
# Generate preprocessing information
@@ -386,6 +411,7 @@ def jsonfit(
386411
all_info["arc_lengths"] = arc_lengths
387412
all_info["integrability"] = integrability_numbers
388413
all_info["timing"] = timing
414+
all_info["hyperparameters"] = hyperparam_info
389415
# Versioning info
390416
all_info["version"] = version()
391417
return all_info

0 commit comments

Comments
 (0)