4040from jax ._src import state
4141from jax ._src import traceback_util
4242from jax ._src .cloud_tpu_init import is_cloud_tpu_older_than
43+ from jax ._src .export ._export import export
4344from jax ._src .interpreters import mlir
4445from jax ._src .interpreters import partial_eval as pe
4546from jax ._src .lax import lax as lax_internal
8990# The value interpreted as a dynamic dimension by MLIR.
9091MLIR_DYNAMIC = - 9223372036854775808
9192
93+ # TODO(mvoz): Find a way to make this a contract we can share with the
94+ # export specialization step in XLA export.
95+ DIM_UPPER_BOUND = np .iinfo (np .int32 ).max
96+ DIM_LOWER_BOUND = - 128
97+
9298partial = functools .partial
9399map , unsafe_map = safe_map , map # pylint: disable=redefined-builtin
94100zip , unsafe_zip = safe_zip , zip # pylint: disable=redefined-builtin
@@ -102,17 +108,49 @@ class MeshContext:
102108
103109# Note - On Export Placeholders
104110#
105- # Mosaic uses vector IR, which does not have a concept of dynamic
106- # dimensions. We need to come up with a way to represent dynamic dimensions in
107- # vector IR, and so we use placeholders, which are later replaced during
108- # specialization.
111+ # Since the vector dialect used by Mosaic does not support dynamic shapes,
112+ # we replace all top-level symbolic dimensions with placeholder
113+ # constants (between max(int32) - 128 and max(int32)) and we keep a
114+ # mapping from the placeholder constants to SHLO functions that encode
115+ # the symbolic dimension expression, as a function of the dimension
116+ # variables.
117+ #
118+ # The calling convention of the produced MLIR module is the same as
119+ # regular mosaic module, except we add on two new attributes to the custom call
120+ # *per* intermediary placeholder dimension.
121+ #
122+ # The attributes are:
123+ #
124+ # tpu.dynamic_dimension_mapping_arg_name_<placeholder>
125+ # tpu.dynamic_dimension_mapping_module_<placeholder>
126+ #
127+ # The first attribute is a comma-separated list of the dimension variables
128+ # that are used to compute the symbolic dimension expression for the
129+ # placeholder. The second attribute is the MLIR module that contains the
130+ # SHLO functions that compute the symbolic dimension expression for the
131+ # placeholder.
109132class LoweringDynamicShapeEnv :
110- dim_expr_to_placeholder : dict [Any , ir .Value ] = {}
133+ dim_expr_to_placeholder : dict [shape_poly ._DimExpr , int ] = {}
134+ placeholder_to_dim_expr : dict [int , shape_poly ._DimExpr ] = {}
111135
112136 def to_placeholder (self , dim_expr : Any ) -> ir .Value :
137+ if jax_core .is_constant_dim (dim_expr ):
138+ # avoid ints, these are not dynamic
139+ return dim_expr
113140 if dim_expr not in self .dim_expr_to_placeholder :
114- next_val = np .iinfo (np .int32 ).max - len (self .dim_expr_to_placeholder )
141+ next_val = DIM_UPPER_BOUND - len (self .dim_expr_to_placeholder )
142+ if next_val < DIM_LOWER_BOUND :
143+ # In practice, even with the largest of programs, we see rarely see
144+ # anything even close to this limit. It is arbitrary, and can be safely
145+ # increased if needed.
146+ raise ValueError (
147+ "Too many dynamic shapes in the input. Mosaic currently only"
148+ " supports up to 128 dynamic dimension values."
149+ )
115150 self .dim_expr_to_placeholder [dim_expr ] = next_val
151+ # Reverse mapping - this is consumed to generate a table that is either
152+ # input<>placeholder or intermediary computation<>placeholder.
153+ self .placeholder_to_dim_expr [next_val ] = dim_expr
116154 return self .dim_expr_to_placeholder [dim_expr ]
117155
118156
@@ -622,6 +660,7 @@ def lower_jaxpr_to_module(
622660 "Pallas TPU requires a libTPU version that's at most a month old"
623661 )
624662 debug_info = jaxpr .debug_info
663+ _mosaic_lowering_dynamic_shape_env = None
625664 if dynamic_shape_replacement_enabled :
626665 _mosaic_lowering_dynamic_shape_env = LoweringDynamicShapeEnv ()
627666
@@ -663,10 +702,12 @@ def dynamic_shape_replacement_fn(
663702 for_verification = for_verification ,
664703 forward_compatible = lowering_context .is_forward_compat (),
665704 dynamic_shape_replacement_fn = dynamic_shape_replacement_fn ,
705+ dynamic_shape_replacement_enabled = dynamic_shape_replacement_enabled ,
666706 )
667707 m .body .append (func_op )
668708 sym_tab .insert (func_op )
669709 window_params = []
710+ static_grid = None
670711 grid = mosaic_grid_mapping .grid
671712 if grid :
672713 for i , bm in enumerate (grid_mapping .block_mappings ):
@@ -738,14 +779,67 @@ def dynamic_shape_replacement_fn(
738779 ]
739780 static_grid = dynamic_shape_replacement_fn (static_grid )
740781 func_op .attributes ["iteration_bounds" ] = ir .DenseI64ArrayAttr .get (static_grid )
741-
742782 func_op .attributes ["scalar_prefetch" ] = ir .IntegerAttr .get (
743783 ir .IntegerType .get_signless (64 ), len (mosaic_grid_mapping .scalar_prefetch_types ))
744784 func_op .attributes ["scratch_operands" ] = ir .IntegerAttr .get (
745785 ir .IntegerType .get_signless (64 ), len (mosaic_grid_mapping .scratch_types ))
746786 func_op .attributes ["dimension_semantics" ] = (
747787 mosaic_grid_mapping .get_dimension_semantics ()
748788 )
789+ if dynamic_shape_replacement_enabled :
790+ if _mosaic_lowering_dynamic_shape_env is None :
791+ raise ValueError (
792+ "Dynamic shape env is None, invariant violated. Unreachable?"
793+ )
794+
795+ # Now we can use jax to compute the dynamic shape graph
796+
797+ if static_grid is not None :
798+ grid_vars = [
799+ _mosaic_lowering_dynamic_shape_env .placeholder_to_dim_expr .get (g , g )
800+ for g in static_grid
801+ ]
802+ else :
803+ grid_vars = []
804+
805+ invars = [invar .aval for invar in jaxpr .invars ]
806+ # Faux shape for grid, just to get the avals
807+ invars .append (jax .ShapeDtypeStruct (grid_vars , jax .numpy .int32 ))
808+ args_dimvars = shape_poly .all_dim_vars (invars )
809+
810+ # This is dimexpr var -> placeholder value for when we jit the dim expr
811+ env : dict [str , int ] = {}
812+ for aval in args_dimvars :
813+ env [aval ] = _mosaic_lowering_dynamic_shape_env .to_placeholder (aval )
814+
815+ for (
816+ placeholder ,
817+ dim_expr ,
818+ ) in _mosaic_lowering_dynamic_shape_env .placeholder_to_dim_expr .items ():
819+ top_level_names = list (env .keys ())
820+ if dim_expr not in top_level_names :
821+ jitted_eval = jax .jit (
822+ jax_core .evaluate_shape ,
823+ static_argnames = (
824+ "shape" ,
825+ "dim_vars" ,
826+ ),
827+ keep_unused = True ,
828+ )
829+ stablehlo = export (
830+ jitted_eval , platforms = [str (jax .devices ()[0 ].platform )]
831+ )(
832+ (dim_expr ,), tuple (args_dimvars ), * (env [v ] for v in args_dimvars )
833+ ).mlir_module ()
834+ arg_name = args_dimvars
835+ # See Note - On Export Placeholders for more details.
836+ m .operation .attributes [
837+ "tpu.dynamic_dimension_mapping_module_" + str (placeholder )
838+ ] = ir .StringAttr .get (str (stablehlo ))
839+ arg_name_str = "," .join (arg_name )
840+ m .operation .attributes [
841+ "tpu.dynamic_dimension_mapping_arg_name_" + str (placeholder )
842+ ] = ir .StringAttr .get (arg_name_str )
749843 return m , mosaic_grid_mapping .get_extra_args ()
750844
751845
@@ -828,6 +922,7 @@ def lower_jaxpr_to_func(
828922 dynamic_shape_replacement_fn : (
829923 Callable [[tuple [jax .DimSize , ...]], tuple [int , ...]] | None
830924 ) = None ,
925+ dynamic_shape_replacement_enabled : bool = False ,
831926) -> func .FuncOp :
832927 num_grid = len (mosaic_grid_mapping .grid_types )
833928 num_scalar_prefetch = len (mosaic_grid_mapping .scalar_prefetch_types )
@@ -874,6 +969,12 @@ def body_func(*args):
874969 )
875970 body_func .__name__ = name
876971 body = func .FuncOp .from_py_func (* arg_types , name = name )(body_func )
972+ if dynamic_shape_replacement_enabled :
973+ # Skip verification for dynamic shape replacement - you can potentially
974+ # produce ir like ex: add(x[placeholder_0, placeholder_1], y[128, 128])
975+ # which is not valid, but we don't care since we'll run the verifier again
976+ # after the dynamic shape replacement pass.
977+ return body .func_op
877978 try :
878979 body .func_op .verify ()
879980 except ir .MLIRError as e :
@@ -3851,3 +3952,15 @@ def _platform_index_lowering(
38513952
38523953
38533954lowering_rules [jax ._src .lax .control_flow .platform_index_p ] = _platform_index_lowering
3955+
3956+
3957+ def _dim_as_value_lowering (ctx : mlir .LoweringRuleContext , * , dim ):
3958+ placeholder = ctx .lowering_context .dynamic_shape_replacement_fn ((dim ,))[0 ]
3959+ return ir_constant (
3960+ placeholder , mlir_type = _dtype_to_ir_type (jnp .dtype ("int32" ))
3961+ )
3962+
3963+
3964+ import jax ._src .export .shape_poly as shape_poly
3965+
3966+ lowering_rules [shape_poly .dim_as_value_p ] = _dim_as_value_lowering
0 commit comments