Skip to content

BUG: Forward sampling with dims fails when mode="JAX" #7286

@jessegrabowski

Description

@jessegrabowski

Describe the issue:

Shapes aren't being correct set on variables when using coords in JAX. I guess this is a consequence of coords being mutable by default, and could be addressed by using freeze_dims_and_data as in #7263. If this is the case, perhaps we should check for the mode='JAX' compile_kwarg in forward samplers and raise early with a more informative error?

Reproduceable code example:

import pymc as pm

# Fails
with pm.Model(coords={'a':['1']}) as m:
    x = pm.Normal('x', dims=['a'])
    pm.sample_prior_predictive(compile_kwargs={'mode':'JAX'})

# Works
with pm.Model() as m:
    x = pm.Normal('x', shape=(1,))
    pm.sample_prior_predictive(compile_kwargs={'mode':'JAX'})

Error message:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pytensor/link/utils.py:196, in streamline.<locals>.streamline_default_f()
    193 for thunk, node, old_storage in zip(
    194     thunks, order, post_thunk_old_storage
    195 ):
--> 196     thunk()
    197     for old_s in old_storage:

File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pytensor/link/basic.py:660, in JITLinker.create_jitable_thunk.<locals>.thunk(fgraph, fgraph_jit, thunk_inputs, thunk_outputs)
    654 def thunk(
    655     fgraph=self.fgraph,
    656     fgraph_jit=fgraph_jit,
    657     thunk_inputs=thunk_inputs,
    658     thunk_outputs=thunk_outputs,
    659 ):
--> 660     outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
    662     for o_var, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):

    [... skipping hidden 11 frame]

File /var/folders/7b/rzxy96cj0w751_6td3g2yss00000gn/T/tmpjm3tlztx:5, in jax_funcified_fgraph(random_generator_shared_variable, a)
      4 # normal_rv{0, (0, 0), floatX, False}(RandomGeneratorSharedVariable(<Generator(PCG64) at 0x29A781FC0>), JAXShapeTuple.0, 11, 0, 1.0)
----> 5 variable, x = sample_fn(random_generator_shared_variable, tensor_variable, tensor_constant, tensor_constant_1, tensor_constant_2)
      6 return x, variable

File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pytensor/link/jax/dispatch/random.py:106, in jax_funcify_RandomVariable.<locals>.sample_fn(rng, size, dtype, *parameters)
    105 def sample_fn(rng, size, dtype, *parameters):
--> 106     return jax_sample_fn(op)(rng, size, out_dtype, *parameters)

File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pytensor/link/jax/dispatch/random.py:164, in jax_sample_fn_loc_scale.<locals>.sample_fn(rng, size, dtype, *parameters)
    163 loc, scale = parameters
--> 164 sample = loc + jax_op(sampling_key, size, dtype) * scale
    165 rng["jax_state"] = rng_key

File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/jax/_src/random.py:710, in normal(key, shape, dtype)
    709 dtype = dtypes.canonicalize_dtype(dtype)
--> 710 shape = core.as_named_shape(shape)
    711 return _normal(key, shape, dtype)

    [... skipping hidden 2 frame]

File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/jax/_src/core.py:2142, in canonicalize_shape(shape, context)
   2141   pass
-> 2142 raise _invalid_shape_error(shape, context)

TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=1/0)>,).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function jax_funcified_fgraph at /var/folders/7b/rzxy96cj0w751_6td3g2yss00000gn/T/tmpjm3tlztx:1 for jit. This concrete value was not available in Python because it depends on the value of the argument a.

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
Cell In[19], line 1
----> 1 pm.draw(x, mode='JAX')

File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pymc/sampling/forward.py:314, in draw(vars, draws, random_seed, **kwargs)
    311 draw_fn = compile_pymc(inputs=[], outputs=vars, random_seed=random_seed, **kwargs)
    313 if draws == 1:
--> 314     return draw_fn()
    316 # Single variable output
    317 if not isinstance(vars, list | tuple):

File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pytensor/compile/function/types.py:970, in Function.__call__(self, *args, **kwargs)
    967 t0_fn = time.perf_counter()
    968 try:
    969     outputs = (
--> 970         self.vm()
    971         if output_subset is None
    972         else self.vm(output_subset=output_subset)
    973     )
    974 except Exception:
    975     restore_defaults()

File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pytensor/link/utils.py:200, in streamline.<locals>.streamline_default_f()
    198             old_s[0] = None
    199 except Exception:
--> 200     raise_with_op(fgraph, node, thunk)

File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pytensor/link/utils.py:523, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
    518     warnings.warn(
    519         f"{exc_type} error does not allow us to add an extra error message"
    520     )
    521     # Some exception need extra parameter in inputs. So forget the
    522     # extra long error message in that case.
--> 523 raise exc_value.with_traceback(exc_trace)

File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pytensor/link/utils.py:196, in streamline.<locals>.streamline_default_f()
    192 try:
    193     for thunk, node, old_storage in zip(
    194         thunks, order, post_thunk_old_storage
    195     ):
--> 196         thunk()
    197         for old_s in old_storage:
    198             old_s[0] = None

File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pytensor/link/basic.py:660, in JITLinker.create_jitable_thunk.<locals>.thunk(fgraph, fgraph_jit, thunk_inputs, thunk_outputs)
    654 def thunk(
    655     fgraph=self.fgraph,
    656     fgraph_jit=fgraph_jit,
    657     thunk_inputs=thunk_inputs,
    658     thunk_outputs=thunk_outputs,
    659 ):
--> 660     outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
    662     for o_var, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):
    663         compute_map[o_var][0] = True

    [... skipping hidden 11 frame]

File /var/folders/7b/rzxy96cj0w751_6td3g2yss00000gn/T/tmpjm3tlztx:5, in jax_funcified_fgraph(random_generator_shared_variable, a)
      3 tensor_variable = shape_tuple_fn(a)
      4 # normal_rv{0, (0, 0), floatX, False}(RandomGeneratorSharedVariable(<Generator(PCG64) at 0x29A781FC0>), JAXShapeTuple.0, 11, 0, 1.0)
----> 5 variable, x = sample_fn(random_generator_shared_variable, tensor_variable, tensor_constant, tensor_constant_1, tensor_constant_2)
      6 return x, variable

File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pytensor/link/jax/dispatch/random.py:106, in jax_funcify_RandomVariable.<locals>.sample_fn(rng, size, dtype, *parameters)
    105 def sample_fn(rng, size, dtype, *parameters):
--> 106     return jax_sample_fn(op)(rng, size, out_dtype, *parameters)

File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pytensor/link/jax/dispatch/random.py:164, in jax_sample_fn_loc_scale.<locals>.sample_fn(rng, size, dtype, *parameters)
    162 rng_key, sampling_key = jax.random.split(rng_key, 2)
    163 loc, scale = parameters
--> 164 sample = loc + jax_op(sampling_key, size, dtype) * scale
    165 rng["jax_state"] = rng_key
    166 return (rng, sample)

File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/jax/_src/random.py:710, in normal(key, shape, dtype)
    707   raise ValueError(f"dtype argument to `normal` must be a float or complex dtype, "
    708                    f"got {dtype}")
    709 dtype = dtypes.canonicalize_dtype(dtype)
--> 710 shape = core.as_named_shape(shape)
    711 return _normal(key, shape, dtype)

    [... skipping hidden 2 frame]

File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/jax/_src/core.py:2142, in canonicalize_shape(shape, context)
   2140 except TypeError:
   2141   pass
-> 2142 raise _invalid_shape_error(shape, context)

TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=1/0)>,).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function jax_funcified_fgraph at /var/folders/7b/rzxy96cj0w751_6td3g2yss00000gn/T/tmpjm3tlztx:1 for jit. This concrete value was not available in Python because it depends on the value of the argument a.
Apply node that caused the error: normal_rv{0, (0, 0), floatX, False}(RandomGeneratorSharedVariable(<Generator(PCG64) at 0x29A781FC0>), JAXShapeTuple.0, 11, 0, 1.0)
Toposort index: 1
Inputs types: [RandomGeneratorType, TensorType(int64, shape=(1,)), TensorType(int64, shape=()), TensorType(int8, shape=()), TensorType(float32, shape=())]
Inputs shapes: ['No shapes', ()]
Inputs strides: ['No strides', ()]
Inputs values: [{'bit_generator': 1, 'state': {'state': 5504079417979030970, 'inc': 4407794720271215875}, 'has_uint32': 0, 'uinteger': 0, 'jax_state': array([1281518353, 2620247482], dtype=uint32)}, array(1)]
Outputs clients: [['output'], ['output']]

Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
  File "/Users/jessegrabowski/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
    coro.send(None)
  File "/Users/jessegrabowski/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3334, in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
  File "/Users/jessegrabowski/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3517, in run_ast_nodes
    if await self.run_code(code, result, async_=asy):
  File "/Users/jessegrabowski/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/var/folders/7b/rzxy96cj0w751_6td3g2yss00000gn/T/ipykernel_14410/452414321.py", line 2, in <module>
    x = pm.Normal('x', dims=['a'])
  File "/Users/jessegrabowski/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pymc/distributions/distribution.py", line 554, in __new__
    rv_out = cls.dist(*args, **kwargs)
  File "/Users/jessegrabowski/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pymc/distributions/continuous.py", line 511, in dist
    return super().dist([mu, sigma], **kwargs)
  File "/Users/jessegrabowski/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pymc/distributions/distribution.py", line 633, in dist
    rv_out = cls.rv_op(*dist_params, size=create_size, **kwargs)

HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

PyMC version information:

pymc: 5.13.1

Context for the issue:

No response

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions