Skip to content

Commit e9ce8fb

Browse files
Merge pull request jax-ml#27227 from jburnim:jburnim_pallas_interpret_mode4
PiperOrigin-RevId: 738235363
2 parents f3b7c5c + 47e8eff commit e9ce8fb

File tree

2 files changed

+65
-10
lines changed

2 files changed

+65
-10
lines changed

jax/_src/pallas/mosaic/interpret.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,15 @@ class TPUInterpretParams:
8383
replaced with arrays all of `jnp.inf`. Additionaly any floating point
8484
operands to any operation will be replaced with (arrays of) `jnp.inf`.
8585
Default: False.
86+
uninitialized_memory: If "nan", allocated buffers are initialized to
87+
to contain all NaNs (or to their maximum possible value for integers).
88+
If "zero", allocated buffers are initialized to all zeros.
89+
Default: "nan".
8690
"""
8791
dma_execution_mode: Literal["eager", "on_wait"] = "on_wait"
8892
detect_races: bool = False
8993
skip_floating_point_ops: bool = False
94+
uninitialized_memory: Literal["nan", "zero"] = "nan"
9095

9196

9297
VectorClock = np.ndarray
@@ -1114,7 +1119,8 @@ def f(*args, jaxpr):
11141119
jax.ShapeDtypeStruct((), jnp.int16),
11151120
device_id,
11161121
TPU_MEMORY_SPACE_IDXS[v.aval.memory_space],
1117-
primitives.uninitialized_value(v.aval.shape, v.aval.dtype),
1122+
_uninitialized_value(
1123+
v.aval.shape, v.aval.dtype, interpret_params),
11181124
ordered=True))
11191125

11201126
out = _interpret(eqn.params['jaxpr'], *deferred_invals(), *allocs)
@@ -1279,16 +1285,19 @@ def f(*args, jaxpr):
12791285

12801286
def _initialize_output_vals(
12811287
block_mappings_output: Iterable[BlockMapping],
1282-
input_args, input_output_aliases) -> Sequence[jax.Array]:
1288+
input_args, input_output_aliases,
1289+
interpret_params: TPUInterpretParams,
1290+
) -> Sequence[jax.Array]:
12831291
oi_map = {v: k for k, v in input_output_aliases}
12841292
output_vals = []
12851293
for i, bm in enumerate(block_mappings_output):
12861294
if i in oi_map:
12871295
output_vals.append(input_args[oi_map[i]])
12881296
else:
1289-
output_vals.append(primitives.uninitialized_value(
1297+
output_vals.append(_uninitialized_value(
12901298
bm.array_shape_dtype.shape,
1291-
bm.array_shape_dtype.dtype))
1299+
bm.array_shape_dtype.dtype,
1300+
interpret_params))
12921301
return output_vals
12931302

12941303
def _compute_start_indices(block_mapping, loop_idx, *args):
@@ -1319,7 +1328,20 @@ def _maybe_dynamic_slice(start_idx, block_shape, value, is_indexing):
13191328
dtype=np.bool_)])
13201329
return lax.squeeze(output, squeeze_dims)
13211330

1322-
def _pad_to_block_dimension(value, block_shape):
1331+
def _uninitialized_value(shape, dtype, interpret_params):
1332+
if interpret_params.uninitialized_memory == 'nan':
1333+
if jnp.issubdtype(dtype, jnp.floating):
1334+
return jnp.full(shape, jnp.nan, dtype)
1335+
elif jnp.issubdtype(dtype, jnp.integer):
1336+
return jnp.full(shape, jnp.iinfo(dtype).max, dtype)
1337+
elif jnp.issubdtype(dtype, jnp.bool):
1338+
return jnp.full(shape, False, dtype)
1339+
if interpret_params.uninitialized_memory == 'zero':
1340+
return jnp.full(shape, 0, dtype)
1341+
raise NotImplementedError(
1342+
interpret_params.uninitialized_memory + ' + ' + str(dtype))
1343+
1344+
def _pad_to_block_dimension(value, block_shape, interpret_params):
13231345
"""Pads values so the shape evenly divides into block dimensions.
13241346
13251347
For example, if values has a shape of (33, 2, 5) with a block_shape of
@@ -1338,7 +1360,7 @@ def _pad_to_block_dimension(value, block_shape):
13381360
)
13391361
if padded_shape != value.shape:
13401362
pad_width = tuple((0, a-b) for a, b in zip(padded_shape, value.shape))
1341-
pad_value = primitives.uninitialized_value(shape=(), dtype=value.dtype)
1363+
pad_value = _uninitialized_value((), value.dtype, interpret_params)
13421364
value = jnp.pad(value, pad_width, constant_values=pad_value)
13431365
return value
13441366

@@ -1397,7 +1419,7 @@ def interpret_pallas_call(
13971419
]
13981420
num_inputs = grid_mapping.num_inputs
13991421
input_args = [
1400-
_pad_to_block_dimension(a, bs)
1422+
_pad_to_block_dimension(a, bs, interpret_params)
14011423
for a, bs in zip(input_args, block_shapes[:num_inputs])
14021424
]
14031425

@@ -1407,11 +1429,12 @@ def interpret_pallas_call(
14071429
output_vals = _initialize_output_vals(
14081430
grid_mapping.block_mappings_output,
14091431
scalars + input_args,
1410-
input_output_aliases)
1432+
input_output_aliases,
1433+
interpret_params)
14111434
num_outputs = grid_mapping.num_outputs
14121435
output_block_shapes = block_shapes[num_inputs : num_inputs + num_outputs]
14131436
for out_val, bs in zip(output_vals, output_block_shapes):
1414-
padded_val = _pad_to_block_dimension(out_val, bs)
1437+
padded_val = _pad_to_block_dimension(out_val, bs, interpret_params)
14151438
output_buffer_shapes.append(padded_val.shape)
14161439
output_buffer_ids.append(callback.io_callback(
14171440
_allocate_buffer,
@@ -1466,7 +1489,8 @@ def interpret_pallas_call(
14661489
jax.ShapeDtypeStruct((), jnp.int16),
14671490
device_id,
14681491
TPU_MEMORY_SPACE_IDXS[var.aval.memory_space],
1469-
primitives.uninitialized_value(var.aval.shape, var.aval.dtype),
1492+
_uninitialized_value(
1493+
var.aval.shape, var.aval.dtype, interpret_params),
14701494
ordered=True))
14711495

14721496
_, input_ids, kernel_output_ids, _ = split_list(

tests/pallas/tpu_pallas_interpret_test.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,5 +156,36 @@ def matmul(x: jax.Array, y: jax.Array):
156156
lowered = jax.jit(matmul).lower(x, y).as_text(dialect="stablehlo")
157157
self.assertNotIn("dot_general", lowered)
158158

159+
@parameterized.parameters('nan', 'zero')
160+
def test_uninitialized_memory(self, uninitialized_memory):
161+
def kernel(o1_ref, o2_ref, o3_ref, t1_ref, t2_ref):
162+
o1_ref[...] = t1_ref[...]
163+
o2_ref[...] = t2_ref[...]
164+
165+
x, y, z = pl.pallas_call(
166+
kernel,
167+
out_shape=[
168+
jax.ShapeDtypeStruct((8, 128), jnp.bfloat16),
169+
jax.ShapeDtypeStruct((8, 128), jnp.int16),
170+
jax.ShapeDtypeStruct((8, 128), jnp.float32),
171+
],
172+
in_specs=[],
173+
scratch_shapes=[
174+
pltpu.VMEM((8, 128), jnp.bfloat16),
175+
pltpu.VMEM((8, 128), jnp.int16),
176+
],
177+
interpret=mosaic_interpret.TPUInterpretParams(
178+
uninitialized_memory=uninitialized_memory),
179+
)()
180+
if uninitialized_memory == 'nan':
181+
self.assertTrue(jnp.isnan(x).all())
182+
np.testing.assert_equal(np.array(y), 32767)
183+
self.assertTrue(jnp.isnan(z).all())
184+
if uninitialized_memory == 'zero':
185+
np.testing.assert_equal(np.array(x), 0)
186+
np.testing.assert_equal(np.array(y), 0)
187+
np.testing.assert_equal(np.array(z), 0)
188+
189+
159190
if __name__ == "__main__":
160191
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)