Skip to content

Commit fa29f18

Browse files
committed
[pallas] align interpreter load/store for masked OOB slicing
1 parent 341e63b commit fa29f18

File tree

3 files changed

+97
-9
lines changed

3 files changed

+97
-9
lines changed

jax/_src/pallas/pallas_call.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from jax._src.interpreters import partial_eval as pe
3737
from jax._src.interpreters import xla
3838
from jax._src.pallas import core as pallas_core
39+
from jax._src.pallas.primitives import uninitialized_value
3940
from jax._src.state import discharge as state_discharge
4041
from jax._src.state import primitives as sp
4142
from jax._src.util import (
@@ -104,17 +105,10 @@ def _pad_values_to_block_dimension(value,
104105
)
105106
if padded_shape != value.shape:
106107
pad_width = tuple((0, a-b) for a, b in zip(padded_shape, value.shape))
107-
pad_value = _uninitialized_value(shape=(), dtype=value.dtype)
108+
pad_value = uninitialized_value(shape=(), dtype=value.dtype)
108109
value = jnp.pad(value, pad_width, constant_values=pad_value)
109110
return value
110111

111-
def _uninitialized_value(shape, dtype):
112-
if jnp.issubdtype(dtype, jnp.floating):
113-
return jnp.full(shape, jnp.nan, dtype)
114-
elif jnp.issubdtype(dtype, jnp.integer):
115-
return jnp.full(shape, jnp.iinfo(dtype).min, dtype)
116-
raise NotImplementedError(dtype)
117-
118112
def _get_next_indices(grid, indices):
119113
next_indices = []
120114
carry = True
@@ -170,7 +164,7 @@ def _pallas_call_impl(*args, jaxpr, name, out_shapes, which_linear,
170164
hasattr(a, "shape") and hasattr(a, "dtype") for a in scratch_avals
171165
):
172166
raise NotImplementedError(f"Cannot initialize scratch: {scratch_avals}")
173-
scratch_values = [_uninitialized_value(a.shape, a.dtype)
167+
scratch_values = [uninitialized_value(a.shape, a.dtype)
174168
for a in scratch_avals]
175169

176170
carry = []

jax/_src/pallas/primitives.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,41 @@ def _load_jvp(primals, tangents, args_tree, **params):
298298

299299
ad.primitive_jvps[load_p] = _load_jvp
300300

301+
def uninitialized_value(shape, dtype):
302+
if jnp.issubdtype(dtype, jnp.floating):
303+
return jnp.full(shape, jnp.nan, dtype)
304+
elif jnp.issubdtype(dtype, jnp.integer):
305+
return jnp.full(shape, jnp.iinfo(dtype).min, dtype)
306+
elif jnp.issubdtype(dtype, jnp.bool):
307+
return jnp.full(shape, False, dtype)
308+
raise NotImplementedError(dtype)
309+
310+
def _pad_values_to_avoid_dynamic_slice_oob_shift(value,
311+
slice_sizes, unpad=False):
312+
"""
313+
DynamicSlice and DynamicUpdateSlice adjust the start index in cases where the
314+
requested slice overruns the bounds of the array. This pads the array with
315+
uninitialised values such that the requested slice will never overrun.
316+
317+
For example, if arr is [1.,2.,3.,4.] and a slice of size 4, start index 2 is
318+
requested then the result will be [3.,4.,NaN,NaN] after padding, rather than
319+
[1.,2.,3.,4.] from the unpadded array
320+
321+
unpad=True performs the inverse operation
322+
"""
323+
324+
padding_config = tuple((0, slice_size, 0) for slice_size in slice_sizes)
325+
if unpad:
326+
padding_config = tuple((-low, -high, -interior)
327+
for (low, high, interior) in padding_config)
328+
padding_value = uninitialized_value(shape=(), dtype=value.dtype)
329+
value = lax.pad(value,
330+
padding_config=padding_config,
331+
padding_value=padding_value)
332+
return value
333+
334+
_unpad_values_to_avoid_dynamic_slice_oob_shift = partial(
335+
_pad_values_to_avoid_dynamic_slice_oob_shift, unpad=True)
301336

302337
def _load_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_):
303338
del out_avals # Unused.
@@ -315,6 +350,10 @@ def _load_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_):
315350
scalar_dims = [not isinstance(s, Slice) and not s.shape for s in indices]
316351
slice_starts = [s.start if isinstance(s, Slice) else s for s in indices]
317352
slice_sizes = tuple(s.size if isinstance(s, Slice) else 1 for s in indices)
353+
# fixes an inconstency with lax.dynamic_slice where if the slice goes out
354+
# of bounds, it will instead move the start_index backwards so the slice
355+
# will fit in memory.
356+
ref = _pad_values_to_avoid_dynamic_slice_oob_shift(ref, slice_sizes)
318357
out_ones = lax.dynamic_slice(ref, slice_starts, slice_sizes=slice_sizes)
319358
out_indexer = tuple(0 if scalar else slice(None) for scalar in scalar_dims)
320359
out = out_ones[out_indexer]
@@ -424,6 +463,10 @@ def _swap_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_):
424463
]
425464
slice_starts = [s.start if isinstance(s, Slice) else s for s in indices]
426465
slice_sizes = tuple(s.size if isinstance(s, Slice) else 1 for s in indices)
466+
# fixes an inconstency with lax.dynamic_update_slice where if the slice
467+
# goes out of bounds, it will instead move the start_index backwards so the
468+
# slice will fit in memory.
469+
ref = _pad_values_to_avoid_dynamic_slice_oob_shift(ref, slice_sizes)
427470
out = lax.dynamic_slice(ref, slice_starts, slice_sizes=slice_sizes)
428471
out = jnp.squeeze(out, scalar_dims)
429472
if mask is not None:
@@ -432,6 +475,7 @@ def _swap_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_):
432475
val = jnp.where(mask, val, out_)
433476
val = jnp.expand_dims(val, scalar_dims)
434477
x_new = lax.dynamic_update_slice(ref, val, start_indices=slice_starts)
478+
x_new = _unpad_values_to_avoid_dynamic_slice_oob_shift(x_new, slice_sizes)
435479
elif all(not isinstance(s, Slice) for s in idx.indices):
436480
out = ref[idx.indices]
437481
if mask is not None:

tests/pallas/pallas_test.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1459,6 +1459,27 @@ def kernel(x_ref, o_ref):
14591459
x = random.normal(key, (size,))
14601460
np.testing.assert_allclose(kernel(x), x + 1.0, atol=1e-5, rtol=1e-5)
14611461

1462+
def test_masked_oob_load_store_slice(self):
1463+
n = 16
1464+
1465+
@functools.partial(
1466+
self.pallas_call,
1467+
out_shape=(jax.ShapeDtypeStruct((n,), jnp.float32)),
1468+
grid=1,
1469+
)
1470+
def masked_oob_load_store_slice(x_ref, mask_ref, start_idx_ref, o_ref):
1471+
x = pl.load(x_ref, (pl.dslice(start_idx_ref[()], n)),
1472+
mask=mask_ref[:], other=-1.)
1473+
pl.store(o_ref, (pl.dslice(None),), x)
1474+
1475+
x = random.normal(random.key(0), (n,))
1476+
slice_start = random.randint(random.key(2), (), 1, n)
1477+
indices = jnp.arange(n) + slice_start
1478+
mask = indices < n
1479+
out = masked_oob_load_store_slice(x, mask, slice_start)
1480+
o_new = jnp.where(mask, x[indices], jnp.full_like(x, -1.))
1481+
np.testing.assert_array_equal(out, o_new)
1482+
14621483
def test_strided_load(self):
14631484
if self.INTERPRET:
14641485
# TODO(b/329733289): Remove this once the bug is fixed.
@@ -1559,6 +1580,35 @@ def masked_swap(_, _2, mask_ref, x_ref, y_ref):
15591580
np.testing.assert_array_equal(out[0], jnp.where(mask, y, x))
15601581
np.testing.assert_array_equal(out[1], jnp.where(mask, x, y))
15611582

1583+
def test_masked_oob_swap_slice(self):
1584+
m, n = 32, 16
1585+
1586+
@functools.partial(
1587+
self.pallas_call,
1588+
out_shape=(jax.ShapeDtypeStruct((n,), jnp.float32),
1589+
jax.ShapeDtypeStruct((m,), jnp.float32)),
1590+
grid=1,
1591+
input_output_aliases={0: 0, 1: 1},
1592+
)
1593+
def masked_oob_swap_slice(_, _2, mask_ref, start_idx_ref, x_ref, y_ref):
1594+
x, mask = x_ref[:], mask_ref[:]
1595+
y = pl.swap(y_ref, (pl.dslice(start_idx_ref[()], n)), x, mask=mask)
1596+
x_ref[:] = y
1597+
1598+
x = random.normal(random.key(0), (n,))
1599+
y = random.normal(random.key(1), (m,))
1600+
slice_start = random.randint(random.key(2), (), m-n+1, m)
1601+
indices = jnp.arange(n) + slice_start
1602+
mask = indices < m
1603+
out = masked_oob_swap_slice(x, y, mask, slice_start)
1604+
1605+
# the unjittable masked indexing equivalent
1606+
unmasked_idx = indices[mask]
1607+
x_new = x.at[mask].set(y[unmasked_idx])
1608+
y_new = y.at[unmasked_idx].set(x[mask])
1609+
np.testing.assert_array_equal(out[0], x_new)
1610+
np.testing.assert_array_equal(out[1], y_new)
1611+
15621612
@parameterized.named_parameters(
15631613
("add_i32", pl.atomic_add, np.array([1, 2, 3, 4], np.int32), np.sum),
15641614
("max_i", pl.atomic_max, np.array([1, 2, 3, 4], np.int32), np.max),

0 commit comments

Comments
 (0)