Skip to content

Commit 07e6113

Browse files
author
Ian Schweer
committed
Fix plumbing in blockwise
1 parent c6f1dcb commit 07e6113

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

pytensor/link/pytorch/dispatch/blockwise.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ def funcify_Blockwise(op: Blockwise, node, *args, **kwargs):
1010
batched_dims = op.batch_ndim(node)
1111
core_node = op._create_dummy_core_node(node.inputs)
1212
core_fgraph = FunctionGraph(inputs=core_node.inputs, outputs=core_node.outputs)
13-
inner_func = pytorch_funcify(core_fgraph, squeeze_output=len(node.outputs) == 1)
13+
inner_func = pytorch_funcify(
14+
core_fgraph, squeeze_output=len(node.outputs) == 1, **kwargs
15+
)
1416

1517
for _ in range(batched_dims):
1618
inner_func = torch.vmap(inner_func)

tests/link/pytorch/test_blockwise.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
basic = pytest.importorskip("pytensor.link.pytorch.dispatch.basic")
1313

1414

15-
class TestOp(Op):
15+
class BatchedTestOp(Op):
1616
gufunc_signature = "(m,n),(n,p)->(m,p)"
1717

1818
def __init__(self, final_shape):
@@ -27,7 +27,7 @@ def perform(self, *_):
2727
raise RuntimeError("In perform")
2828

2929

30-
@basic.pytorch_funcify.register(TestOp)
30+
@basic.pytorch_funcify.register(BatchedTestOp)
3131
def evaluate_test_op(op, **_):
3232
def func(a, b):
3333
op.call_shapes.extend(map(torch.Tensor.size, [a, b]))
@@ -42,7 +42,7 @@ def test_blockwise_broadcast():
4242

4343
x = pt.tensor4("x", shape=(5, 1, 2, 3))
4444
y = pt.tensor3("y", shape=(3, 3, 2))
45-
op = TestOp((2, 2))
45+
op = BatchedTestOp((2, 2))
4646
z = Blockwise(op)(x, y)
4747

4848
f = pytensor.function([x, y], z, mode="PYTORCH")

0 commit comments

Comments
 (0)