Skip to content

Commit d5ead57

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Mosaic TPU] Add support for modeling loads/stores and fix minor issues in model extraction
PiperOrigin-RevId: 703102072
1 parent 569c2a3 commit d5ead57

File tree

2 files changed

+125
-2
lines changed

2 files changed

+125
-2
lines changed

jax/_src/pallas/mosaic/verification.py

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,15 @@ def block(self, begin: str, end: str):
145145
self.level -= 1
146146
self.locals.append(self._indent(end) + "\n")
147147

148+
@contextlib.contextmanager
149+
def comment_if_emitted(self, comment):
150+
self.comment(comment)
151+
yield
152+
self.comment(comment)
153+
if self.locals[-1] == self.locals[-2]:
154+
self.locals.pop()
155+
self.locals.pop()
156+
148157
def get(self, value: ir.Value, default: Any = _UNSPECIFIED):
149158
if default is _UNSPECIFIED:
150159
return self.env[value]
@@ -358,6 +367,17 @@ def _print_op(ctx, op):
358367
return bin_op(ctx, "int", "%", *op.operands)
359368
case "arith.divsi":
360369
return bin_op(ctx, "int", "/", *op.operands)
370+
case "arith.andi":
371+
return bin_op(ctx, _model_type(op.result.type), "&", *op.operands)
372+
case "arith.select":
373+
cond, if_true, if_false = map(lambda o: ctx.get(o, None), op.operands)
374+
if cond is None or if_true is None or if_false is None:
375+
return NotImplemented
376+
result_ty = _model_type(op.result.type)
377+
return ctx.emit(result_ty, f"({cond} -> {if_true} : {if_false})")
378+
case "arith.index_cast":
379+
model = ctx.get(op.operands[0], None)
380+
return ctx.emit("int", model) if model is not None else NotImplemented
361381
case "arith.cmpi":
362382
match op.predicate.value:
363383
case arith.CmpIPredicate.eq:
@@ -386,12 +406,44 @@ def _print_op(ctx, op):
386406
read_refs.append(model)
387407
with ctx.block("d_step {", "}"): # Start reading
388408
for r in read_refs:
409+
for loc in r.written_at(None):
410+
ctx.emit(None, f"assert(!{loc})")
389411
for loc in r.readers_at(None):
390412
ctx.emit(None, f"{loc}++")
391413
with ctx.block("d_step {", "}"): # Stop reading
392414
for r in read_refs:
393415
for loc in r.readers_at(None):
394416
ctx.emit(None, f"{loc}--")
417+
case "vector.load":
418+
ref = ctx.get(op.operands[0])
419+
assert isinstance(ref, GlobalRefModel)
420+
if (first_idx := ctx.get(op.operands[1], None)) is not None:
421+
leading_load_len = ir.VectorType(op.result.type).shape[0]
422+
ref = GlobalRefModel(f"{ref.base} + {first_idx}", leading_load_len)
423+
with ctx.block("d_step {", "}"): # Start reading
424+
for loc in ref.written_at(None):
425+
ctx.emit(None, f"assert(!{loc})")
426+
for loc in ref.readers_at(None):
427+
ctx.emit(None, f"{loc}++")
428+
with ctx.block("d_step {", "}"): # Stop reading
429+
for loc in ref.readers_at(None):
430+
ctx.emit(None, f"{loc}--")
431+
return NotImplemented # We don't model the result of the load.
432+
case "vector.store":
433+
ref = ctx.get(op.operands[1]) # Stored value goes first
434+
assert isinstance(ref, GlobalRefModel)
435+
if (first_idx := ctx.get(op.operands[2], None)) is not None:
436+
leading_store_len = ir.VectorType(op.operands[0].type).shape[0]
437+
ref = GlobalRefModel(f"{ref.base} + {first_idx}", leading_store_len)
438+
with ctx.block("d_step {", "}"): # Start writing
439+
for loc in ref.readers_at(None):
440+
ctx.emit(None, f"assert(!{loc})")
441+
for loc in ref.written_at(None):
442+
ctx.emit(None, f"assert(!{loc})")
443+
ctx.emit(None, f"{loc} = 1")
444+
with ctx.block("d_step {", "}"): # Stop reading
445+
for loc in ref.written_at(None):
446+
ctx.emit(None, f"{loc} = 0")
395447
case "scf.for":
396448
carrys = [
397449
ctx.emit("int", ctx.get(arg))
@@ -419,6 +471,7 @@ def _print_op(ctx, op):
419471
ctx.emit(None, f"{c} = {ctx.get(new)}")
420472
ctx.emit(None, f"{induction_var} = {induction_var} + {step}")
421473
ctx.emit(None, ":: else -> break")
474+
ctx.emit(None, "skip") # To avoid "Jump into d_step sequence errors"
422475
if len(carrys) == 1:
423476
return carrys[0]
424477
else:
@@ -450,16 +503,27 @@ def bin_op(ctx, result_ty, op, lhs, rhs):
450503
return ctx.emit(result_ty, f"{lhs} {op} {rhs}")
451504

452505

506+
def _model_type(ty):
507+
if ir.IntegerType.isinstance(ty):
508+
if ir.IntegerType(ty).width == 1:
509+
return "bool"
510+
else:
511+
return "int"
512+
else:
513+
raise NotImplementedError(ty)
514+
515+
453516
def _print_block(ctx, block):
454517
for op in block:
455518
try:
456-
results = _print_op(ctx, op)
519+
with ctx.comment_if_emitted(op.OPERATION_NAME):
520+
results = _print_op(ctx, op)
457521
except Exception as e:
458522
raise RuntimeError(f"Failed to print op: {op}") from e
459523
if results is NotImplemented:
460524
continue
461525
if not op.results:
462-
assert results is None
526+
assert results is None or results == ()
463527
elif len(op.results) > 1:
464528
raise NotImplementedError(op)
465529
else:

tests/pallas/tpu_pallas_distributed_test.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
"""Tests for distributed pallas TPU operations."""
1616

1717
import functools
18+
import os
19+
import tempfile
1820
from absl.testing import absltest
1921
from absl.testing import parameterized
2022
import jax
@@ -513,5 +515,62 @@ def _():
513515
atol=1e-5,
514516
rtol=1e-3)
515517

518+
519+
class VerificationTest(jtu.JaxTestCase):
520+
521+
def test_verification(self):
522+
if (num_devices := jax.local_device_count()) <= 1:
523+
self.skipTest('Test requires multiple devices.')
524+
if not jtu.is_device_tpu_at_least(4) or jax.devices()[0].num_cores > 1:
525+
self.skipTest('Test requires a new single-core TPU.')
526+
def kernel_body(in_ref, out_ref, scratch_ref, send_sem, recv_sem, capacity_sem):
527+
my_id = lax.axis_index('x')
528+
dst_id = jnp.where(my_id == num_devices - 1, 0, my_id + 1)
529+
src_id = jnp.where(my_id == 0, num_devices - 1, my_id - 1)
530+
pltpu.semaphore_signal(capacity_sem, 1, device_id=src_id)
531+
out_ref[...] = jnp.zeros_like(out_ref)
532+
scratch_ref[0] = in_ref[0]
533+
534+
@functools.partial(lax.fori_loop, 0, num_devices - 1, init_val=None)
535+
def _(i, _):
536+
slot = i % 2
537+
next_slot = 1 - slot
538+
pltpu.semaphore_wait(capacity_sem, 1)
539+
copy = pltpu.async_remote_copy(
540+
scratch_ref.at[slot],
541+
scratch_ref.at[next_slot],
542+
send_sem,
543+
recv_sem,
544+
device_id=dst_id,
545+
)
546+
out_ref[...] += scratch_ref[slot]
547+
copy.wait()
548+
pltpu.semaphore_signal(capacity_sem, 1, device_id=src_id)
549+
out_ref[...] += scratch_ref[(num_devices - 1) % 2]
550+
pltpu.semaphore_wait(capacity_sem, 1)
551+
552+
kernel = pl.pallas_call(
553+
kernel_body,
554+
out_shape=jax.ShapeDtypeStruct((128, 128), jnp.float32),
555+
scratch_shapes=[
556+
pltpu.VMEM((2, 128, 128), jnp.float32),
557+
pltpu.SemaphoreType.DMA,
558+
pltpu.SemaphoreType.DMA,
559+
pltpu.SemaphoreType.REGULAR,
560+
],
561+
)
562+
devices = mesh_utils.create_device_mesh((num_devices,))
563+
mesh = jax.sharding.Mesh(devices, ['x'])
564+
# This is just a smoke test to ensure that the verification does not crash.
565+
with tempfile.TemporaryDirectory() as tmpdir:
566+
previous_config = jax.config.read('jax_pallas_dump_promela_to')
567+
jax.config.update('jax_pallas_dump_promela_to', tmpdir)
568+
shard_map.shard_map(
569+
kernel, mesh=mesh, in_specs=P('x'), out_specs=P(None), check_rep=False
570+
)(jnp.ones((8, 128, 128), jnp.float32))
571+
jax.config.update('jax_pallas_dump_promela_to', previous_config)
572+
self.assertNotEmpty(os.listdir(tmpdir))
573+
574+
516575
if __name__ == '__main__':
517576
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)