Skip to content

Commit 0b2dcbe

Browse files
committed
Add error message in Numba implementation of SpecifyShape
1 parent 12afbf2 commit 0b2dcbe

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -545,11 +545,11 @@ def numba_funcify_SpecifyShape(op, node, **kwargs):
545545
shape_input_names = ["shape_" + str(i) for i in range(len(shape_inputs))]
546546

547547
func_conditions = [
548-
f"assert x.shape[{i}] == {shape_input_names}"
549-
for i, (shape_input, shape_input_names) in enumerate(
548+
f"assert x.shape[{i}] == {eval_dim_name}, f'SpecifyShape: dim {{{i}}} of input has shape {{x.shape[{i}]}}, expected {{{eval_dim_name}.item()}}.'"
549+
for i, (node_dim_input, eval_dim_name) in enumerate(
550550
zip(shape_inputs, shape_input_names, strict=True)
551551
)
552-
if shape_input is not NoneConst
552+
if node_dim_input is not NoneConst
553553
]
554554

555555
func = dedent(

0 commit comments

Comments
 (0)