Skip to content

Commit 20658fa

Browse files
hawkinspGoogle-ML-Automation
authored andcommitted
Replace cached function get_replicated_hlo_sharding() with a constant.
Small cleanup, no functional changes intended. PiperOrigin-RevId: 737727727
1 parent ebcae0d commit 20658fa

File tree

2 files changed

+6
-8
lines changed

2 files changed

+6
-8
lines changed

jax/_src/debugging.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,7 @@ def _hlo_sharding_callback(hlo_sharding: xc.HloSharding):
446446
if len(devices) == 1:
447447
# If we only have one device in our computation, we can construct a
448448
# replicated HloSharding and call it right now.
449-
_hlo_sharding_callback(sharding_impls.get_replicated_hlo_sharding())
449+
_hlo_sharding_callback(sharding_impls.replicated_hlo_sharding)
450450
return []
451451

452452
key = xc.encode_inspect_sharding_callback(_hlo_sharding_callback)

jax/_src/sharding_impls.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,7 @@ def modify_sdy_sharding_wrt_axis_types(sdy_sharding: SdyArraySharding, mesh):
114114
return sdy_sharding
115115

116116

117-
@util.cache(max_size=128, trace_context_in_key=False)
118-
def get_replicated_hlo_sharding():
119-
return xc.HloSharding.replicate()
117+
replicated_hlo_sharding = xc.HloSharding.replicate()
120118

121119

122120
@use_cpp_class(xc.SingleDeviceSharding)
@@ -183,7 +181,7 @@ def _device_assignment(self) -> XLADeviceAssignment:
183181
return (self._device,)
184182

185183
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
186-
return get_replicated_hlo_sharding()
184+
return replicated_hlo_sharding
187185

188186
def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding:
189187
sdy_dim_sharding = [SdyDimSharding(axes=[], is_closed=True)
@@ -401,7 +399,7 @@ def _op_sharding_to_pos_sharding(
401399
def _positional_sharding_to_xla_hlo_sharding(
402400
self, num_dimensions: int) -> xc.HloSharding:
403401
if self.shape == (1,) * self.ndim:
404-
return get_replicated_hlo_sharding()
402+
return replicated_hlo_sharding
405403

406404
pbuf = xc.OpSharding()
407405
shape = self.shape[self.ndim - num_dimensions:] # 'rank promotion' of val
@@ -603,7 +601,7 @@ def __reduce__(self):
603601
@functools.cached_property
604602
def _hlo_sharding_hash(self):
605603
if self.is_fully_replicated:
606-
return hash(get_replicated_hlo_sharding())
604+
return hash(replicated_hlo_sharding)
607605
return hash(self._hlo_sharding)
608606

609607
def __eq__(self, other):
@@ -669,7 +667,7 @@ def is_fully_addressable(self) -> bool:
669667

670668
@classmethod
671669
def get_replicated(cls, device_assignment, *, memory_kind: str | None = None):
672-
return cls(tuple(device_assignment), get_replicated_hlo_sharding(),
670+
return cls(tuple(device_assignment), replicated_hlo_sharding,
673671
memory_kind=memory_kind)
674672

675673

0 commit comments

Comments
 (0)