Skip to content

Commit 66fd926

Browse files
committed
Modify fit saving functions to save entries always
Will just pick dataset size, weighted, if fit was not extended
1 parent a9fce7f commit 66fd926

File tree

1 file changed

+33
-7
lines changed

1 file changed

+33
-7
lines changed

src/dmu/stats/utilities.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -388,9 +388,13 @@ def save_fit(
388388
fit_dir: Directory where outputs are meant to go
389389
plt_cfg: Plotting configuration, as taken by ZfitPlotter
390390
d_const: Dictionary storing constraints
391+
392+
Returns
393+
--------------------
394+
Object holding fitting parameters if fit happened, or None if fit did not run
391395
'''
392396
if plt_cfg is None:
393-
return
397+
return None
394398

395399
if isinstance(plt_cfg, dict):
396400
plt_cfg = OmegaConf.create(plt_cfg)
@@ -401,18 +405,26 @@ def save_fit(
401405
fit_dir = str(fit_dir)
402406

403407
_save_fit_plot(data=data, model=model, cfg=plt_cfg, fit_dir=fit_dir)
404-
_save_result(fit_dir=fit_dir, res=res)
405408

409+
if model and not model.extended:
410+
# TODO: Update this when the issue with this property be fixed in zfit
411+
nentries = data.samplesize.numpy() # type: ignore
412+
else:
413+
nentries = None
414+
415+
pars = _save_result(fit_dir=fit_dir, res=res, nentries = nentries)
406416
df = data.to_pandas(weightsname=Data.weight_name)
407417
opath = f'{fit_dir}/data.json'
408418
log.debug(f'Saving data to: {opath}')
409419
df.to_json(opath, indent=2)
410420

411421
if model is None:
412-
return
422+
return None
413423

414424
print_pdf(model, txt_path=f'{fit_dir}/post_fit.txt', d_const=d_const)
415425
pdf_to_tex(path=f'{fit_dir}/post_fit.txt', d_par={'mu' : r'$\mu$'}, skip_fixed=True)
426+
427+
return pars
416428
# ----------------------
417429
def _save_fit_plot(
418430
data : zdata,
@@ -477,18 +489,27 @@ def _save_fit_plot(
477489
plt.savefig(fit_path_log)
478490
plt.close()
479491
#-------------------------------------------------------
480-
def _save_result(fit_dir : str, res : zres|None) -> None:
492+
def _save_result(
493+
nentries: float | None,
494+
fit_dir : str,
495+
res : zres | None) -> Measurement | None:
481496
'''
482497
Saves result as yaml, JSON, pkl
483498
484499
Parameters
485500
---------------
486-
fit_dir: Directory where fit result will go
487-
res : Zfit result object
501+
nentries: Number of entries or sum of weights, if not None
502+
fit_dir : Directory where fit result will go
503+
res : Zfit result object
504+
505+
Returns
506+
---------------
507+
Object holding values and errors of fitting parameters or None if result not found
508+
because fit did not run
488509
'''
489510
if res is None:
490511
log.info('No result object found, not saving parameters in pkl or JSON')
491-
return
512+
return None
492513

493514
# TODO: Remove this once there be a safer way to freeze
494515
# see https://github.com/zfit/zfit/issues/632
@@ -503,13 +524,18 @@ def _save_result(fit_dir : str, res : zres|None) -> None:
503524
d_par = _parameters_from_result(result=res)
504525
d_par = dict(sorted(d_par.items()))
505526

527+
if nentries:
528+
d_par['nentries'] = nentries, 0
529+
506530
opath = f'{fit_dir}/parameters.json'
507531
log.debug(f'Saving parameters to: {opath}')
508532
gut.dump_json(data = d_par, path = opath, exists_ok = True)
509533

510534
opath = f'{fit_dir}/parameters.yaml'
511535
log.debug(f'Saving parameters to: {opath}')
512536
gut.dump_json(data = d_par, path = opath, exists_ok = True)
537+
538+
return Measurement(data = d_par)
513539
#-------------------------------------------------------
514540
# Make latex table from text file
515541
#-------------------------------------------------------

0 commit comments

Comments
 (0)