Skip to content

Commit 01a110c

Browse files
Better mosaic lowering for dynamic shapes, extend an interpreter into shape_poly dimexpr and lower them alongside the graph if we are in a dynamic export regime.
PiperOrigin-RevId: 738171437
1 parent 0fb5974 commit 01a110c

File tree

3 files changed

+129
-12
lines changed

3 files changed

+129
-12
lines changed

jax/_src/pallas/core.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from jax._src import state
3636
from jax._src import tree_util
3737
from jax._src import util
38+
from jax._src.export._export import export
3839
from jax._src.interpreters import mlir
3940
from jax._src.interpreters import partial_eval as pe
4041
from jax._src.state import discharge as state_discharge
@@ -1165,14 +1166,16 @@ def _core_map_typecheck_rule(_, *in_atoms, jaxpr, mesh, **kwargs):
11651166

11661167

11671168
def lower_as_mlir(
1168-
f, *args, dynamic_shapes=False, device=None, **kwargs
1169+
f, *args, dynamic_shapes=False, device=None, static_argnames=(), **kwargs
11691170
) -> mlir.ir.Module:
11701171
with pallas_export_experimental(dynamic_shapes):
1171-
lowered = jax.jit(f, device=device).lower(*args, **kwargs)
1172-
stablehlo = lowered.compiler_ir(dialect="stablehlo")
1172+
f = jax.jit(f, device=device, static_argnames=static_argnames)
1173+
exported = export(f, platforms=["tpu"])(*args, **kwargs)
1174+
stablehlo = exported.mlir_module()
11731175

11741176
return stablehlo # type: ignore[return-value]
11751177

1178+
11761179
_out_shape_to_aval_mapping: dict[
11771180
type[Any], Callable[[Any], jax_core.AbstractValue]
11781181
] = {}

jax/_src/pallas/mosaic/lowering.py

Lines changed: 120 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from jax._src import state
4141
from jax._src import traceback_util
4242
from jax._src.cloud_tpu_init import is_cloud_tpu_older_than
43+
from jax._src.export._export import export
4344
from jax._src.interpreters import mlir
4445
from jax._src.interpreters import partial_eval as pe
4546
from jax._src.lax import lax as lax_internal
@@ -89,6 +90,11 @@
8990
# The value interpreted as a dynamic dimension by MLIR.
9091
MLIR_DYNAMIC = -9223372036854775808
9192

93+
# TODO(mvoz): Find a way to make this a contract we can share with the
94+
# export specialization step in XLA export.
95+
DIM_UPPER_BOUND = np.iinfo(np.int32).max
96+
DIM_LOWER_BOUND = -128
97+
9298
partial = functools.partial
9399
map, unsafe_map = safe_map, map # pylint: disable=redefined-builtin
94100
zip, unsafe_zip = safe_zip, zip # pylint: disable=redefined-builtin
@@ -102,17 +108,49 @@ class MeshContext:
102108

103109
# Note - On Export Placeholders
104110
#
105-
# Mosaic uses vector IR, which does not have a concept of dynamic
106-
# dimensions. We need to come up with a way to represent dynamic dimensions in
107-
# vector IR, and so we use placeholders, which are later replaced during
108-
# specialization.
111+
# Since the vector dialect used by Mosaic does not support dynamic shapes,
112+
# we replace all top-level symbolic dimensions with placeholder
113+
# constants (between max(int32) - 128 and max(int32)) and we keep a
114+
# mapping from the placeholder constants to SHLO functions that encode
115+
# the symbolic dimension expression, as a function of the dimension
116+
# variables.
117+
#
118+
# The calling convention of the produced MLIR module is the same as
119+
# regular mosaic module, except we add on two new attributes to the custom call
120+
# *per* intermediary placeholder dimension.
121+
#
122+
# The attributes are:
123+
#
124+
# tpu.dynamic_dimension_mapping_arg_name_<placeholder>
125+
# tpu.dynamic_dimension_mapping_module_<placeholder>
126+
#
127+
# The first attribute is a comma-separated list of the dimension variables
128+
# that are used to compute the symbolic dimension expression for the
129+
# placeholder. The second attribute is the MLIR module that contains the
130+
# SHLO functions that compute the symbolic dimension expression for the
131+
# placeholder.
109132
class LoweringDynamicShapeEnv:
110-
dim_expr_to_placeholder: dict[Any, ir.Value] = {}
133+
dim_expr_to_placeholder: dict[shape_poly._DimExpr, int] = {}
134+
placeholder_to_dim_expr: dict[int, shape_poly._DimExpr] = {}
111135

112136
def to_placeholder(self, dim_expr: Any) -> ir.Value:
137+
if jax_core.is_constant_dim(dim_expr):
138+
# avoid ints, these are not dynamic
139+
return dim_expr
113140
if dim_expr not in self.dim_expr_to_placeholder:
114-
next_val = np.iinfo(np.int32).max - len(self.dim_expr_to_placeholder)
141+
next_val = DIM_UPPER_BOUND - len(self.dim_expr_to_placeholder)
142+
if next_val < DIM_LOWER_BOUND:
143+
# In practice, even with the largest of programs, we see rarely see
144+
# anything even close to this limit. It is arbitrary, and can be safely
145+
# increased if needed.
146+
raise ValueError(
147+
"Too many dynamic shapes in the input. Mosaic currently only"
148+
" supports up to 128 dynamic dimension values."
149+
)
115150
self.dim_expr_to_placeholder[dim_expr] = next_val
151+
# Reverse mapping - this is consumed to generate a table that is either
152+
# input<>placeholder or intermediary computation<>placeholder.
153+
self.placeholder_to_dim_expr[next_val] = dim_expr
116154
return self.dim_expr_to_placeholder[dim_expr]
117155

118156

@@ -622,6 +660,7 @@ def lower_jaxpr_to_module(
622660
"Pallas TPU requires a libTPU version that's at most a month old"
623661
)
624662
debug_info = jaxpr.debug_info
663+
_mosaic_lowering_dynamic_shape_env = None
625664
if dynamic_shape_replacement_enabled:
626665
_mosaic_lowering_dynamic_shape_env = LoweringDynamicShapeEnv()
627666

@@ -663,10 +702,12 @@ def dynamic_shape_replacement_fn(
663702
for_verification=for_verification,
664703
forward_compatible=lowering_context.is_forward_compat(),
665704
dynamic_shape_replacement_fn=dynamic_shape_replacement_fn,
705+
dynamic_shape_replacement_enabled=dynamic_shape_replacement_enabled,
666706
)
667707
m.body.append(func_op)
668708
sym_tab.insert(func_op)
669709
window_params = []
710+
static_grid = None
670711
grid = mosaic_grid_mapping.grid
671712
if grid:
672713
for i, bm in enumerate(grid_mapping.block_mappings):
@@ -738,14 +779,67 @@ def dynamic_shape_replacement_fn(
738779
]
739780
static_grid = dynamic_shape_replacement_fn(static_grid)
740781
func_op.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get(static_grid)
741-
742782
func_op.attributes["scalar_prefetch"] = ir.IntegerAttr.get(
743783
ir.IntegerType.get_signless(64), len(mosaic_grid_mapping.scalar_prefetch_types))
744784
func_op.attributes["scratch_operands"] = ir.IntegerAttr.get(
745785
ir.IntegerType.get_signless(64), len(mosaic_grid_mapping.scratch_types))
746786
func_op.attributes["dimension_semantics"] = (
747787
mosaic_grid_mapping.get_dimension_semantics()
748788
)
789+
if dynamic_shape_replacement_enabled:
790+
if _mosaic_lowering_dynamic_shape_env is None:
791+
raise ValueError(
792+
"Dynamic shape env is None, invariant violated. Unreachable?"
793+
)
794+
795+
# Now we can use jax to compute the dynamic shape graph
796+
797+
if static_grid is not None:
798+
grid_vars = [
799+
_mosaic_lowering_dynamic_shape_env.placeholder_to_dim_expr.get(g, g)
800+
for g in static_grid
801+
]
802+
else:
803+
grid_vars = []
804+
805+
invars = [invar.aval for invar in jaxpr.invars]
806+
# Faux shape for grid, just to get the avals
807+
invars.append(jax.ShapeDtypeStruct(grid_vars, jax.numpy.int32))
808+
args_dimvars = shape_poly.all_dim_vars(invars)
809+
810+
# This is dimexpr var -> placeholder value for when we jit the dim expr
811+
env: dict[str, int] = {}
812+
for aval in args_dimvars:
813+
env[aval] = _mosaic_lowering_dynamic_shape_env.to_placeholder(aval)
814+
815+
for (
816+
placeholder,
817+
dim_expr,
818+
) in _mosaic_lowering_dynamic_shape_env.placeholder_to_dim_expr.items():
819+
top_level_names = list(env.keys())
820+
if dim_expr not in top_level_names:
821+
jitted_eval = jax.jit(
822+
jax_core.evaluate_shape,
823+
static_argnames=(
824+
"shape",
825+
"dim_vars",
826+
),
827+
keep_unused=True,
828+
)
829+
stablehlo = export(
830+
jitted_eval, platforms=[str(jax.devices()[0].platform)]
831+
)(
832+
(dim_expr,), tuple(args_dimvars), *(env[v] for v in args_dimvars)
833+
).mlir_module()
834+
arg_name = args_dimvars
835+
# See Note - On Export Placeholders for more details.
836+
m.operation.attributes[
837+
"tpu.dynamic_dimension_mapping_module_" + str(placeholder)
838+
] = ir.StringAttr.get(str(stablehlo))
839+
arg_name_str = ",".join(arg_name)
840+
m.operation.attributes[
841+
"tpu.dynamic_dimension_mapping_arg_name_" + str(placeholder)
842+
] = ir.StringAttr.get(arg_name_str)
749843
return m, mosaic_grid_mapping.get_extra_args()
750844

751845

@@ -828,6 +922,7 @@ def lower_jaxpr_to_func(
828922
dynamic_shape_replacement_fn: (
829923
Callable[[tuple[jax.DimSize, ...]], tuple[int, ...]] | None
830924
) = None,
925+
dynamic_shape_replacement_enabled: bool = False,
831926
) -> func.FuncOp:
832927
num_grid = len(mosaic_grid_mapping.grid_types)
833928
num_scalar_prefetch = len(mosaic_grid_mapping.scalar_prefetch_types)
@@ -874,6 +969,12 @@ def body_func(*args):
874969
)
875970
body_func.__name__ = name
876971
body = func.FuncOp.from_py_func(*arg_types, name=name)(body_func)
972+
if dynamic_shape_replacement_enabled:
973+
# Skip verification for dynamic shape replacement - you can potentially
974+
# produce ir like ex: add(x[placeholder_0, placeholder_1], y[128, 128])
975+
# which is not valid, but we don't care since we'll run the verifier again
976+
# after the dynamic shape replacement pass.
977+
return body.func_op
877978
try:
878979
body.func_op.verify()
879980
except ir.MLIRError as e:
@@ -3851,3 +3952,15 @@ def _platform_index_lowering(
38513952

38523953

38533954
lowering_rules[jax._src.lax.control_flow.platform_index_p] = _platform_index_lowering
3955+
3956+
3957+
def _dim_as_value_lowering(ctx: mlir.LoweringRuleContext, *, dim):
3958+
placeholder = ctx.lowering_context.dynamic_shape_replacement_fn((dim,))[0]
3959+
return ir_constant(
3960+
placeholder, mlir_type=_dtype_to_ir_type(jnp.dtype("int32"))
3961+
)
3962+
3963+
3964+
import jax._src.export.shape_poly as shape_poly
3965+
3966+
lowering_rules[shape_poly.dim_as_value_p] = _dim_as_value_lowering

tests/pallas/pallas_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2501,7 +2501,8 @@ def sym_matmul_kernel(x_ref, y_ref, z_ref):
25012501
)
25022502
assert exported_module is not None
25032503
self.assertIn(
2504-
"tensor<?x?xf32>, %arg6: tensor<?x?xf32>, %arg7: tensor<?x?xf32>",
2504+
"%arg0: tensor<?x?xf32> loc(unknown), %arg1: tensor<?x?xf32>"
2505+
" loc(unknown), %arg2: tensor<?x?xf32>",
25052506
str(exported_module),
25062507
)
25072508
x = jax.ShapeDtypeStruct((128, 1024), jax.numpy.float32)
@@ -2512,7 +2513,7 @@ def sym_matmul_kernel(x_ref, y_ref, z_ref):
25122513
)
25132514
assert exported_module is not None
25142515
self.assertIn(
2515-
"@sym_matmul(%arg0: tensor<128x1024xf32>, %arg1: tensor<1024x512xf32>",
2516+
"call @sym_matmul(%arg0, %arg1)",
25162517
str(exported_module),
25172518
)
25182519

0 commit comments

Comments
 (0)