@@ -122,6 +122,7 @@ def convert_flat_trace_to_idata(
122122 samples ,
123123 include_transformed = False ,
124124 postprocessing_backend = "cpu" ,
125+ inference_backend = "pymc" ,
125126 model = None ,
126127):
127128 model = modelcontext (model )
@@ -139,10 +140,21 @@ def convert_flat_trace_to_idata(
139140 var_names = model .unobserved_value_vars
140141 vars_to_sample = list (get_default_varnames (var_names , include_transformed = include_transformed ))
141142 print ("Transforming variables..." , file = sys .stdout )
142- jax_fn = get_jaxified_graph (inputs = model .value_vars , outputs = vars_to_sample )
143- result = jax .vmap (jax .vmap (jax_fn ))(
144- * jax .device_put (list (trace .values ()), jax .devices (postprocessing_backend )[0 ])
145- )
143+
144+ if inference_backend == "pymc" :
145+ # TODO: we need to remove JAX dependency as win32 users can now use Pathfinder with inference_backend="pymc".
146+ jax_fn = get_jaxified_graph (inputs = model .value_vars , outputs = vars_to_sample )
147+ result = jax .vmap (jax .vmap (jax_fn ))(
148+ * jax .device_put (list (trace .values ()), jax .devices (postprocessing_backend )[0 ])
149+ )
150+ elif inference_backend == "blackjax" :
151+ jax_fn = get_jaxified_graph (inputs = model .value_vars , outputs = vars_to_sample )
152+ result = jax .vmap (jax .vmap (jax_fn ))(
153+ * jax .device_put (list (trace .values ()), jax .devices (postprocessing_backend )[0 ])
154+ )
155+ else :
156+ raise ValueError (f"Invalid inference_backend: { inference_backend } " )
157+
146158 trace = {v .name : r for v , r in zip (vars_to_sample , result )}
147159 coords , dims = coords_and_dims_for_inferencedata (model )
148160 idata = az .from_dict (trace , dims = dims , coords = coords )
@@ -742,7 +754,6 @@ def fit_pathfinder(
742754 random_seed = random_seed ,
743755 ** pathfinder_kwargs ,
744756 )
745-
746757 elif inference_backend == "blackjax" :
747758 jitter_seed , pathfinder_seed , sample_seed = _get_seeds_per_chain (random_seed , 3 )
748759 # TODO: extend initial points initialisation to blackjax
@@ -773,15 +784,15 @@ def fit_pathfinder(
773784 state = pathfinder_state ,
774785 num_samples = num_draws ,
775786 )
776-
777787 else :
778- raise ValueError (f"Inference backend { inference_backend } not supported " )
788+ raise ValueError (f"Invalid inference_backend: { inference_backend } " )
779789
780790 print ("Running pathfinder..." , file = sys .stdout )
781791
782792 idata = convert_flat_trace_to_idata (
783793 pathfinder_samples ,
784794 postprocessing_backend = postprocessing_backend ,
795+ inference_backend = inference_backend ,
785796 model = model ,
786797 )
787798 return idata
0 commit comments