Skip to content

Commit d72a289

Browse files
authored
Fix failing JAX Split test (#1646)
1 parent 934306f commit d72a289

File tree

2 files changed

+15
-12
lines changed

2 files changed

+15
-12
lines changed

pytensor/link/jax/dispatch/tensor_basic.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -119,20 +119,22 @@ def jax_funcify_Split(op: Split, node, **kwargs):
119119
def split(x, axis, splits):
120120
if constant_axis is not None:
121121
axis = constant_axis
122+
if len(splits) != op.len_splits:
123+
raise ValueError("Length of splits is not equal to n_splits")
124+
122125
if constant_splits is not None:
123126
splits = constant_splits
124127
cumsum_splits = np.cumsum(splits[:-1])
128+
if (splits < 0).any():
129+
raise ValueError("Split sizes cannot be negative")
125130
else:
126131
cumsum_splits = jnp.cumsum(splits[:-1])
127132

128-
if len(splits) != op.len_splits:
129-
raise ValueError("Length of splits is not equal to n_splits")
130-
if np.sum(splits) != x.shape[axis]:
131-
raise ValueError(
132-
f"Split sizes do not sum up to input length along axis: {x.shape[axis]}"
133-
)
134-
if np.any(splits < 0):
135-
raise ValueError("Split sizes cannot be negative")
133+
if constant_axis is not None and constant_splits is not None:
134+
if splits.sum() != x.shape[axis]:
135+
raise ValueError(
136+
f"Split sizes do not sum up to input length along axis: {x.shape[axis]}"
137+
)
136138

137139
return jnp.split(x, cumsum_splits, axis=axis)
138140

tests/link/jax/test_tensor_basic.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,16 +182,17 @@ def test_jax_split_not_supported(self):
182182
UserWarning, match="Split node does not have constant split positions."
183183
):
184184
fn = pytensor.function([a], a_splits, mode="JAX")
185-
# It raises an informative ConcretizationTypeError, but there's an AttributeError that surpasses it
186-
with pytest.raises(AttributeError):
185+
# This test used to raise AttributeError in previous versions of JAX.
186+
# Now it raises `TracerIntegerConversionError`.
187+
# We accept both errors for backwards compatibility.
188+
with pytest.raises((AttributeError, errors.TracerIntegerConversionError)):
187189
fn(np.zeros((6, 4), dtype=pytensor.config.floatX))
188190

189191
split_axis = iscalar("split_axis")
190192
a_splits = ptb.split(a, splits_size=[2, 4], n_splits=2, axis=split_axis)
191193
with pytest.warns(UserWarning, match="Split node does not have constant axis."):
192194
fn = pytensor.function([a, split_axis], a_splits, mode="JAX")
193-
# Same as above, an AttributeError surpasses the `TracerIntegerConversionError`
194-
# Both errors are included for backwards compatibility
195+
# Same reasoning as above to accept both errors.
195196
with pytest.raises((AttributeError, errors.TracerIntegerConversionError)):
196197
fn(np.zeros((6, 6), dtype=pytensor.config.floatX), 0)
197198

0 commit comments

Comments
 (0)