@@ -108,6 +108,7 @@ def pallas_call_tpu_lowering_rule(
108108 * in_nodes ,
109109 jaxpr : jax_core .Jaxpr ,
110110 grid_mapping : core .GridMapping ,
111+ mesh : pallas_core .Mesh | None ,
111112 input_output_aliases : tuple [tuple [int , int ], ...],
112113 debug : bool ,
113114 interpret : bool ,
@@ -116,7 +117,8 @@ def pallas_call_tpu_lowering_rule(
116117 out_avals : tuple [jax_core .AbstractValue , ...],
117118):
118119 """Lowers a pallas_call to a Mosaic TPU custom call."""
119- del interpret
120+ del mesh , interpret # Unused.
121+
120122 debug_info = jaxpr ._debug_info
121123 if debug :
122124 print (f"\n The kernel jaxpr for pallas_call { debug_info .func_src_info } :" )
@@ -126,11 +128,11 @@ def pallas_call_tpu_lowering_rule(
126128 else :
127129 mosaic_params = {}
128130
129- mesh = None
131+ jax_mesh = None
130132 axis_context = ctx .module_context .axis_context
131133 if axis_context is not None :
132134 if isinstance (axis_context , sharding_impls .SPMDAxisContext ):
133- mesh = axis_context .mesh
135+ jax_mesh = axis_context .mesh
134136 mlir_ctx = mlir .JaxIrContext ()
135137 mlir_ctx .append_dialect_registry (mlir .upstream_dialects )
136138 mlir_ctx .load_all_available_dialects ()
@@ -147,7 +149,7 @@ def lower_module(for_verification: bool):
147149 grid_mapping ,
148150 jaxpr ,
149151 dimension_semantics = dimension_semantics ,
150- mesh = mesh ,
152+ mesh = jax_mesh ,
151153 for_verification = for_verification ,
152154 dynamic_shape_replacement_enabled = pallas_core .dynamic_shapes_export_enabled (),
153155 )
@@ -164,11 +166,11 @@ def lower_module(for_verification: bool):
164166 )
165167
166168 if promela_dump_path := _DUMP_PROMELA_TO .value :
167- num_devices = 1 if mesh is None else mesh .devices .size
169+ num_devices = 1 if jax_mesh is None else jax_mesh .devices .size
168170 num_cores = (
169171 jax .devices ()[0 ].num_cores
170- if mesh is None
171- else mesh .devices [0 ].num_cores
172+ if jax_mesh is None
173+ else jax_mesh .devices [0 ].num_cores
172174 )
173175 verification_module , _ = lower_module (for_verification = True )
174176 model = verification .export_promela_model (
0 commit comments