-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Closed
Description
Describe the issue:
Shapes aren't being correct set on variables when using coords in JAX. I guess this is a consequence of coords being mutable by default, and could be addressed by using freeze_dims_and_data as in #7263. If this is the case, perhaps we should check for the mode='JAX' compile_kwarg in forward samplers and raise early with a more informative error?
Reproduceable code example:
import pymc as pm
# Fails
with pm.Model(coords={'a':['1']}) as m:
x = pm.Normal('x', dims=['a'])
pm.sample_prior_predictive(compile_kwargs={'mode':'JAX'})
# Works
with pm.Model() as m:
x = pm.Normal('x', shape=(1,))
pm.sample_prior_predictive(compile_kwargs={'mode':'JAX'})Error message:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pytensor/link/utils.py:196, in streamline.<locals>.streamline_default_f()
193 for thunk, node, old_storage in zip(
194 thunks, order, post_thunk_old_storage
195 ):
--> 196 thunk()
197 for old_s in old_storage:
File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pytensor/link/basic.py:660, in JITLinker.create_jitable_thunk.<locals>.thunk(fgraph, fgraph_jit, thunk_inputs, thunk_outputs)
654 def thunk(
655 fgraph=self.fgraph,
656 fgraph_jit=fgraph_jit,
657 thunk_inputs=thunk_inputs,
658 thunk_outputs=thunk_outputs,
659 ):
--> 660 outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
662 for o_var, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):
[... skipping hidden 11 frame]
File /var/folders/7b/rzxy96cj0w751_6td3g2yss00000gn/T/tmpjm3tlztx:5, in jax_funcified_fgraph(random_generator_shared_variable, a)
4 # normal_rv{0, (0, 0), floatX, False}(RandomGeneratorSharedVariable(<Generator(PCG64) at 0x29A781FC0>), JAXShapeTuple.0, 11, 0, 1.0)
----> 5 variable, x = sample_fn(random_generator_shared_variable, tensor_variable, tensor_constant, tensor_constant_1, tensor_constant_2)
6 return x, variable
File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pytensor/link/jax/dispatch/random.py:106, in jax_funcify_RandomVariable.<locals>.sample_fn(rng, size, dtype, *parameters)
105 def sample_fn(rng, size, dtype, *parameters):
--> 106 return jax_sample_fn(op)(rng, size, out_dtype, *parameters)
File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pytensor/link/jax/dispatch/random.py:164, in jax_sample_fn_loc_scale.<locals>.sample_fn(rng, size, dtype, *parameters)
163 loc, scale = parameters
--> 164 sample = loc + jax_op(sampling_key, size, dtype) * scale
165 rng["jax_state"] = rng_key
File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/jax/_src/random.py:710, in normal(key, shape, dtype)
709 dtype = dtypes.canonicalize_dtype(dtype)
--> 710 shape = core.as_named_shape(shape)
711 return _normal(key, shape, dtype)
[... skipping hidden 2 frame]
File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/jax/_src/core.py:2142, in canonicalize_shape(shape, context)
2141 pass
-> 2142 raise _invalid_shape_error(shape, context)
TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=1/0)>,).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function jax_funcified_fgraph at /var/folders/7b/rzxy96cj0w751_6td3g2yss00000gn/T/tmpjm3tlztx:1 for jit. This concrete value was not available in Python because it depends on the value of the argument a.
During handling of the above exception, another exception occurred:
TypeError Traceback (most recent call last)
Cell In[19], line 1
----> 1 pm.draw(x, mode='JAX')
File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pymc/sampling/forward.py:314, in draw(vars, draws, random_seed, **kwargs)
311 draw_fn = compile_pymc(inputs=[], outputs=vars, random_seed=random_seed, **kwargs)
313 if draws == 1:
--> 314 return draw_fn()
316 # Single variable output
317 if not isinstance(vars, list | tuple):
File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pytensor/compile/function/types.py:970, in Function.__call__(self, *args, **kwargs)
967 t0_fn = time.perf_counter()
968 try:
969 outputs = (
--> 970 self.vm()
971 if output_subset is None
972 else self.vm(output_subset=output_subset)
973 )
974 except Exception:
975 restore_defaults()
File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pytensor/link/utils.py:200, in streamline.<locals>.streamline_default_f()
198 old_s[0] = None
199 except Exception:
--> 200 raise_with_op(fgraph, node, thunk)
File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pytensor/link/utils.py:523, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
518 warnings.warn(
519 f"{exc_type} error does not allow us to add an extra error message"
520 )
521 # Some exception need extra parameter in inputs. So forget the
522 # extra long error message in that case.
--> 523 raise exc_value.with_traceback(exc_trace)
File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pytensor/link/utils.py:196, in streamline.<locals>.streamline_default_f()
192 try:
193 for thunk, node, old_storage in zip(
194 thunks, order, post_thunk_old_storage
195 ):
--> 196 thunk()
197 for old_s in old_storage:
198 old_s[0] = None
File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pytensor/link/basic.py:660, in JITLinker.create_jitable_thunk.<locals>.thunk(fgraph, fgraph_jit, thunk_inputs, thunk_outputs)
654 def thunk(
655 fgraph=self.fgraph,
656 fgraph_jit=fgraph_jit,
657 thunk_inputs=thunk_inputs,
658 thunk_outputs=thunk_outputs,
659 ):
--> 660 outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
662 for o_var, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):
663 compute_map[o_var][0] = True
[... skipping hidden 11 frame]
File /var/folders/7b/rzxy96cj0w751_6td3g2yss00000gn/T/tmpjm3tlztx:5, in jax_funcified_fgraph(random_generator_shared_variable, a)
3 tensor_variable = shape_tuple_fn(a)
4 # normal_rv{0, (0, 0), floatX, False}(RandomGeneratorSharedVariable(<Generator(PCG64) at 0x29A781FC0>), JAXShapeTuple.0, 11, 0, 1.0)
----> 5 variable, x = sample_fn(random_generator_shared_variable, tensor_variable, tensor_constant, tensor_constant_1, tensor_constant_2)
6 return x, variable
File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pytensor/link/jax/dispatch/random.py:106, in jax_funcify_RandomVariable.<locals>.sample_fn(rng, size, dtype, *parameters)
105 def sample_fn(rng, size, dtype, *parameters):
--> 106 return jax_sample_fn(op)(rng, size, out_dtype, *parameters)
File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pytensor/link/jax/dispatch/random.py:164, in jax_sample_fn_loc_scale.<locals>.sample_fn(rng, size, dtype, *parameters)
162 rng_key, sampling_key = jax.random.split(rng_key, 2)
163 loc, scale = parameters
--> 164 sample = loc + jax_op(sampling_key, size, dtype) * scale
165 rng["jax_state"] = rng_key
166 return (rng, sample)
File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/jax/_src/random.py:710, in normal(key, shape, dtype)
707 raise ValueError(f"dtype argument to `normal` must be a float or complex dtype, "
708 f"got {dtype}")
709 dtype = dtypes.canonicalize_dtype(dtype)
--> 710 shape = core.as_named_shape(shape)
711 return _normal(key, shape, dtype)
[... skipping hidden 2 frame]
File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/jax/_src/core.py:2142, in canonicalize_shape(shape, context)
2140 except TypeError:
2141 pass
-> 2142 raise _invalid_shape_error(shape, context)
TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=1/0)>,).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function jax_funcified_fgraph at /var/folders/7b/rzxy96cj0w751_6td3g2yss00000gn/T/tmpjm3tlztx:1 for jit. This concrete value was not available in Python because it depends on the value of the argument a.
Apply node that caused the error: normal_rv{0, (0, 0), floatX, False}(RandomGeneratorSharedVariable(<Generator(PCG64) at 0x29A781FC0>), JAXShapeTuple.0, 11, 0, 1.0)
Toposort index: 1
Inputs types: [RandomGeneratorType, TensorType(int64, shape=(1,)), TensorType(int64, shape=()), TensorType(int8, shape=()), TensorType(float32, shape=())]
Inputs shapes: ['No shapes', ()]
Inputs strides: ['No strides', ()]
Inputs values: [{'bit_generator': 1, 'state': {'state': 5504079417979030970, 'inc': 4407794720271215875}, 'has_uint32': 0, 'uinteger': 0, 'jax_state': array([1281518353, 2620247482], dtype=uint32)}, array(1)]
Outputs clients: [['output'], ['output']]
Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
File "/Users/jessegrabowski/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
coro.send(None)
File "/Users/jessegrabowski/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3334, in run_cell_async
has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
File "/Users/jessegrabowski/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3517, in run_ast_nodes
if await self.run_code(code, result, async_=asy):
File "/Users/jessegrabowski/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "/var/folders/7b/rzxy96cj0w751_6td3g2yss00000gn/T/ipykernel_14410/452414321.py", line 2, in <module>
x = pm.Normal('x', dims=['a'])
File "/Users/jessegrabowski/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pymc/distributions/distribution.py", line 554, in __new__
rv_out = cls.dist(*args, **kwargs)
File "/Users/jessegrabowski/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pymc/distributions/continuous.py", line 511, in dist
return super().dist([mu, sigma], **kwargs)
File "/Users/jessegrabowski/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pymc/distributions/distribution.py", line 633, in dist
rv_out = cls.rv_op(*dist_params, size=create_size, **kwargs)
HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
PyMC version information:
pymc: 5.13.1
Context for the issue:
No response