Skip to content

Commit aceae84

Browse files
[Pallas] Enable skipping of floating-point operations when interpreting Pallas TPU kernels on CPU.
PiperOrigin-RevId: 735527650
1 parent 802cb33 commit aceae84

File tree

3 files changed

+108
-26
lines changed

3 files changed

+108
-26
lines changed

jax/_src/pallas/mosaic/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,9 @@ py_library(
159159
":core",
160160
":primitives",
161161
"//jax",
162+
"//jax:core",
163+
"//jax:source_info_util",
164+
"//jax:util",
162165
"//jax/_src/lib",
163166
"//jax/_src/pallas",
164167
] + py_deps("numpy"),

jax/_src/pallas/mosaic/interpret.py

Lines changed: 84 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -69,18 +69,24 @@ class TPUInterpretParams:
6969
7070
Attributes:
7171
dma_execution_mode: If "eager", DMAs are executed as soon as they are
72-
issued. If "on_wait", DMA reads or writes are only executed when a
73-
device is waiting on a DMA semaphore that will be signaled when the read
74-
or write is complete.
72+
issued. If "on_wait", DMA reads or writes are only executed when a device
73+
is waiting on a DMA semaphore that will be signaled when the read or write
74+
is complete.
7575
Default: "on_wait".
7676
detect_races: If True, a dynamic, happens-before race detector will be
7777
used to detect data races during kernel interpretation. If any races are
7878
detected, a message will be printed and `races.races_found` will be set
7979
to True.
8080
Default: False.
81+
skip_floating_point_ops: If True, operations that produce only floating
82+
point values will not be interpreted; instead, their results will be
83+
replaced with arrays all of `jnp.inf`. Additionaly any floating point
84+
operands to any operation will be replaced with (arrays of) `jnp.inf`.
85+
Default: False.
8186
"""
8287
dma_execution_mode: Literal["eager", "on_wait"] = "on_wait"
8388
detect_races: bool = False
89+
skip_floating_point_ops: bool = False
8490

8591

8692
VectorClock = np.ndarray
@@ -954,16 +960,32 @@ def _is_any(memory_space):
954960
return ((memory_space == mosaic_core.TPUMemorySpace.ANY) or
955961
(memory_space == pallas_core.MemorySpace.ANY))
956962

963+
def _is_float(dtype):
964+
return jnp.issubdtype(dtype, jnp.floating)
965+
966+
_SENTINEL = jnp.inf
967+
968+
@dataclasses.dataclass(frozen=True)
969+
class Placeholder:
970+
"""Placeholder for use in `_interpret_jaxpr` below instead of putting a concrete value into `env`."""
971+
shape: tuple[int, ...]
972+
dtype: jnp.dtype
973+
957974
def _interpret_jaxpr(jaxpr, *args, compiler_params, interpret_params):
958975
env = {}
959976

960977
def read(var):
961978
if isinstance(var, jax_core.Literal):
962-
return var.val
979+
result = var.val
963980
else:
964-
return env[var]
981+
result = env[var]
982+
if isinstance(result, Placeholder):
983+
result = jax.lax.full(result.shape, _SENTINEL, result.dtype)
984+
return result
965985

966986
def write(var, value):
987+
if interpret_params.skip_floating_point_ops and _is_float(value.dtype):
988+
value = Placeholder(value.shape, value.dtype)
967989
env[var] = value
968990

969991
jax.util.safe_map(write, jaxpr.constvars + jaxpr.invars, args)
@@ -987,11 +1009,16 @@ def write(var, value):
9871009
with source_info_util.user_context(
9881010
eqn.source_info.traceback, name_stack=eqn.source_info.name_stack):
9891011
prim = eqn.primitive
990-
invals = jax.util.safe_map(read, eqn.invars)
1012+
# We defer reading the values for `eqn.invars` into each of the branches
1013+
# of the if-elif-else statement below. This is because the else branch may
1014+
# not need to do any reads if `interpret_params.skip_floating_point_ops`
1015+
# is True. If this is the case, we want to avoid materializing the read
1016+
# array into the jaxpr when this function is traced.
1017+
deferred_invals = functools.partial(jax.util.safe_map, read, eqn.invars)
9911018

9921019
if prim is primitives.load_p:
9931020
(ref, transforms, mask, _) = jax.tree.unflatten(
994-
eqn.params['args_tree'], invals)
1021+
eqn.params['args_tree'], deferred_invals())
9951022
if mask is not None:
9961023
raise NotImplementedError('masked load_p')
9971024
out = callback.io_callback(
@@ -1005,7 +1032,7 @@ def write(var, value):
10051032

10061033
elif prim is primitives.swap_p:
10071034
(ref, transforms, val, mask) = jax.tree.unflatten(
1008-
eqn.params['args_tree'], invals)
1035+
eqn.params['args_tree'], deferred_invals())
10091036
out = callback.io_callback(
10101037
functools.partial(swap, source_info=eqn.source_info),
10111038
eqn.outvars[0].aval,
@@ -1023,6 +1050,7 @@ def write(var, value):
10231050
elif prim is lax.cond_p:
10241051
def _make_branch(jaxpr):
10251052
return lambda *args: _interpret(jaxpr, *args)
1053+
invals = deferred_invals()
10261054
out = lax.switch(
10271055
invals[0],
10281056
[_make_branch(branch_jaxpr.jaxpr)
@@ -1031,7 +1059,9 @@ def _make_branch(jaxpr):
10311059

10321060
elif prim is lax.scan_p:
10331061
consts, init_carry, xs = split_list(
1034-
invals, [eqn.params['num_consts'], eqn.params['num_carry']])
1062+
deferred_invals(),
1063+
[eqn.params['num_consts'], eqn.params['num_carry']],
1064+
)
10351065
def _scan_body(c, a):
10361066
return split_list(
10371067
_interpret(eqn.params['jaxpr'].jaxpr, *consts, *c, *a),
@@ -1041,8 +1071,10 @@ def _scan_body(c, a):
10411071
out = carry + out
10421072

10431073
elif prim is lax.while_p:
1044-
cond_consts, body_consts, init_vals = split_list(
1045-
invals, [eqn.params['cond_nconsts'], eqn.params['body_nconsts']])
1074+
cond_consts, body_consts, init_vals = split_list(
1075+
deferred_invals(),
1076+
[eqn.params['cond_nconsts'], eqn.params['body_nconsts']],
1077+
)
10461078
out = lax.while_loop(
10471079
lambda args: _interpret(
10481080
eqn.params['cond_jaxpr'].jaxpr, *cond_consts, *args)[0],
@@ -1056,6 +1088,7 @@ def _scan_body(c, a):
10561088
elif prim is pjit.pjit_p:
10571089
def f(*args, jaxpr):
10581090
return _interpret(jaxpr.jaxpr, *jaxpr.consts, *args)
1091+
invals = deferred_invals()
10591092
in_avals = tuple(jax_core.shaped_abstractify(i) for i in invals)
10601093
new_jaxpr = _to_jaxpr(
10611094
lu.wrap_init(functools.partial(f, jaxpr=eqn.params['jaxpr']),
@@ -1084,7 +1117,7 @@ def f(*args, jaxpr):
10841117
primitives.uninitialized_value(v.aval.shape, v.aval.dtype),
10851118
ordered=True))
10861119

1087-
out = _interpret(eqn.params['jaxpr'], *invals, *allocs)
1120+
out = _interpret(eqn.params['jaxpr'], *deferred_invals(), *allocs)
10881121

10891122
for a in allocs:
10901123
if isinstance(a, tuple):
@@ -1106,6 +1139,7 @@ def f(*args, jaxpr):
11061139
pass
11071140

11081141
elif prim is state_primitives.get_p:
1142+
invals = deferred_invals()
11091143
out = callback.io_callback(
11101144
functools.partial(get, source_info=eqn.source_info),
11111145
eqn.outvars[0].aval,
@@ -1116,6 +1150,7 @@ def f(*args, jaxpr):
11161150
ordered=True)
11171151

11181152
elif prim is state_primitives.swap_p:
1153+
invals = deferred_invals()
11191154
out = callback.io_callback(
11201155
functools.partial(swap, source_info=eqn.source_info),
11211156
eqn.outvars[0].aval,
@@ -1128,11 +1163,17 @@ def f(*args, jaxpr):
11281163
ordered=True)
11291164

11301165
elif prim is mosaic_primitives.dma_start_p:
1131-
(src, src_transforms,
1132-
dst, dst_transforms,
1133-
dst_sem, dst_sem_transforms,
1134-
src_sem, src_sem_transforms,
1135-
target_device_id) = jax.tree.unflatten(eqn.params['tree'], invals)
1166+
(
1167+
src,
1168+
src_transforms,
1169+
dst,
1170+
dst_transforms,
1171+
dst_sem,
1172+
dst_sem_transforms,
1173+
src_sem,
1174+
src_sem_transforms,
1175+
target_device_id,
1176+
) = jax.tree.unflatten(eqn.params['tree'], deferred_invals())
11361177
target_device_id = _device_id_to_logical(
11371178
target_device_id, eqn.params['device_id_type'], axis_sizes)
11381179
(orig_src_ref, _, orig_dst_ref, *_
@@ -1152,11 +1193,17 @@ def f(*args, jaxpr):
11521193
out = []
11531194

11541195
elif prim is mosaic_primitives.dma_wait_p:
1155-
(src, src_transforms,
1156-
dst, dst_transforms,
1157-
dst_sem, dst_sem_transforms,
1158-
src_sem, src_sem_transforms,
1159-
target_device_id) = jax.tree.unflatten(eqn.params['tree'], invals)
1196+
(
1197+
src,
1198+
src_transforms,
1199+
dst,
1200+
dst_transforms,
1201+
dst_sem,
1202+
dst_sem_transforms,
1203+
src_sem,
1204+
src_sem_transforms,
1205+
target_device_id,
1206+
) = jax.tree.unflatten(eqn.params['tree'], deferred_invals())
11601207
read_shape, read_dtype = _compute_transformed_shape_and_dtype(
11611208
eqn.invars[0].aval.shape, eqn.invars[0].aval.dtype, src_transforms)
11621209
callback.io_callback(
@@ -1178,7 +1225,7 @@ def f(*args, jaxpr):
11781225

11791226
elif prim is mosaic_primitives.semaphore_signal_p:
11801227
sem, sem_transforms, inc, target_device_id, core_index = (
1181-
jax.tree.unflatten(eqn.params['args_tree'], invals))
1228+
jax.tree.unflatten(eqn.params['args_tree'], deferred_invals()))
11821229
target_device_id = _device_id_to_logical(
11831230
target_device_id, eqn.params['device_id_type'], axis_sizes)
11841231
callback.io_callback(
@@ -1194,7 +1241,7 @@ def f(*args, jaxpr):
11941241

11951242
elif prim is mosaic_primitives.semaphore_wait_p:
11961243
sem, sem_transforms, value = (
1197-
jax.tree.unflatten(eqn.params['args_tree'], invals))
1244+
jax.tree.unflatten(eqn.params['args_tree'], deferred_invals()))
11981245
callback.io_callback(
11991246
semaphore_wait,
12001247
(),
@@ -1211,8 +1258,19 @@ def f(*args, jaxpr):
12111258
raise NotImplementedError('atomic_cas_p')
12121259

12131260
else:
1214-
subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
1215-
out = prim.bind(*subfuns, *invals, **bind_params)
1261+
if interpret_params.skip_floating_point_ops and all(
1262+
_is_float(ovar.aval.dtype) for ovar in eqn.outvars
1263+
):
1264+
# Skip `prim.bind` since `prim` only produces floating-point values.
1265+
# It is safe to populate `out` with avals since mapping `write` over
1266+
# `out` below only relies on the shape and dtype (for writing
1267+
# `Placeholder`s).
1268+
out = [ovar.aval for ovar in eqn.outvars]
1269+
if not prim.multiple_results:
1270+
out = out[0]
1271+
else:
1272+
subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
1273+
out = prim.bind(*subfuns, *deferred_invals(), **bind_params)
12161274

12171275
out = out if prim.multiple_results else [out]
12181276
jax.util.safe_map(write, eqn.outvars, out)

tests/pallas/tpu_pallas_interpret_test.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,27 @@ def kernel_with_race(x_ref, o_ref, t_ref, sem):
134134
)(x).block_until_ready()
135135
self.assertTrue(mosaic_interpret.races.races_found)
136136

137+
def test_skip_floating_point_ops(self):
138+
def matmul_kernel(x_ref, y_ref, z_ref):
139+
z_ref[...] = x_ref[...] @ y_ref[...]
140+
141+
def matmul(x: jax.Array, y: jax.Array):
142+
return pl.pallas_call(
143+
matmul_kernel,
144+
out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype),
145+
interpret=mosaic_interpret.TPUInterpretParams(
146+
skip_floating_point_ops=True
147+
),
148+
)(x, y)
149+
150+
k1, k2 = jax.random.split(jax.random.key(0))
151+
x = jax.random.normal(k1, (1024, 1024))
152+
y = jax.random.normal(k2, (1024, 1024))
153+
z = jax.jit(matmul)(x, y)
154+
np.testing.assert_array_equal(z, jnp.full_like(z, jnp.inf))
155+
156+
lowered = jax.jit(matmul).lower(x, y).as_text(dialect="stablehlo")
157+
self.assertNotIn("dot_general", lowered)
137158

138159
if __name__ == "__main__":
139160
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)