File tree Expand file tree Collapse file tree 1 file changed +6
-3
lines changed
Expand file tree Collapse file tree 1 file changed +6
-3
lines changed Original file line number Diff line number Diff line change @@ -772,15 +772,18 @@ def __repr__(self) -> str:
772772 return f'{ cls_name } ({ body } { mem } , shape={ self .shape } )'
773773
774774 def reshape (self , * shape ) -> PositionalSharding :
775- return self ._remake (self ._devices , self ._ids .reshape (* shape ))
775+ return self ._remake (self ._devices , self ._ids .reshape (* shape ),
776+ memory_kind = self .memory_kind )
776777
777778 def transpose (self , * axes ) -> PositionalSharding :
778- return self ._remake (self ._devices , self ._ids .transpose (* axes ))
779+ return self ._remake (self ._devices , self ._ids .transpose (* axes ),
780+ memory_kind = self .memory_kind )
779781 T = property (transpose )
780782
781783 def replicate (self , axis = None , keepdims = True ) -> PositionalSharding :
782784 new_ids = self ._ids .sum (axis = axis , keepdims = keepdims ) # union
783- return self ._remake (self ._devices , new_ids )
785+ return self ._remake (self ._devices , new_ids ,
786+ memory_kind = self .memory_kind )
784787
785788 def check_compatible_aval (self , aval_shape : Shape ) -> None :
786789 if len (aval_shape ) != len (self .shape ) and not self .is_fully_replicated :
You can’t perform that action at this time.
0 commit comments