File tree Expand file tree Collapse file tree 2 files changed +6
-7
lines changed
pytensor/link/jax/dispatch Expand file tree Collapse file tree 2 files changed +6
-7
lines changed Original file line number Diff line number Diff line change @@ -96,12 +96,11 @@ def shape_i(x):
96
96
def jax_funcify_SpecifyShape (op , node , ** kwargs ):
97
97
def specifyshape (x , * shape ):
98
98
assert x .ndim == len (shape )
99
- assert x .shape == tuple (shape ), (
100
- "got shape" ,
101
- x .shape ,
102
- "expected" ,
103
- shape ,
104
- )
99
+ for actual , expected in zip (x .shape , shape ):
100
+ if expected is None :
101
+ continue
102
+ if actual != expected :
103
+ raise ValueError (f"Invalid shape: Expected { shape } but got { x .shape } " )
105
104
return x
106
105
107
106
return specifyshape
Original file line number Diff line number Diff line change @@ -25,7 +25,7 @@ def test_jax_shape_ops():
25
25
26
26
def test_jax_specify_shape ():
27
27
in_at = at .matrix ("in" )
28
- x = at .specify_shape (in_at , (4 , 5 ))
28
+ x = at .specify_shape (in_at , (4 , None ))
29
29
x_fg = FunctionGraph ([in_at ], [x ])
30
30
compare_jax_and_py (x_fg , [np .ones ((4 , 5 )).astype (config .floatX )])
31
31
You can’t perform that action at this time.
0 commit comments