@@ -435,9 +435,9 @@ def compile_pymc_model(
435435 * ,
436436 backend : Literal ["numba" , "jax" ] = "numba" ,
437437 gradient_backend : Literal ["pytensor" , "jax" ] = "pytensor" ,
438- overrides : dict [Union ["Variable" , str ], np .ndarray | float | int ] | None = None ,
438+ initial_points : dict [Union ["Variable" , str ], np .ndarray | float | int ] | None = None ,
439439 jitter_rvs : set ["TensorVariable" ] | None = None ,
440- default_strategy : Literal ["support_point" , "prior" ] = "prior " ,
440+ default_initialization_strategy : Literal ["support_point" , "prior" ] = "support_point " ,
441441 ** kwargs ,
442442) -> CompiledModel :
443443 """Compile necessary functions for sampling a pymc model.
@@ -455,10 +455,10 @@ def compile_pymc_model(
455455 The set (or list or tuple) of random variables for which a U(-1, +1)
456456 jitter should be added to the initial value. Only available for
457457 variables that have a transform or real-valued support.
458- default_strategy : str
458+ default_initialization_strategy : str
459459 Which of { "support_point", "prior" } to prefer if the initval setting
460460 for an RV is None.
461- overrides : dict
461+ initial_points : dict
462462 Initial value (strategies) to use instead of what's specified in
463463 `Model.initial_values`.
464464 Returns
@@ -475,13 +475,13 @@ def compile_pymc_model(
475475 "and restart your kernel in case you are in an interactive session."
476476 )
477477
478- if default_strategy == "support_point" and jitter_rvs is None :
478+ if default_initialization_strategy == "support_point" and jitter_rvs is None :
479479 jitter_rvs = set (model .free_RVs )
480480
481481 initial_point_fn = make_initial_point_fn (
482482 model = model ,
483- overrides = overrides ,
484- default_strategy = default_strategy ,
483+ overrides = initial_points ,
484+ default_strategy = default_initialization_strategy ,
485485 jitter_rvs = jitter_rvs ,
486486 return_transformed = True ,
487487 )
0 commit comments