@@ -114,9 +114,7 @@ def modify_sdy_sharding_wrt_axis_types(sdy_sharding: SdyArraySharding, mesh):
114114 return sdy_sharding
115115
116116
117- @util .cache (max_size = 128 , trace_context_in_key = False )
118- def get_replicated_hlo_sharding ():
119- return xc .HloSharding .replicate ()
117+ replicated_hlo_sharding = xc .HloSharding .replicate ()
120118
121119
122120@use_cpp_class (xc .SingleDeviceSharding )
@@ -183,7 +181,7 @@ def _device_assignment(self) -> XLADeviceAssignment:
183181 return (self ._device ,)
184182
185183 def _to_xla_hlo_sharding (self , num_dimensions : int ) -> xc .HloSharding :
186- return get_replicated_hlo_sharding ()
184+ return replicated_hlo_sharding
187185
188186 def _to_sdy_sharding (self , num_dimensions : int ) -> SdyArraySharding :
189187 sdy_dim_sharding = [SdyDimSharding (axes = [], is_closed = True )
@@ -401,7 +399,7 @@ def _op_sharding_to_pos_sharding(
401399def _positional_sharding_to_xla_hlo_sharding (
402400 self , num_dimensions : int ) -> xc .HloSharding :
403401 if self .shape == (1 ,) * self .ndim :
404- return get_replicated_hlo_sharding ()
402+ return replicated_hlo_sharding
405403
406404 pbuf = xc .OpSharding ()
407405 shape = self .shape [self .ndim - num_dimensions :] # 'rank promotion' of val
@@ -603,7 +601,7 @@ def __reduce__(self):
603601 @functools .cached_property
604602 def _hlo_sharding_hash (self ):
605603 if self .is_fully_replicated :
606- return hash (get_replicated_hlo_sharding () )
604+ return hash (replicated_hlo_sharding )
607605 return hash (self ._hlo_sharding )
608606
609607 def __eq__ (self , other ):
@@ -669,7 +667,7 @@ def is_fully_addressable(self) -> bool:
669667
670668 @classmethod
671669 def get_replicated (cls , device_assignment , * , memory_kind : str | None = None ):
672- return cls (tuple (device_assignment ), get_replicated_hlo_sharding () ,
670+ return cls (tuple (device_assignment ), replicated_hlo_sharding ,
673671 memory_kind = memory_kind )
674672
675673
0 commit comments