Skip to content

Commit 7cef064

Browse files
use continue on rewrite failures when checking clients
1 parent 17fbeb3 commit 7cef064

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,13 +193,13 @@ def check_for_block_diag(x):
193193
# Check that the BlockDiagonal is an input to a Dot node:
194194
for client in get_clients_at_depth(fgraph, node, depth=1):
195195
if not isinstance(client.op, Dot):
196-
return
196+
continue
197197

198198
op = client.op
199199
x, y = client.inputs
200200

201201
if not (check_for_block_diag(x) or check_for_block_diag(y)):
202-
return None
202+
continue
203203

204204
# Case 1: Only one input is BlockDiagonal. In this case, multiply all components of the block-diagonal with the
205205
# non-block diagonal, and return a new block diagonal
@@ -235,7 +235,7 @@ def check_for_block_diag(x):
235235
else:
236236
# TODO: If shapes are statically known and all components have equal shapes, we could rewrite
237237
# this case to block_diag(*[dot(comp_1, comp_2) for comp_1, comp_2 in zip(x.owner.inputs, y.owner.inputs)])
238-
return None
238+
continue
239239

240240
copy_stack_trace(node.outputs[0], new_output)
241241
return {client.outputs[0]: new_output}

0 commit comments

Comments
 (0)