4141from jax ._src import dtypes
4242from jax ._src import effects
4343from jax ._src import linear_util as lu
44- from jax ._src import mesh as mesh_lib
4544from jax ._src import op_shardings
4645from jax ._src import sharding_specs
4746from jax ._src import profiler
6564from jax ._src .lib .mlir .dialects import hlo
6665from jax ._src .partition_spec import PartitionSpec , UnconstrainedSingleton
6766from jax ._src .sharding import Sharding as JSharding
67+ from jax ._src .mesh import AbstractMesh , Mesh
6868from jax ._src .sharding_impls import (
6969 ArrayMapping , ArrayMappingOrAutoOrUnspecified , AUTO , UNSPECIFIED ,
7070 UnspecifiedValue , get_array_mapping as _get_array_mapping ,
@@ -98,7 +98,6 @@ class WeakRefList(list):
9898Replicated = sharding_specs .Replicated
9999
100100AvalDimSharding = Union [Unstacked , Chunked , NoSharding ]
101- Mesh = mesh_lib .Mesh
102101MeshAxisName = sharding_impls .MeshAxisName
103102MeshDimAssignment = Union [ShardedAxis , Replicated ]
104103ShardingSpec = sharding_specs .ShardingSpec
@@ -1723,20 +1722,19 @@ def _get_and_check_device_assignment(
17231722 devices : Sequence [xc .Device ] | None ,
17241723) -> tuple [xc .Client , tuple [xc .Device , ...]]:
17251724 first_sharding_info = None
1726- if devices is None :
1727- devices = ()
1728- else :
1729- devices = tuple (devices )
1725+ devices = () if devices is None else tuple (devices )
17301726
1731- for i , s_type , source_info in shardings :
1732- if isinstance (i , UnspecifiedValue ):
1727+ for sh , s_type , source_info in shardings :
1728+ if isinstance (sh , UnspecifiedValue ):
1729+ continue
1730+ if isinstance (sh , NamedSharding ) and isinstance (sh .mesh , AbstractMesh ):
17331731 continue
1734-
17351732 if first_sharding_info is None :
17361733 first_sharding_info = (
1737- (i .mesh ._flat_devices_tuple , s_type , source_info ) if isinstance (i , AUTO )
1738- else (i ._device_assignment , s_type , source_info ))
1739- arr_device_assignment = i .mesh ._flat_devices_tuple if isinstance (i , AUTO ) else i ._device_assignment
1734+ (sh .mesh ._flat_devices_tuple , s_type , source_info ) if isinstance (sh , AUTO )
1735+ else (sh ._device_assignment , s_type , source_info ))
1736+ arr_device_assignment = (sh .mesh ._flat_devices_tuple if isinstance (sh , AUTO )
1737+ else sh ._device_assignment )
17401738 if not devices :
17411739 if first_sharding_info [0 ] != arr_device_assignment :
17421740 raise DeviceAssignmentMismatchError ([
@@ -1837,7 +1835,8 @@ class SemanticallyEqualShardings:
18371835 def __init__ (self , shardings : tuple [GSPMDSharding | UnspecifiedValue , ...],
18381836 avals : tuple [core .AbstractValue ]):
18391837 gspmd_shardings = [
1840- s if isinstance (s , (UnspecifiedValue , AUTO ))
1838+ s if (isinstance (s , (UnspecifiedValue , AUTO )) or
1839+ (isinstance (s , NamedSharding ) and isinstance (s .mesh , AbstractMesh )))
18411840 else to_gspmd_sharding (s , a .ndim ) # pytype: disable=attribute-error
18421841 for s , a in zip (shardings , avals )]
18431842 self ._gspmd_shardings = gspmd_shardings
@@ -1895,7 +1894,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
18951894 propagated_out_mem_kinds : tuple [None | str , ...],
18961895 platforms : tuple [str , ...],
18971896 lowering_parameters : mlir .LoweringParameters ,
1898- abstract_mesh : mesh_lib . AbstractMesh | None ):
1897+ abstract_mesh : AbstractMesh | None ):
18991898 jaxpr = closed_jaxpr .jaxpr
19001899 in_shardings = semantic_in_shardings .shardings
19011900 out_shardings = semantic_out_shardings .shardings
@@ -2082,6 +2081,40 @@ def write(var, val):
20822081 return tuple (safe_map (read , jaxpr .outvars ))
20832082
20842083
2084+ def _get_num_devices (shardings , device_assignment , lowering_platforms ,
2085+ prim_requires_devices ) -> int :
2086+ ext_abstract_mesh , concrete_sharding = None , False
2087+ for s in shardings :
2088+ if isinstance (s , UnspecifiedValue ):
2089+ continue
2090+ elif isinstance (s , NamedSharding ) and isinstance (s .mesh , AbstractMesh ):
2091+ if ext_abstract_mesh is not None and ext_abstract_mesh != s .mesh :
2092+ raise ValueError ("AbstractMesh should be the same across all "
2093+ f"shardings. Got { ext_abstract_mesh } and { s .mesh } " )
2094+ ext_abstract_mesh = s .mesh
2095+ else :
2096+ concrete_sharding = True
2097+ if (concrete_sharding and ext_abstract_mesh is not None and
2098+ len (device_assignment ) != ext_abstract_mesh .size ):
2099+ raise ValueError (
2100+ f"AbstractMesh size: { ext_abstract_mesh .size } does not match the"
2101+ f" device assignment size: { len (device_assignment )} " )
2102+ if concrete_sharding :
2103+ return len (device_assignment )
2104+ if ext_abstract_mesh is None :
2105+ return len (device_assignment )
2106+ if lowering_platforms is None :
2107+ raise ValueError (
2108+ "Passing lowering_platforms via"
2109+ " jit(f).trace(*args).lower(lowering_platforms=...) is required when"
2110+ " only AbstractMesh exists in a jitted computation." )
2111+ if prim_requires_devices :
2112+ raise ValueError (
2113+ "AbstractMesh cannot be used when jaxpr contains primitives that"
2114+ " require devices to be present during lowering." )
2115+ return ext_abstract_mesh .size
2116+
2117+
20852118MaybeLayout = Sequence [Union [DeviceLocalLayout , AutoLayout , None ]]
20862119
20872120
@@ -2126,7 +2159,7 @@ def _concretize_abstract_shardings(shardings, avals, device_assignment):
21262159
21272160 @lru_cache (maxsize = 128 )
21282161 def _abstract_to_concrete_mesh (abstract_mesh ):
2129- return mesh_lib . Mesh (
2162+ return Mesh (
21302163 np_dev .reshape (abstract_mesh .axis_sizes ), abstract_mesh .axis_names ,
21312164 axis_types = abstract_mesh .axis_types )
21322165
@@ -2153,7 +2186,7 @@ def lower_sharding_computation(
21532186 donated_invars : Sequence [bool ],
21542187 * ,
21552188 keep_unused : bool ,
2156- context_mesh : mesh_lib . Mesh | None ,
2189+ context_mesh : Mesh | None ,
21572190 compiler_options_kvs : tuple [tuple [str , Any ], ...],
21582191 lowering_platforms : tuple [str , ...] | None ,
21592192 lowering_parameters : mlir .LoweringParameters ,
@@ -2211,6 +2244,7 @@ def lower_sharding_computation(
22112244 ((js , MismatchType .SHARDING_INSIDE_COMPUTATION , source_info )
22122245 for js , source_info in unique_intermediate_shardings )),
22132246 devices_from_context )
2247+ unique_intermediate_shardings = [js for js , _ in unique_intermediate_shardings ]
22142248
22152249 if config .sharding_in_types .value :
22162250 out_shardings = _concretize_abstract_shardings (
@@ -2221,21 +2255,31 @@ def lower_sharding_computation(
22212255 platforms = lowering_platforms or (
22222256 getattr (backend , "_raw_platform" , backend .platform ),)
22232257
2258+ prim_requires_devices = dispatch .jaxpr_has_prim_requiring_devices (jaxpr )
2259+
2260+ # TODO(yashkatariya): All device specific logic should go in compilation
2261+ # but this requires a big refactor. The current `_get_num_devices` logic
2262+ # is good enough to lower with AbstractMesh but cannot be compiled. Once
2263+ # I refactor, this will also work well with mesh being provided at
2264+ # compile time.
2265+ num_devices = _get_num_devices (
2266+ it .chain (unique_in_shardings , unique_out_shardings ,
2267+ unique_intermediate_shardings ),
2268+ device_assignment , lowering_platforms , prim_requires_devices )
2269+
22242270 committed = bool (
2225- devices_from_context or
2226- len (device_assignment ) > 1 or
2227- any (not isinstance (i , UnspecifiedValue ) for i in unique_in_shardings ) or
2228- any (not isinstance (js , UnspecifiedValue ) for js , _ in unique_intermediate_shardings ) or
2229- any (not isinstance (o , UnspecifiedValue ) for o in unique_out_shardings ))
2271+ devices_from_context
2272+ or num_devices > 1
2273+ or any (not isinstance (s , UnspecifiedValue ) for s in it .chain (
2274+ unique_in_shardings , unique_out_shardings , unique_intermediate_shardings )))
22302275
22312276 da_object = _create_da_object (tuple (device_assignment ))
22322277
22332278 transfer_mem_kind_in_jaxpr = jaxpr_transfer_mem_kinds (jaxpr )
22342279 all_default_mem_kind = are_all_shardings_default_mem_kind (
22352280 da_object ,
22362281 it .chain (unique_in_shardings , unique_out_shardings ,
2237- [js for js , _ in unique_intermediate_shardings ],
2238- transfer_mem_kind_in_jaxpr )) # pytype: disable=wrong-arg-types
2282+ unique_intermediate_shardings , transfer_mem_kind_in_jaxpr )) # pytype: disable=wrong-arg-types
22392283
22402284 if all_default_mem_kind :
22412285 propagated_out_mem_kinds = (None ,) * len (global_out_avals )
@@ -2244,12 +2288,11 @@ def lower_sharding_computation(
22442288 closed_jaxpr , in_shardings )
22452289
22462290 # 2. Build up the HLO
2247- prim_requires_devices = dispatch .jaxpr_has_prim_requiring_devices (jaxpr )
22482291
22492292 abstract_mesh = None
22502293 if prim_requires_devices :
22512294 for sharding in it .chain (unique_in_shardings , unique_out_shardings ,
2252- [ js for js , _ in unique_intermediate_shardings ] ):
2295+ unique_intermediate_shardings ):
22532296 if isinstance (sharding , NamedSharding ):
22542297 if (abstract_mesh is not None and
22552298 abstract_mesh != sharding .mesh .abstract_mesh ):
@@ -2267,7 +2310,7 @@ def lower_sharding_computation(
22672310 (module , keepalive , host_callbacks , unordered_effects , ordered_effects ,
22682311 nreps , tuple_args , shape_poly_state ) = _cached_lowering_to_hlo (
22692312 closed_jaxpr , api_name , fun_name , backend , semantic_in_shardings ,
2270- semantic_out_shardings , in_layouts , out_layouts , len ( da_object ) ,
2313+ semantic_out_shardings , in_layouts , out_layouts , num_devices ,
22712314 tuple (da_object ) if prim_requires_devices else None , donated_invars ,
22722315 name_stack , all_default_mem_kind , inout_aliases ,
22732316 propagated_out_mem_kinds , platforms ,
@@ -2310,7 +2353,7 @@ def lower_sharding_computation(
23102353 all_default_mem_kind = all_default_mem_kind ,
23112354 all_args_info = all_args_info ,
23122355 pgle_profiler = pgle_profiler ,
2313- intermediate_shardings = [ s for s , _ in unique_intermediate_shardings ] ,
2356+ intermediate_shardings = unique_intermediate_shardings ,
23142357 context_mesh = context_mesh )
23152358
23162359
@@ -2480,7 +2523,7 @@ def _register_out_sharding_handler(
24802523
24812524def _gspmd_to_named_sharding (
24822525 out_s : GSPMDSharding , orig_in_s : NamedSharding ) -> NamedSharding :
2483- assert isinstance (orig_in_s .mesh , mesh_lib . Mesh )
2526+ assert isinstance (orig_in_s .mesh , Mesh )
24842527 return sharding_impls ._gspmd_to_named_sharding_via_mesh (out_s , orig_in_s .mesh )
24852528
24862529_register_out_sharding_handler (NamedSharding , _gspmd_to_named_sharding )
@@ -2532,7 +2575,7 @@ def _get_out_sharding_from_orig_sharding(
25322575
25332576def maybe_recover_user_shardings (
25342577 old_shardings , new_shardings , old_avals , new_avals ,
2535- intermediate_shardings = None , context_mesh : mesh_lib . Mesh | None = None ):
2578+ intermediate_shardings = None , context_mesh : Mesh | None = None ):
25362579 if all (not isinstance (o , sharding_impls .GSPMDSharding ) for o in new_shardings ):
25372580 return new_shardings
25382581
@@ -2832,8 +2875,14 @@ def from_hlo(name: str,
28322875 all_args_info : AllArgsInfo | None = None ,
28332876 pgle_profiler : profiler .PGLEProfiler | None = None ,
28342877 intermediate_shardings : Sequence [JSharding ] | None = None ,
2835- context_mesh : mesh_lib . Mesh | None = None
2878+ context_mesh : Mesh | None = None ,
28362879 ) -> MeshExecutable :
2880+ if any (isinstance (s , NamedSharding ) and isinstance (s .mesh , AbstractMesh )
2881+ for s in it .chain (in_shardings , out_shardings )):
2882+ raise RuntimeError (
2883+ "A jitted computation cannot contain AbstractMesh in in_shardings and"
2884+ " out_shardings during compilation. You can use `jax.export` to "
2885+ " lower with an AbstractMesh and later compile with concrete devices." )
28372886 if shape_poly_state is not None and shape_poly_state .uses_dim_vars :
28382887 hlo = mlir .refine_polymorphic_shapes (hlo )
28392888 if isinstance (device_assignment , xc .DeviceList ):
0 commit comments