Skip to content

Commit 3fbe9a9

Browse files
juanitorduzricardoV94
authored andcommitted
Remove SpecifyShape from Assert JAX
1 parent 0d75490 commit 3fbe9a9

File tree

1 file changed

+0
-2
lines changed

1 file changed

+0
-2
lines changed

pymc/sampling/jax.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
from pytensor.raise_op import Assert
3535
from pytensor.tensor import TensorVariable
3636
from pytensor.tensor.random.type import RandomType
37-
from pytensor.tensor.shape import SpecifyShape
3837

3938
from pymc import Model, modelcontext
4039
from pymc.backends.arviz import find_constants, find_observations
@@ -62,7 +61,6 @@
6261

6362
@jax_funcify.register(Assert)
6463
@jax_funcify.register(CheckParameterValue)
65-
@jax_funcify.register(SpecifyShape)
6664
def jax_funcify_Assert(op, **kwargs):
6765
# Jax does not allow assert whose values aren't known during JIT compilation
6866
# within it's JIT-ed code. Hence we need to make a simple pass through

0 commit comments

Comments
 (0)