@@ -435,9 +435,12 @@ def compile_pymc_model(
435435 * ,
436436 backend : Literal ["numba" , "jax" ] = "numba" ,
437437 gradient_backend : Literal ["pytensor" , "jax" ] = "pytensor" ,
438- overrides : dict [Union ["Variable" , str ], np .ndarray | float | int ] | None = None ,
438+ initial_points : dict [Union ["Variable" , str ], np .ndarray | float | int ]
439+ | None = None ,
439440 jitter_rvs : set ["TensorVariable" ] | None = None ,
440- default_strategy : Literal ["support_point" , "prior" ] = "prior" ,
441+ default_initialization_strategy : Literal [
442+ "support_point" , "prior"
443+ ] = "support_point" ,
441444 ** kwargs ,
442445) -> CompiledModel :
443446 """Compile necessary functions for sampling a pymc model.
@@ -455,10 +458,10 @@ def compile_pymc_model(
455458 The set (or list or tuple) of random variables for which a U(-1, +1)
456459 jitter should be added to the initial value. Only available for
457460 variables that have a transform or real-valued support.
458- default_strategy : str
461+ default_initialization_strategy : str
459462 Which of { "support_point", "prior" } to prefer if the initval setting
460463 for an RV is None.
461- overrides : dict
464+ initial_points : dict
462465 Initial value (strategies) to use instead of what's specified in
463466 `Model.initial_values`.
464467 Returns
@@ -475,13 +478,13 @@ def compile_pymc_model(
475478 "and restart your kernel in case you are in an interactive session."
476479 )
477480
478- if default_strategy == "support_point" and jitter_rvs is None :
481+ if default_initialization_strategy == "support_point" and jitter_rvs is None :
479482 jitter_rvs = set (model .free_RVs )
480483
481484 initial_point_fn = make_initial_point_fn (
482485 model = model ,
483- overrides = overrides ,
484- default_strategy = default_strategy ,
486+ overrides = initial_points ,
487+ default_strategy = default_initialization_strategy ,
485488 jitter_rvs = jitter_rvs ,
486489 return_transformed = True ,
487490 )
0 commit comments