Skip to content

Commit 0995bc2

Browse files
chr1sj0nesGoogle-ML-Automation
authored andcommitted
[pallas:triton] Simplify lowering code. BlockInfo is now always present for memory refs.
PiperOrigin-RevId: 695414469
1 parent 242ac2b commit 0995bc2

File tree

1 file changed

+24
-38
lines changed

1 file changed

+24
-38
lines changed

jax/_src/pallas/triton/lowering.py

Lines changed: 24 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -87,15 +87,15 @@ class ModuleContext:
8787
class BlockInfo:
8888
full_shape_dtype: jax.ShapeDtypeStruct
8989
start_indices: Sequence[Any]
90-
block_shape: tuple[int, ...] # TODO(necula): can this contain "mapped"?
90+
block_shape: tuple[int | pallas_core.Mapped, ...]
9191

9292

9393
@dataclasses.dataclass
9494
class LoweringRuleContext:
9595
context: ModuleContext
9696
avals_in: Sequence[jax_core.ShapedArray]
9797
avals_out: Sequence[jax_core.ShapedArray]
98-
block_infos: Sequence[BlockInfo | None] # TODO(necula): can this be None?
98+
block_infos: Sequence[BlockInfo | None]
9999

100100
replace = dataclasses.replace
101101

@@ -362,14 +362,15 @@ def read_env(atom: jax_core.Atom):
362362
def read_block_info_env(atom: jax_core.Atom):
363363
if isinstance(atom, jax_core.Literal):
364364
return None
365-
return block_info_env.get(atom, None)
365+
return block_info_env.get(atom)
366366

367367
def write_env(var: jax_core.Var, val):
368368
env[var] = val
369369

370370
if block_infos is not None:
371371
for invar, block_info in zip(jaxpr.invars, block_infos):
372-
block_info_env[invar] = block_info
372+
if block_info is not None:
373+
block_info_env[invar] = block_info
373374

374375
map(write_env, jaxpr.invars, args)
375376

@@ -393,7 +394,7 @@ def write_env(var: jax_core.Var, val):
393394
raise # We only add the extra info to the innermost exception.
394395
except Exception as e:
395396
if not pallas_call._verbose_errors_enabled():
396-
raise
397+
raise
397398
inval_types = map(lambda t: getattr(t, "type", None), invals)
398399
raise LoweringError(
399400
f"Exception while lowering eqn:\n {eqn}\nWith context:\n "
@@ -474,14 +475,14 @@ def _atomic_lowering_rule(
474475
args_tree,
475476
atomic_type: primitives.AtomicOpType,
476477
):
478+
block_info, *_ = ctx.block_infos
479+
assert block_info is not None
477480
ptr, indexers, val, mask = args_tree.unflatten(args_flat)
478481
*_, value_aval, mask_aval = args_tree.unflatten(ctx.avals_in)
479482
if len(indexers) != 1:
480483
raise NotImplementedError("Only single indexer is supported.")
481484
idx = indexers[0]
482-
ptr = _compute_pointers_from_indices(
483-
ptr, ctx.block_infos[0], idx, ctx.avals_in[0]
484-
)
485+
ptr = _compute_pointers_from_indices(ptr, block_info, idx)
485486
val = _ensure_ir_value(val, value_aval)
486487
if mask is not None:
487488
mask = _ensure_ir_value(mask, mask_aval)
@@ -1631,36 +1632,21 @@ def _reshape_lowering_rule(
16311632

16321633

16331634
def _compute_pointers_from_indices(
1634-
root_ptr: ir.Value,
1635-
block_info: BlockInfo | None,
1636-
nd_indexer: NDIndexer,
1637-
array_shape_dtype: Any,
1635+
root_ptr: ir.Value, block_info: BlockInfo, nd_indexer: NDIndexer
16381636
) -> ir.Value:
1639-
if block_info is None: # TODO(necula): is this branch dead?
1640-
full_shape = array_shape_dtype.shape
1641-
num_mapped_dims = 0
1642-
block_shape = array_shape_dtype.shape
1643-
else:
1644-
full_shape = block_info.full_shape_dtype.shape
1645-
num_mapped_dims = sum(
1646-
b is pallas_core.mapped for b in block_info.block_shape
1647-
)
1648-
block_shape = block_info.block_shape
1637+
full_shape = block_info.full_shape_dtype.shape
1638+
num_mapped_dims = sum(b is pallas_core.mapped for b in block_info.block_shape)
16491639
strides = pallas_utils.strides_from_shape(full_shape)
16501640
indexer_shape = nd_indexer.get_indexer_shape()
16511641
int_indexer_shape = nd_indexer.int_indexer_shape
16521642
_check_tensor_size(indexer_shape)
16531643
indices = nd_indexer.indices
16541644
other_shape = indexer_shape[len(int_indexer_shape) :]
16551645
other_shape_idx = 0
1656-
if block_info is None:
1657-
start_index_offsets = [None] * len(indices)
1658-
else:
1659-
start_index_offsets = block_info.start_indices
16601646
assert len(indices) + num_mapped_dims == len(full_shape)
1661-
assert len(start_index_offsets) == len(full_shape)
1647+
assert len(block_info.start_indices) == len(full_shape)
16621648

1663-
array_dtype = jnp.dtype(array_shape_dtype.dtype)
1649+
array_dtype = jnp.dtype(block_info.full_shape_dtype.dtype)
16641650
full_size = math.prod(full_shape) * array_dtype.itemsize
16651651
# Use 64-bit indexing when offset might be >= 2**32 bytes.
16661652
offset_eltype = ir.IntegerType.get_signless(64 if full_size > 2**32 else 32)
@@ -1671,7 +1657,7 @@ def _compute_pointers_from_indices(
16711657

16721658
indexer_iter = iter(indices)
16731659
for dim_stride, dim_block_size, start_offset in zip(
1674-
strides, block_shape, start_index_offsets
1660+
strides, block_info.block_shape, block_info.start_indices
16751661
):
16761662
if dim_block_size is pallas_core.mapped:
16771663
index = _ir_constant(0, offset_eltype)
@@ -1831,6 +1817,8 @@ def _masked_load_lowering_rule(
18311817
cache_modifier,
18321818
is_volatile,
18331819
):
1820+
block_info, *_ = ctx.block_infos
1821+
assert block_info is not None
18341822
ptr, indexers, mask, other = args_tree.unflatten(args_flat)
18351823
*_, mask_aval, other_aval = args_tree.unflatten(ctx.avals_in)
18361824
if len(indexers) > 1:
@@ -1839,9 +1827,7 @@ def _masked_load_lowering_rule(
18391827
if not tt_dialect.PointerType.isinstance(ptr.type):
18401828
assert len(ctx.avals_in) == 1
18411829
return ptr
1842-
ptr = _compute_pointers_from_indices(
1843-
ptr, ctx.block_infos[0], idx, ctx.avals_in[0]
1844-
)
1830+
ptr = _compute_pointers_from_indices(ptr, block_info, idx)
18451831
if mask is not None:
18461832
mask = _bcast_to(_ensure_ir_value(mask, mask_aval), idx.get_indexer_shape())
18471833
if other is not None:
@@ -1931,14 +1917,14 @@ def _store(
19311917
def _masked_swap_lowering_rule(
19321918
ctx: LoweringRuleContext, *args_flat, args_tree, eviction_policy
19331919
):
1920+
block_info, *_ = ctx.block_infos
1921+
assert block_info is not None
19341922
ptr, indexers, value, mask = args_tree.unflatten(args_flat)
19351923
*_, value_aval, mask_aval = args_tree.unflatten(ctx.avals_in)
19361924
if len(indexers) > 1:
19371925
raise NotImplementedError("No support for multiple indexers yet.")
19381926
idx = indexers[0]
1939-
ptr = _compute_pointers_from_indices(
1940-
ptr, ctx.block_infos[0], idx, ctx.avals_in[0]
1941-
)
1927+
ptr = _compute_pointers_from_indices(ptr, block_info, idx)
19421928
other = None
19431929
if value is not None:
19441930
value = _ensure_ir_value(value, value_aval)
@@ -1954,16 +1940,16 @@ def _masked_swap_lowering_rule(
19541940

19551941
@register_lowering(sp.addupdate_p)
19561942
def _addupdate_lowering_rule(ctx: LoweringRuleContext, ptr, value, *idx, tree):
1943+
block_info, *_ = ctx.block_infos
1944+
assert block_info is not None
19571945
indexers = tree_util.tree_unflatten(tree, idx)
19581946
if not tt_dialect.PointerType.isinstance(ptr.type):
19591947
assert len(indexers) == 0
19601948
return ptr
19611949
if len(indexers) > 1:
19621950
raise NotImplementedError("No support for multiple indexers yet.")
19631951
indexer = indexers[0]
1964-
ptr = _compute_pointers_from_indices(
1965-
ptr, ctx.block_infos[0], indexer, ctx.avals_in[0]
1966-
)
1952+
ptr = _compute_pointers_from_indices(ptr, block_info, indexer)
19671953
op = tt_dialect.RMWOp.FADD
19681954
if isinstance(_element_type(value.type), ir.IntegerType):
19691955
op = tt_dialect.RMWOp.ADD

0 commit comments

Comments
 (0)