We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 0d75490 commit 3fbe9a9Copy full SHA for 3fbe9a9
pymc/sampling/jax.py
@@ -34,7 +34,6 @@
34
from pytensor.raise_op import Assert
35
from pytensor.tensor import TensorVariable
36
from pytensor.tensor.random.type import RandomType
37
-from pytensor.tensor.shape import SpecifyShape
38
39
from pymc import Model, modelcontext
40
from pymc.backends.arviz import find_constants, find_observations
@@ -62,7 +61,6 @@
62
61
63
@jax_funcify.register(Assert)
64
@jax_funcify.register(CheckParameterValue)
65
-@jax_funcify.register(SpecifyShape)
66
def jax_funcify_Assert(op, **kwargs):
67
# Jax does not allow assert whose values aren't known during JIT compilation
68
# within it's JIT-ed code. Hence we need to make a simple pass through
0 commit comments