Skip to content

Commit 1a05947

Browse files
committed
Merge branch 'develop'
2 parents fdeca6a + dd4fb93 commit 1a05947

File tree

12 files changed

+2972
-1403
lines changed

12 files changed

+2972
-1403
lines changed

bin/GCfitter

Lines changed: 167 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -19,70 +19,150 @@ move_choices = {
1919
'gaussianmove': emcee.moves.GaussianMove, 'kdemove': emcee.moves.KDEMove,
2020
}
2121

22+
bound_choices = {'none', 'single', 'multi', 'balls', 'cubes'}
23+
sample_choices = {'auto', 'unif', 'rwalk', 'rstagger',
24+
'slice', 'rslice', 'hslice'}
25+
26+
27+
def pos_int(arg):
28+
'''ensure arg is a positive integer, for use as `type` in ArgumentParser'''
29+
30+
if not arg.isdigit():
31+
mssg = f"invalid positive int value: '{arg}'"
32+
raise argparse.ArgumentTypeError(mssg)
33+
34+
return int(arg)
35+
2236

2337
if __name__ == '__main__':
2438

39+
# ----------------------------------------------------------------------
40+
# Command line argument parsing
41+
# ----------------------------------------------------------------------
42+
2543
parser = argparse.ArgumentParser(description='fit some GCs')
2644

2745
parser.add_argument('cluster', help='Common name of the cluster to model')
2846

29-
parser.add_argument('--savedir', default=default_dir,
30-
help='location of saved sampling runs')
31-
parser.add_argument('-i', '--initials',
32-
help='alternative JSON file with different intials')
33-
parser.add_argument('-p', '--priors', dest='param_priors',
34-
help='alternative JSON file with different priors')
35-
36-
parser.add_argument('-N', '--Niters', default=2000, type=int,
37-
help='Number of sampling iterations')
38-
parser.add_argument('--Nwalkers', default=150, type=int,
39-
help='Number of walkers for MCMC sampler')
40-
41-
parser.add_argument('--continue', dest='cont_run', action='store_true',
42-
help='Continue from previous saved run')
43-
parser.add_argument('--backup', action='store_true',
44-
help='Create continuous backups during run')
47+
# ----------------------------------------------------------------------
48+
# Common arguments to all samplers
49+
# ----------------------------------------------------------------------
4550

46-
parser.add_argument('--verbose', action='store_true')
47-
parser.add_argument('--debug', action='store_true')
51+
shared_parser = argparse.ArgumentParser(add_help=False)
4852

49-
parallel_group = parser.add_mutually_exclusive_group()
50-
parallel_group.add_argument("--Ncpu", default=2, type=int,
53+
parallel_group = shared_parser.add_mutually_exclusive_group()
54+
parallel_group.add_argument("--Ncpu", default=2, type=pos_int,
5155
help="Number of `multiprocessing` processes")
5256
parallel_group.add_argument("--mpi", action="store_true",
5357
help="Run with MPI rather than multiprocessing")
5458

55-
parser.add_argument('--fix', dest='fixed_params', nargs='*',
56-
help='Parameters to fix, not estimate from the MCMC')
59+
shared_parser.add_argument('--savedir', default=default_dir,
60+
help='location of saved sampling runs')
61+
shared_parser.add_argument('-i', '--initials',
62+
help='alternative JSON file '
63+
'with different intials')
64+
shared_parser.add_argument('-p', '--priors', dest='param_priors',
65+
help='alternative JSON file '
66+
'with different priors')
5767

58-
parser.add_argument('--exclude', dest='excluded_likelihoods', nargs='*',
59-
help='Likelihood components to exclude from posteriors')
68+
shared_parser.add_argument('--fix', dest='fixed_params', nargs='*',
69+
help='Parameters to fix, '
70+
'not estimate from the sampler')
6071

61-
parser.add_argument('--no-hyperparams', dest='hyperparams',
62-
action='store_false',
63-
help="Don't use Bayesian hyperparams")
72+
shared_parser.add_argument('--exclude', nargs='*',
73+
dest='excluded_likelihoods',
74+
help='Likelihood components to '
75+
'exclude from posteriors')
6476

65-
parser.add_argument('--strict', nargs='+',
66-
metavar=('[STRICT]', 'LIKELIHOOD'),
67-
help="A (numeric) strictness factor to be applied "
68-
"to each specified likelihood component")
77+
shared_parser.add_argument('--no-hyperparams', dest='hyperparams',
78+
action='store_false',
79+
help="Don't use Bayesian hyperparams")
6980

70-
parser.add_argument('--moves', type=str.lower, nargs='*',
71-
default=['stretchmove'], choices=move_choices.keys(),
72-
help="Alternative MCMC move proposal algorithm to use. "
73-
"Multiple moves will be given equal random weight")
81+
shared_parser.add_argument('--strict', nargs='+',
82+
metavar=('[STRICT]', 'LIKELIHOOD'),
83+
help="A (numeric) strictness factor to be "
84+
"applied to each specified likelihood")
7485

75-
parser.add_argument('--show-progress', action='store_true', dest='progress',
76-
help="Display emcee's progress bar")
86+
shared_parser.add_argument('--verbose', action='store_true')
87+
shared_parser.add_argument('--debug', action='store_true')
88+
89+
# ----------------------------------------------------------------------
90+
# Subparsers for each sampler
91+
# ----------------------------------------------------------------------
92+
93+
subparsers = parser.add_subparsers(title="Sampler",
94+
dest="sampler", required=True,
95+
help="Which Sampler algorithm to use in "
96+
"fitting the cluster")
97+
98+
# ----------------------------------------------------------------------
99+
# MCMC sampling with emcee
100+
# ----------------------------------------------------------------------
101+
102+
parser_MCMC = subparsers.add_parser('MCMC', parents=[shared_parser])
103+
104+
parser_MCMC.add_argument('-N', '--Niters', default=2000, type=pos_int,
105+
help='Number of sampling iterations')
106+
parser_MCMC.add_argument('--Nwalkers', default=150, type=pos_int,
107+
help='Number of walkers for MCMC sampler')
108+
109+
parser_MCMC.add_argument('--moves', type=str.lower, nargs='*',
110+
default=['stretchmove'],
111+
choices=move_choices.keys(),
112+
help="Alternative MCMC move proposal algorithm to "
113+
"use. Multiple moves will be given equal "
114+
"random weight")
115+
116+
parser_MCMC.add_argument('--continue', dest='cont_run', action='store_true',
117+
help='Continue from previous saved run')
118+
parser_MCMC.add_argument('--backup', action='store_true',
119+
help='Create continuous backups during run')
120+
121+
parser_MCMC.add_argument('--show-progress', action='store_true',
122+
dest='progress', help="Display progress bar")
123+
124+
parser_MCMC.set_defaults(fit_func=fitter.MCMC_fit)
125+
126+
# ----------------------------------------------------------------------
127+
# Nested Sampling with dynesty
128+
# ----------------------------------------------------------------------
129+
# TODO make the "current_batch" storage optional
130+
131+
parser_nest = subparsers.add_parser('nested', parents=[shared_parser])
132+
133+
parser_nest.add_argument('--pfrac', default=1.0, type=float,
134+
help='Posterior weighting fraction f_p')
135+
parser_nest.add_argument('--dlogz', default=0.25, type=float,
136+
help='Δln(Z) tolerance initial stopping condition')
137+
parser_nest.add_argument('--maxiter', default=None, type=pos_int,
138+
help='Maximum number of iterations allowed. May '
139+
'end sampling before the stopping conditions '
140+
'are met')
141+
parser_nest.add_argument('--init-maxiter', default=None, type=pos_int,
142+
help='Maximum number of iterations allowed in the '
143+
'baseline run')
144+
parser_nest.add_argument('--N-per-batch', default=100, type=pos_int,
145+
dest='Nlive_per_batch',
146+
help='Number of live points to add each batch')
147+
parser_nest.add_argument('--bound-type', default='balls',
148+
choices=bound_choices,
149+
help='Method used to bound sampling on the prior')
150+
parser_nest.add_argument('--sample-type', default='auto',
151+
choices=sample_choices,
152+
help='Method used to sample uniformly within the '
153+
'likelihood, based on the provided bounds')
154+
155+
parser_nest.set_defaults(fit_func=fitter.nested_fit)
77156

78157
args = parser.parse_args()
79158

80159
# ----------------------------------------------------------------------
81-
# Do any args preprocessing necessary for calling fitter
160+
# Args preprocessing
82161
# ----------------------------------------------------------------------
83162

84-
if args.cont_run:
85-
raise NotImplementedError
163+
# ----------------------------------------------------------------------
164+
# Common arguments
165+
# ----------------------------------------------------------------------
86166

87167
if args.initials:
88168

@@ -104,13 +184,6 @@ if __name__ == '__main__':
104184
else:
105185
parser.error(f"Cannot access '{bnd_file}': No such file")
106186

107-
pathlib.Path(args.savedir).mkdir(exist_ok=True)
108-
109-
if debug := args.debug:
110-
args.verbose = True
111-
112-
del args.debug
113-
114187
# TODO could also be a way here for setting `err_on_fail` in the priors
115188
if args.strict is not None:
116189
try:
@@ -121,12 +194,51 @@ if __name__ == '__main__':
121194
if len(args.strict) == 1:
122195
args.strict.append('*')
123196

124-
args.moves = [move_choices[mv]() for mv in args.moves]
197+
pathlib.Path(args.savedir).mkdir(exist_ok=True)
198+
199+
# ----------------------------------------------------------------------
200+
# MCMC specific arguments
201+
# ----------------------------------------------------------------------
202+
203+
if args.sampler == 'MCMC':
204+
205+
if args.cont_run:
206+
raise NotImplementedError
207+
208+
args.moves = [move_choices[mv]() for mv in args.moves]
209+
210+
# ----------------------------------------------------------------------
211+
# Nested Sampling specific arguments
212+
# ----------------------------------------------------------------------
213+
214+
elif args.sampler == 'nested':
215+
216+
# TODO add more of these options
217+
args.initial_kwargs = {
218+
'maxiter': args.init_maxiter or float('inf'),
219+
'nlive': args.Nlive_per_batch,
220+
'dlogz': args.dlogz
221+
}
222+
223+
args.batch_kwargs = {
224+
'maxiter': args.maxiter or float('inf'),
225+
'nlive_new': args.Nlive_per_batch
226+
}
227+
228+
del args.dlogz
229+
del args.maxiter
230+
del args.init_maxiter
231+
del args.Nlive_per_batch
125232

126233
# ----------------------------------------------------------------------
127234
# Setup logging
128235
# ----------------------------------------------------------------------
129236

237+
if debug := args.debug:
238+
args.verbose = True
239+
240+
del args.debug
241+
130242
config = {
131243
'level': logging.DEBUG if debug else logging.INFO,
132244
'format': ('%(process)s|%(asctime)s|'
@@ -145,6 +257,11 @@ if __name__ == '__main__':
145257
# Call fitter
146258
# ----------------------------------------------------------------------
147259

148-
print('args:', vars(args))
260+
del args.sampler
261+
262+
fit_func = args.fit_func
263+
del args.fit_func
264+
265+
logging.debug(f"{args=}")
149266

150-
fitter.fit(**vars(args))
267+
fit_func(**vars(args))

fitter/core/data.py

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,20 @@ def __repr__(self):
152152
def __str__(self):
153153
return f'{self._name} Dataset'
154154

155+
_citation = None
156+
157+
def __citation__(self):
158+
if self._citation is not None:
159+
return self._citation
160+
else:
161+
try:
162+
bibcodes = self.mdata['source'].split(';')
163+
self._citation = util.bibcode2cite(bibcodes)
164+
return self._citation
165+
166+
except KeyError:
167+
return None
168+
155169
def __contains__(self, key):
156170
return key in self._dict_variables
157171

@@ -182,6 +196,9 @@ def __init__(self, group):
182196
def variables(self):
183197
return self._dict_variables
184198

199+
def cite(self):
200+
return self.__citation__()
201+
185202
def build_err(self, varname, model_r, model_val, strict=True):
186203
'''
187204
varname is the variable we want to get the error for
@@ -274,6 +291,7 @@ def __getitem__(self, key):
274291
mssg = f"Dataset '{key}' does not exist in {self}"
275292
raise KeyError(mssg) from err
276293

294+
# TODO a filter method for finding all datasets matching a pattern
277295
@property
278296
def datasets(self):
279297
return self._dict_datasets
@@ -311,6 +329,17 @@ def _walker(key, obj):
311329

312330
return groups
313331

332+
def filter_datasets(self, pattern, valid_only=True):
333+
# TODO maybe `datasets` and this should only return ds list not dict?
334+
# if thats the case, make `datasets._name` public
335+
336+
if valid_only:
337+
datasets = {key for (key, *_) in self.valid_likelihoods}
338+
else:
339+
datasets = self.datasets.keys
340+
341+
return {key: self[key] for key in fnmatch.filter(datasets, pattern)}
342+
314343
def filter_likelihoods(self, patterns, exclude=False, keys_only=False):
315344
'''filter the valid likelihoods based on list of patterns, matching
316345
either the dataset name or likelihood function name.
@@ -340,6 +369,7 @@ def get_sources(self, fmt='bibtex'):
340369
341370
fmt : 'bibtex', 'bibcode', 'citep'
342371
'''
372+
# TODO make this use dataset __citation__'s so it doesnt pull each time
343373

344374
res = {}
345375

@@ -522,17 +552,10 @@ def _determine_likelihoods(self):
522552
# --------------------------------------------------------------------------
523553

524554
# TODO The units are *quite* incomplete in Model (10)
555+
# TODO would be cool to get this to work with limepy's `sampling`
525556

526557
class Model(lp.limepy):
527558

528-
def __getattr__(self, key):
529-
'''If `key` is not defined in the limepy model, try to get it from θ'''
530-
try:
531-
return self._theta[key]
532-
except KeyError as err:
533-
msg = f"'{self.__class__.__name__}' object has no attribute '{key}'"
534-
raise AttributeError(msg) from err
535-
536559
def _init_mf(self):
537560

538561
m123 = [0.1, 0.5, 1.0, 100] # Slope breakpoints for imf
@@ -649,6 +672,9 @@ def __init__(self, theta, observations=None, *, verbose=False):
649672

650673
self._theta = theta
651674

675+
for key, val in self._theta.items():
676+
setattr(self, key, val)
677+
652678
# ------------------------------------------------------------------
653679
# Get mass function
654680
# ------------------------------------------------------------------
@@ -680,22 +706,24 @@ def __init__(self, theta, observations=None, *, verbose=False):
680706
# TODO still don't entriely understand when this is to be used
681707
# mj is middle of mass bins, mes are edges, widths are sizes of bins
682708
# self.mbin_widths = np.diff(self._mf.mes[-1]) ??
709+
# Whats the differences with `mes` and `me`?
683710
# TODO is this supposed to habe units? I think so
684711
self.mes_widths = np.diff(self._mf.mes[-1])
685712

686713
# append tracer mass bins (must be appended to end to not affect nms)
687714
if observations is not None:
688715

689-
# TODO should only append tracer masses for valid likelihood dsets
716+
# TODO should only append tracer masses for valid likelihood dsets?
690717
tracer_mj = np.unique([
691718
dataset.mdata['m'] for dataset in observations.datasets.values()
692719
if 'm' in dataset.mdata
693720
])
694721

695-
# TODO shouldn't append multiple of same tracer mass
696722
mj = np.concatenate((mj, tracer_mj))
697723
Mj = np.concatenate((Mj, 0.1 * np.ones_like(tracer_mj)))
698724

725+
self._tracer_bins = slice(self.nms + self.nmr, None)
726+
699727
else:
700728
logging.warning("No `Observations` given, no tracer masses added")
701729

0 commit comments

Comments
 (0)