Skip to content

Commit fb832af

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Respect the original memory kind on reshape, transpose and replicate methods of PositionalSharding. Fixes #25769
PiperOrigin-RevId: 713446871
1 parent cbcc883 commit fb832af

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

jax/_src/sharding_impls.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -772,15 +772,18 @@ def __repr__(self) -> str:
772772
return f'{cls_name}({body}{mem}, shape={self.shape})'
773773

774774
def reshape(self, *shape) -> PositionalSharding:
775-
return self._remake(self._devices, self._ids.reshape(*shape))
775+
return self._remake(self._devices, self._ids.reshape(*shape),
776+
memory_kind=self.memory_kind)
776777

777778
def transpose(self, *axes) -> PositionalSharding:
778-
return self._remake(self._devices, self._ids.transpose(*axes))
779+
return self._remake(self._devices, self._ids.transpose(*axes),
780+
memory_kind=self.memory_kind)
779781
T = property(transpose)
780782

781783
def replicate(self, axis=None, keepdims=True) -> PositionalSharding:
782784
new_ids = self._ids.sum(axis=axis, keepdims=keepdims) # union
783-
return self._remake(self._devices, new_ids)
785+
return self._remake(self._devices, new_ids,
786+
memory_kind=self.memory_kind)
784787

785788
def check_compatible_aval(self, aval_shape: Shape) -> None:
786789
if len(aval_shape) != len(self.shape) and not self.is_fully_replicated:

0 commit comments

Comments
 (0)