@@ -182,16 +182,17 @@ def test_jax_split_not_supported(self):
182
182
UserWarning , match = "Split node does not have constant split positions."
183
183
):
184
184
fn = pytensor .function ([a ], a_splits , mode = "JAX" )
185
- # It raises an informative ConcretizationTypeError, but there's an AttributeError that surpasses it
186
- with pytest .raises (AttributeError ):
185
+ # This test used to raise AttributeError in previous versions of JAX.
186
+ # Now it raises `TracerIntegerConversionError`.
187
+ # We accept both errors for backwards compatibility.
188
+ with pytest .raises ((AttributeError , errors .TracerIntegerConversionError )):
187
189
fn (np .zeros ((6 , 4 ), dtype = pytensor .config .floatX ))
188
190
189
191
split_axis = iscalar ("split_axis" )
190
192
a_splits = ptb .split (a , splits_size = [2 , 4 ], n_splits = 2 , axis = split_axis )
191
193
with pytest .warns (UserWarning , match = "Split node does not have constant axis." ):
192
194
fn = pytensor .function ([a , split_axis ], a_splits , mode = "JAX" )
193
- # Same as above, an AttributeError surpasses the `TracerIntegerConversionError`
194
- # Both errors are included for backwards compatibility
195
+ # Same reasoning as above to accept both errors.
195
196
with pytest .raises ((AttributeError , errors .TracerIntegerConversionError )):
196
197
fn (np .zeros ((6 , 6 ), dtype = pytensor .config .floatX ), 0 )
197
198
0 commit comments