2121from typing import Any , Union
2222
2323from jax ._src import config
24- from jax ._src .util import use_cpp_class , cache , use_cpp_method , tuple_insert
24+ from jax ._src .util import use_cpp_class , cache , use_cpp_method
2525from jax ._src .lib import xla_client as xc
2626from jax ._src .lib .mlir .dialects import sdy
2727from jax ._src import mesh as mesh_lib
28- from jax ._src .partition_spec import PartitionSpec , UnconstrainedSingleton
28+ from jax ._src .partition_spec import PartitionSpec
2929from jax ._src import sharding as JSharding
3030from jax ._src import xla_bridge as xb
3131import numpy as np
@@ -198,7 +198,7 @@ def is_fully_addressable(self) -> bool:
198198 # Speed up `is_fully_addressable` since there is a high chance that the
199199 # mesh across multiple NamedSharding objects will be the same.
200200 if config .enable_empty_arrays .value :
201- client = self ._internal_device_list [0 ].client
201+ client = self ._internal_device_list [0 ].client # type: ignore
202202 return (len (self .mesh ._process_indices ) == 1 and
203203 next (iter (self .mesh ._process_indices )) ==
204204 xb .process_index (client ))
@@ -325,80 +325,6 @@ def __repr__(self):
325325 if self .replicated_axes else '' )
326326 return f"SdyArraySharding([{ dim_sharding_repr } ]{ device_id_repr } { rar } )"
327327
328- # TODO(yashkatariya): Remove this after jax 0.5.2 release
329- class ParsedPartitionSpec :
330- __slots__ = ('_user_spec' , 'partitions' )
331-
332- _user_spec : PartitionSpec | None
333- partitions : tuple [tuple [MeshAxisName , ...] | UnconstrainedSingleton , ...]
334-
335- def __init__ (self , user_spec , partitions ):
336- self ._user_spec = user_spec
337- assert None not in partitions , partitions
338- self .partitions = tuple (partitions )
339-
340- def get_partition_spec (self ) -> PartitionSpec :
341- if isinstance (self ._user_spec , PartitionSpec ):
342- return self ._user_spec
343- else :
344- return get_single_pspec (self )
345-
346- def insert_axis_partitions (self , dim , val ):
347- parts = self .partitions
348- too_short = dim - len (parts )
349- if too_short > 0 :
350- parts += ((),) * too_short
351- new_partitions = tuple_insert (parts , dim , val )
352- return ParsedPartitionSpec (None , new_partitions )
353-
354- @classmethod
355- def from_user_input (
356- cls ,
357- entry : PartitionSpec | None ,
358- arg_name : str ,
359- allow_unconstrained_dims : bool = False ,
360- ) -> ParsedPartitionSpec :
361- if entry is None :
362- return cls (entry , ())
363- if not isinstance (entry , PartitionSpec ):
364- raise TypeError (f"{ arg_name } are expected to be "
365- f"PartitionSpec instances or None, but got { entry } " )
366- axis_specs = []
367- for axis_spec in entry :
368- if axis_spec is None :
369- axis_spec = ()
370- elif isinstance (axis_spec , (list , tuple )):
371- axis_spec = tuple (axis_spec )
372- elif axis_spec is PartitionSpec .UNCONSTRAINED :
373- if not allow_unconstrained_dims :
374- raise ValueError (f"Unconstrained dims are not allowed: { entry } " )
375- axis_spec = PartitionSpec .UNCONSTRAINED
376- else :
377- axis_spec = (axis_spec ,)
378- axis_specs .append (axis_spec )
379- new_entry = PartitionSpec (
380- * [tuple (e ) if isinstance (e , (list , tuple )) else e for e in entry ])
381- return cls (new_entry , axis_specs )
382-
383- def __hash__ (self ):
384- return hash (self .partitions )
385-
386- def __eq__ (self , other ):
387- if not isinstance (other , ParsedPartitionSpec ):
388- return False
389- return self .partitions == other .partitions
390-
391- def __len__ (self ):
392- return len (self .partitions )
393-
394- def __getitem__ (self , i ):
395- return self .partitions [i ]
396-
397- def __iter__ (self ):
398- return iter (self .partitions )
399-
400- def __repr__ (self ):
401- return f"ParsedPartitionSpec(partitions={ self .partitions } )"
402328
403329@cache (max_size = 4096 , trace_context_in_key = False )
404330def named_sharding_to_xla_hlo_sharding (
@@ -491,18 +417,8 @@ def array_mapping_to_axis_resources(array_mapping: ArrayMapping):
491417 partitions .append (None )
492418 return PartitionSpec (* partitions )
493419
494- get_single_pspec = lambda p : array_mapping_to_axis_resources (get_array_mapping (p )) # type: ignore
495-
496- # TODO(yashkatariya): Remove this after jax 0.5.2 release
497- def preprocess (mesh , spec , parsed_pspec , _manual_axes = frozenset ()):
498- if parsed_pspec is None :
499- spec = PartitionSpec () if spec is None else spec
500- parsed_pspec = ParsedPartitionSpec .from_user_input (
501- spec , "NamedSharding spec" , allow_unconstrained_dims = True )
502- _check_unique_resources (parsed_pspec , "NamedSharding spec" , mesh )
503- _check_mesh_resource_axis (mesh , parsed_pspec , _manual_axes )
504- return parsed_pspec
505420
421+ @cache (max_size = 128 , trace_context_in_key = False )
506422def check_pspec (mesh , spec , _manual_axes = frozenset ()):
507423 _check_unique_resources (spec , "NamedSharding spec" , mesh )
508424 _check_mesh_resource_axis (mesh , spec , _manual_axes )
@@ -517,13 +433,10 @@ def __init__(self, message, mesh, pspec):
517433 def __str__ (self ):
518434 return f"{ self .message } "
519435
520- def _check_unique_resources (
521- pspec : ParsedPartitionSpec | PartitionSpec , arg_name : str , mesh = None ,
522- ) -> None :
436+ def _check_unique_resources (pspec : PartitionSpec , arg_name : str , mesh = None
437+ ) -> None :
523438 resource_counts : dict [MeshAxisName , int ] = {}
524439 duplicate = False
525- pspec = (pspec .get_partition_spec () if isinstance (pspec , ParsedPartitionSpec )
526- else pspec )
527440 for d in pspec :
528441 if d is PartitionSpec .UNCONSTRAINED or d is None :
529442 continue
@@ -542,10 +455,8 @@ def _check_unique_resources(
542455 f' for { mesh_lib .show_axes (multiple_uses )} ' ),
543456 mesh = mesh , pspec = pspec )
544457
545- @ cache ( max_size = 128 , trace_context_in_key = False )
458+
546459def _check_mesh_resource_axis (mesh , pspec , _manual_axes ):
547- pspec = (pspec .get_partition_spec () if isinstance (pspec , ParsedPartitionSpec )
548- else pspec )
549460 for p in pspec :
550461 if p is PartitionSpec .UNCONSTRAINED or p is None :
551462 continue
0 commit comments