@@ -19,70 +19,150 @@ move_choices = {
1919 'gaussianmove' : emcee .moves .GaussianMove , 'kdemove' : emcee .moves .KDEMove ,
2020}
2121
22+ bound_choices = {'none' , 'single' , 'multi' , 'balls' , 'cubes' }
23+ sample_choices = {'auto' , 'unif' , 'rwalk' , 'rstagger' ,
24+ 'slice' , 'rslice' , 'hslice' }
25+
26+
27+ def pos_int (arg ):
28+ '''ensure arg is a positive integer, for use as `type` in ArgumentParser'''
29+
30+ if not arg .isdigit ():
31+ mssg = f"invalid positive int value: '{ arg } '"
32+ raise argparse .ArgumentTypeError (mssg )
33+
34+ return int (arg )
35+
2236
2337if __name__ == '__main__' :
2438
39+ # ----------------------------------------------------------------------
40+ # Command line argument parsing
41+ # ----------------------------------------------------------------------
42+
2543 parser = argparse .ArgumentParser (description = 'fit some GCs' )
2644
2745 parser .add_argument ('cluster' , help = 'Common name of the cluster to model' )
2846
29- parser .add_argument ('--savedir' , default = default_dir ,
30- help = 'location of saved sampling runs' )
31- parser .add_argument ('-i' , '--initials' ,
32- help = 'alternative JSON file with different intials' )
33- parser .add_argument ('-p' , '--priors' , dest = 'param_priors' ,
34- help = 'alternative JSON file with different priors' )
35-
36- parser .add_argument ('-N' , '--Niters' , default = 2000 , type = int ,
37- help = 'Number of sampling iterations' )
38- parser .add_argument ('--Nwalkers' , default = 150 , type = int ,
39- help = 'Number of walkers for MCMC sampler' )
40-
41- parser .add_argument ('--continue' , dest = 'cont_run' , action = 'store_true' ,
42- help = 'Continue from previous saved run' )
43- parser .add_argument ('--backup' , action = 'store_true' ,
44- help = 'Create continuous backups during run' )
47+ # ----------------------------------------------------------------------
48+ # Common arguments to all samplers
49+ # ----------------------------------------------------------------------
4550
46- parser .add_argument ('--verbose' , action = 'store_true' )
47- parser .add_argument ('--debug' , action = 'store_true' )
51+ shared_parser = argparse .ArgumentParser (add_help = False )
4852
49- parallel_group = parser .add_mutually_exclusive_group ()
50- parallel_group .add_argument ("--Ncpu" , default = 2 , type = int ,
53+ parallel_group = shared_parser .add_mutually_exclusive_group ()
54+ parallel_group .add_argument ("--Ncpu" , default = 2 , type = pos_int ,
5155 help = "Number of `multiprocessing` processes" )
5256 parallel_group .add_argument ("--mpi" , action = "store_true" ,
5357 help = "Run with MPI rather than multiprocessing" )
5458
55- parser .add_argument ('--fix' , dest = 'fixed_params' , nargs = '*' ,
56- help = 'Parameters to fix, not estimate from the MCMC' )
59+ shared_parser .add_argument ('--savedir' , default = default_dir ,
60+ help = 'location of saved sampling runs' )
61+ shared_parser .add_argument ('-i' , '--initials' ,
62+ help = 'alternative JSON file '
63+ 'with different intials' )
64+ shared_parser .add_argument ('-p' , '--priors' , dest = 'param_priors' ,
65+ help = 'alternative JSON file '
66+ 'with different priors' )
5767
58- parser .add_argument ('--exclude' , dest = 'excluded_likelihoods' , nargs = '*' ,
59- help = 'Likelihood components to exclude from posteriors' )
68+ shared_parser .add_argument ('--fix' , dest = 'fixed_params' , nargs = '*' ,
69+ help = 'Parameters to fix, '
70+ 'not estimate from the sampler' )
6071
61- parser .add_argument ('--no-hyperparams' , dest = 'hyperparams' ,
62- action = 'store_false' ,
63- help = "Don't use Bayesian hyperparams" )
72+ shared_parser .add_argument ('--exclude' , nargs = '*' ,
73+ dest = 'excluded_likelihoods' ,
74+ help = 'Likelihood components to '
75+ 'exclude from posteriors' )
6476
65- parser .add_argument ('--strict' , nargs = '+' ,
66- metavar = ('[STRICT]' , 'LIKELIHOOD' ),
67- help = "A (numeric) strictness factor to be applied "
68- "to each specified likelihood component" )
77+ shared_parser .add_argument ('--no-hyperparams' , dest = 'hyperparams' ,
78+ action = 'store_false' ,
79+ help = "Don't use Bayesian hyperparams" )
6980
70- parser .add_argument ('--moves ' , type = str . lower , nargs = '* ' ,
71- default = [ 'stretchmove' ], choices = move_choices . keys ( ),
72- help = "Alternative MCMC move proposal algorithm to use. "
73- "Multiple moves will be given equal random weight " )
81+ shared_parser .add_argument ('--strict ' , nargs = '+ ' ,
82+ metavar = ( '[STRICT]' , 'LIKELIHOOD' ),
83+ help = "A (numeric) strictness factor to be "
84+ "applied to each specified likelihood " )
7485
75- parser .add_argument ('--show-progress' , action = 'store_true' , dest = 'progress' ,
76- help = "Display emcee's progress bar" )
86+ shared_parser .add_argument ('--verbose' , action = 'store_true' )
87+ shared_parser .add_argument ('--debug' , action = 'store_true' )
88+
89+ # ----------------------------------------------------------------------
90+ # Subparsers for each sampler
91+ # ----------------------------------------------------------------------
92+
93+ subparsers = parser .add_subparsers (title = "Sampler" ,
94+ dest = "sampler" , required = True ,
95+ help = "Which Sampler algorithm to use in "
96+ "fitting the cluster" )
97+
98+ # ----------------------------------------------------------------------
99+ # MCMC sampling with emcee
100+ # ----------------------------------------------------------------------
101+
102+ parser_MCMC = subparsers .add_parser ('MCMC' , parents = [shared_parser ])
103+
104+ parser_MCMC .add_argument ('-N' , '--Niters' , default = 2000 , type = pos_int ,
105+ help = 'Number of sampling iterations' )
106+ parser_MCMC .add_argument ('--Nwalkers' , default = 150 , type = pos_int ,
107+ help = 'Number of walkers for MCMC sampler' )
108+
109+ parser_MCMC .add_argument ('--moves' , type = str .lower , nargs = '*' ,
110+ default = ['stretchmove' ],
111+ choices = move_choices .keys (),
112+ help = "Alternative MCMC move proposal algorithm to "
113+ "use. Multiple moves will be given equal "
114+ "random weight" )
115+
116+ parser_MCMC .add_argument ('--continue' , dest = 'cont_run' , action = 'store_true' ,
117+ help = 'Continue from previous saved run' )
118+ parser_MCMC .add_argument ('--backup' , action = 'store_true' ,
119+ help = 'Create continuous backups during run' )
120+
121+ parser_MCMC .add_argument ('--show-progress' , action = 'store_true' ,
122+ dest = 'progress' , help = "Display progress bar" )
123+
124+ parser_MCMC .set_defaults (fit_func = fitter .MCMC_fit )
125+
126+ # ----------------------------------------------------------------------
127+ # Nested Sampling with dynesty
128+ # ----------------------------------------------------------------------
129+ # TODO make the "current_batch" storage optional
130+
131+ parser_nest = subparsers .add_parser ('nested' , parents = [shared_parser ])
132+
133+ parser_nest .add_argument ('--pfrac' , default = 1.0 , type = float ,
134+ help = 'Posterior weighting fraction f_p' )
135+ parser_nest .add_argument ('--dlogz' , default = 0.25 , type = float ,
136+ help = 'Δln(Z) tolerance initial stopping condition' )
137+ parser_nest .add_argument ('--maxiter' , default = None , type = pos_int ,
138+ help = 'Maximum number of iterations allowed. May '
139+ 'end sampling before the stopping conditions '
140+ 'are met' )
141+ parser_nest .add_argument ('--init-maxiter' , default = None , type = pos_int ,
142+ help = 'Maximum number of iterations allowed in the '
143+ 'baseline run' )
144+ parser_nest .add_argument ('--N-per-batch' , default = 100 , type = pos_int ,
145+ dest = 'Nlive_per_batch' ,
146+ help = 'Number of live points to add each batch' )
147+ parser_nest .add_argument ('--bound-type' , default = 'balls' ,
148+ choices = bound_choices ,
149+ help = 'Method used to bound sampling on the prior' )
150+ parser_nest .add_argument ('--sample-type' , default = 'auto' ,
151+ choices = sample_choices ,
152+ help = 'Method used to sample uniformly within the '
153+ 'likelihood, based on the provided bounds' )
154+
155+ parser_nest .set_defaults (fit_func = fitter .nested_fit )
77156
78157 args = parser .parse_args ()
79158
80159 # ----------------------------------------------------------------------
81- # Do any args preprocessing necessary for calling fitter
160+ # Args preprocessing
82161 # ----------------------------------------------------------------------
83162
84- if args .cont_run :
85- raise NotImplementedError
163+ # ----------------------------------------------------------------------
164+ # Common arguments
165+ # ----------------------------------------------------------------------
86166
87167 if args .initials :
88168
@@ -104,13 +184,6 @@ if __name__ == '__main__':
104184 else :
105185 parser .error (f"Cannot access '{ bnd_file } ': No such file" )
106186
107- pathlib .Path (args .savedir ).mkdir (exist_ok = True )
108-
109- if debug := args .debug :
110- args .verbose = True
111-
112- del args .debug
113-
114187 # TODO could also be a way here for setting `err_on_fail` in the priors
115188 if args .strict is not None :
116189 try :
@@ -121,12 +194,51 @@ if __name__ == '__main__':
121194 if len (args .strict ) == 1 :
122195 args .strict .append ('*' )
123196
124- args .moves = [move_choices [mv ]() for mv in args .moves ]
197+ pathlib .Path (args .savedir ).mkdir (exist_ok = True )
198+
199+ # ----------------------------------------------------------------------
200+ # MCMC specific arguments
201+ # ----------------------------------------------------------------------
202+
203+ if args .sampler == 'MCMC' :
204+
205+ if args .cont_run :
206+ raise NotImplementedError
207+
208+ args .moves = [move_choices [mv ]() for mv in args .moves ]
209+
210+ # ----------------------------------------------------------------------
211+ # Nested Sampling specific arguments
212+ # ----------------------------------------------------------------------
213+
214+ elif args .sampler == 'nested' :
215+
216+ # TODO add more of these options
217+ args .initial_kwargs = {
218+ 'maxiter' : args .init_maxiter or float ('inf' ),
219+ 'nlive' : args .Nlive_per_batch ,
220+ 'dlogz' : args .dlogz
221+ }
222+
223+ args .batch_kwargs = {
224+ 'maxiter' : args .maxiter or float ('inf' ),
225+ 'nlive_new' : args .Nlive_per_batch
226+ }
227+
228+ del args .dlogz
229+ del args .maxiter
230+ del args .init_maxiter
231+ del args .Nlive_per_batch
125232
126233 # ----------------------------------------------------------------------
127234 # Setup logging
128235 # ----------------------------------------------------------------------
129236
237+ if debug := args .debug :
238+ args .verbose = True
239+
240+ del args .debug
241+
130242 config = {
131243 'level' : logging .DEBUG if debug else logging .INFO ,
132244 'format' : ('%(process)s|%(asctime)s|'
@@ -145,6 +257,11 @@ if __name__ == '__main__':
145257 # Call fitter
146258 # ----------------------------------------------------------------------
147259
148- print ('args:' , vars (args ))
260+ del args .sampler
261+
262+ fit_func = args .fit_func
263+ del args .fit_func
264+
265+ logging .debug (f"{ args = } " )
149266
150- fitter . fit (** vars (args ))
267+ fit_func (** vars (args ))
0 commit comments