Skip to content

Commit f0fbe9a

Browse files
committed
Split: Return disconnected gradient for split sizes
1 parent 6e3d33b commit f0fbe9a

File tree

2 files changed

+21
-10
lines changed

2 files changed

+21
-10
lines changed

pytensor/tensor/basic.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2254,18 +2254,19 @@ def infer_shape(self, fgraph, node, in_shapes):
22542254
out_shapes.append(temp)
22552255
return out_shapes
22562256

2257+
def connection_pattern(self, node):
2258+
n_out = len(node.outputs)
2259+
return [
2260+
[True] * n_out,
2261+
[True] * n_out,
2262+
[False] * n_out,
2263+
]
2264+
22572265
def L_op(self, inputs, outputs, g_outputs):
22582266
"""Join the gradients along the axis that was used to split x."""
2259-
_x, axis, n = inputs
2267+
_x, axis, _n = inputs
22602268

2261-
# If all the output gradients are disconnected, then so are the inputs
2262-
if builtins.all(isinstance(g.type, DisconnectedType) for g in g_outputs):
2263-
return [
2264-
DisconnectedType()(),
2265-
grad_undefined(self, 1, axis),
2266-
grad_undefined(self, 2, n),
2267-
]
2268-
# Else, we have to make them zeros before joining them
2269+
# We have to convert disconnected outputs to zeros before joining them
22692270
new_g_outputs = []
22702271
for o, g in zip(outputs, g_outputs, strict=True):
22712272
if isinstance(g.type, DisconnectedType):
@@ -2276,7 +2277,7 @@ def L_op(self, inputs, outputs, g_outputs):
22762277
return [
22772278
join(axis, *new_g_outputs),
22782279
grad_undefined(self, 1, axis),
2279-
grad_undefined(self, 2, n),
2280+
DisconnectedType()(),
22802281
]
22812282

22822283
def R_op(self, inputs, eval_points):

tests/tensor/test_reshape.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from pytensor import config, function
77
from pytensor import tensor as pt
88
from pytensor.graph import rewrite_graph, vectorize_graph
9+
from pytensor.graph.op import io_connection_pattern
910
from pytensor.tensor.reshape import (
1011
_analyze_axes_list,
1112
join_dims,
@@ -289,3 +290,12 @@ def test_pack_unpack_round_trip(self, axes):
289290

290291
for input_val, output_val in zip(input_dict.values(), output_vals, strict=True):
291292
np.testing.assert_allclose(input_val, output_val)
293+
294+
295+
def test_unpack_connection():
296+
x = pt.vector("x")
297+
d0 = pt.scalar("d0", dtype=int)
298+
d1 = pt.scalar("d1", dtype=int)
299+
x0, x1 = pt.unpack(x, axes=None, packed_shapes=[d0, d1])
300+
out = x0.sum() + x1.sum()
301+
assert io_connection_pattern([x, d0, d1], [out]) == [[True], [False], [False]]

0 commit comments

Comments
 (0)