Skip to content

Commit 74efc96

Browse files
aseyboldtricardoV94
authored andcommitted
fix(jax): Specify shape should ignore None axes
1 parent 790b46f commit 74efc96

File tree

2 files changed

+6
-7
lines changed

2 files changed

+6
-7
lines changed

pytensor/link/jax/dispatch/shape.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,11 @@ def shape_i(x):
9696
def jax_funcify_SpecifyShape(op, node, **kwargs):
9797
def specifyshape(x, *shape):
9898
assert x.ndim == len(shape)
99-
assert x.shape == tuple(shape), (
100-
"got shape",
101-
x.shape,
102-
"expected",
103-
shape,
104-
)
99+
for actual, expected in zip(x.shape, shape):
100+
if expected is None:
101+
continue
102+
if actual != expected:
103+
raise ValueError(f"Invalid shape: Expected {shape} but got {x.shape}")
105104
return x
106105

107106
return specifyshape

tests/link/jax/test_shape.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def test_jax_shape_ops():
2525

2626
def test_jax_specify_shape():
2727
in_at = at.matrix("in")
28-
x = at.specify_shape(in_at, (4, 5))
28+
x = at.specify_shape(in_at, (4, None))
2929
x_fg = FunctionGraph([in_at], [x])
3030
compare_jax_and_py(x_fg, [np.ones((4, 5)).astype(config.floatX)])
3131

0 commit comments

Comments
 (0)