@@ -298,6 +298,41 @@ def _load_jvp(primals, tangents, args_tree, **params):
298
298
299
299
ad .primitive_jvps [load_p ] = _load_jvp
300
300
301
+ def uninitialized_value (shape , dtype ):
302
+ if jnp .issubdtype (dtype , jnp .floating ):
303
+ return jnp .full (shape , jnp .nan , dtype )
304
+ elif jnp .issubdtype (dtype , jnp .integer ):
305
+ return jnp .full (shape , jnp .iinfo (dtype ).min , dtype )
306
+ elif jnp .issubdtype (dtype , jnp .bool ):
307
+ return jnp .full (shape , False , dtype )
308
+ raise NotImplementedError (dtype )
309
+
310
+ def _pad_values_to_avoid_dynamic_slice_oob_shift (value ,
311
+ slice_sizes , unpad = False ):
312
+ """
313
+ DynamicSlice and DynamicUpdateSlice adjust the start index in cases where the
314
+ requested slice overruns the bounds of the array. This pads the array with
315
+ uninitialised values such that the requested slice will never overrun.
316
+
317
+ For example, if arr is [1.,2.,3.,4.] and a slice of size 4, start index 2 is
318
+ requested then the result will be [3.,4.,NaN,NaN] after padding, rather than
319
+ [1.,2.,3.,4.] from the unpadded array
320
+
321
+ unpad=True performs the inverse operation
322
+ """
323
+
324
+ padding_config = tuple ((0 , slice_size , 0 ) for slice_size in slice_sizes )
325
+ if unpad :
326
+ padding_config = tuple ((- low , - high , - interior )
327
+ for (low , high , interior ) in padding_config )
328
+ padding_value = uninitialized_value (shape = (), dtype = value .dtype )
329
+ value = lax .pad (value ,
330
+ padding_config = padding_config ,
331
+ padding_value = padding_value )
332
+ return value
333
+
334
+ _unpad_values_to_avoid_dynamic_slice_oob_shift = partial (
335
+ _pad_values_to_avoid_dynamic_slice_oob_shift , unpad = True )
301
336
302
337
def _load_discharge_rule (in_avals , out_avals , * args_flat , args_tree , ** _ ):
303
338
del out_avals # Unused.
@@ -315,6 +350,10 @@ def _load_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_):
315
350
scalar_dims = [not isinstance (s , Slice ) and not s .shape for s in indices ]
316
351
slice_starts = [s .start if isinstance (s , Slice ) else s for s in indices ]
317
352
slice_sizes = tuple (s .size if isinstance (s , Slice ) else 1 for s in indices )
353
+ # fixes an inconstency with lax.dynamic_slice where if the slice goes out
354
+ # of bounds, it will instead move the start_index backwards so the slice
355
+ # will fit in memory.
356
+ ref = _pad_values_to_avoid_dynamic_slice_oob_shift (ref , slice_sizes )
318
357
out_ones = lax .dynamic_slice (ref , slice_starts , slice_sizes = slice_sizes )
319
358
out_indexer = tuple (0 if scalar else slice (None ) for scalar in scalar_dims )
320
359
out = out_ones [out_indexer ]
@@ -424,6 +463,10 @@ def _swap_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_):
424
463
]
425
464
slice_starts = [s .start if isinstance (s , Slice ) else s for s in indices ]
426
465
slice_sizes = tuple (s .size if isinstance (s , Slice ) else 1 for s in indices )
466
+ # fixes an inconstency with lax.dynamic_update_slice where if the slice
467
+ # goes out of bounds, it will instead move the start_index backwards so the
468
+ # slice will fit in memory.
469
+ ref = _pad_values_to_avoid_dynamic_slice_oob_shift (ref , slice_sizes )
427
470
out = lax .dynamic_slice (ref , slice_starts , slice_sizes = slice_sizes )
428
471
out = jnp .squeeze (out , scalar_dims )
429
472
if mask is not None :
@@ -432,6 +475,7 @@ def _swap_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_):
432
475
val = jnp .where (mask , val , out_ )
433
476
val = jnp .expand_dims (val , scalar_dims )
434
477
x_new = lax .dynamic_update_slice (ref , val , start_indices = slice_starts )
478
+ x_new = _unpad_values_to_avoid_dynamic_slice_oob_shift (x_new , slice_sizes )
435
479
elif all (not isinstance (s , Slice ) for s in idx .indices ):
436
480
out = ref [idx .indices ]
437
481
if mask is not None :
0 commit comments