@@ -638,6 +638,7 @@ class PhysicalDeviceMesh(ABC):
638
638
num_devices_per_host : int
639
639
mesh_id : int
640
640
operation_executables : dict
641
+ one_replica_ids : dict
641
642
642
643
def get_signature (self ) -> str :
643
644
"""Return a signature string that contains the mesh shape and GPU
@@ -648,6 +649,27 @@ def get_signature(self) -> str:
648
649
ret = ret .replace (" " , "-" )
649
650
return ret
650
651
652
+ def _compute_one_replica_ids (self , indices , aval_shape , sharding_spec ):
653
+ # Tuple (aval_shape, sharding_spec) is 1-1 mapped to indices
654
+ # used to compute one_replica_ids
655
+ if (aval_shape , sharding_spec ) in self .one_replica_ids :
656
+ return self .one_replica_ids [(aval_shape , sharding_spec )]
657
+
658
+ one_replica_indices = []
659
+ one_replica_host_local_ids = []
660
+ seen_index_hashes = set ()
661
+ for i , index in enumerate (indices ):
662
+ hashed_index = _hashable_index (index )
663
+ if hashed_index not in seen_index_hashes :
664
+ one_replica_indices .append (i )
665
+ one_replica_host_local_ids .append (
666
+ divmod (i , self .num_devices_per_host ))
667
+ seen_index_hashes .add (hashed_index )
668
+ self .one_replica_ids [(
669
+ aval_shape ,
670
+ sharding_spec )] = one_replica_indices , one_replica_host_local_ids
671
+ return one_replica_indices , one_replica_host_local_ids
672
+
651
673
@property
652
674
def shape (self ):
653
675
return self .num_hosts , self .num_devices_per_host
@@ -845,6 +867,7 @@ def __init__(self, devices: Sequence["Device"] = None):
845
867
self .mesh_id = - 1
846
868
self .device_strs = []
847
869
self .operation_executables = {}
870
+ self .one_replica_ids = {}
848
871
849
872
self .backend = xb .get_backend (global_config .backend )
850
873
@@ -974,6 +997,7 @@ def __init__(self,
974
997
self .workers = None
975
998
self .service_server = None
976
999
self .operation_executables = {}
1000
+ self .one_replica_ids = {}
977
1001
self .namespace = namespace
978
1002
979
1003
if devices is not None :
@@ -1508,8 +1532,6 @@ def __init__(self,
1508
1532
self .shape = self .aval .shape
1509
1533
self .dtype = self .aval .dtype
1510
1534
self ._npy_value = None
1511
- self ._one_replica_host_local_ids = None
1512
- self ._one_replica_buffer_ids = None
1513
1535
self ._fetched_np_buffers = None
1514
1536
self ._fetched_np_buffers_ref = None
1515
1537
self .skip_shard_args_check = False
@@ -1616,34 +1638,16 @@ def load(cls, path: str, aval: ShapedArray, device_mesh: PhysicalDeviceMesh,
1616
1638
return DistributedArray (device_mesh , aval , sharding_spec , ary_ref ,
1617
1639
indices )
1618
1640
1619
- def _compute_one_replica_ids (self ):
1620
- one_replica_indices = []
1621
- one_replica_host_local_ids = []
1622
- seen_index_hashes = set ()
1623
- for i , index in enumerate (self .indices ):
1624
- hashed_index = _hashable_index (index )
1625
- if hashed_index not in seen_index_hashes :
1626
- one_replica_indices .append (i )
1627
- one_replica_host_local_ids .append (
1628
- divmod (i , self .device_mesh .num_devices_per_host ))
1629
- seen_index_hashes .add (hashed_index )
1630
- self ._one_replica_buffer_ids = one_replica_indices
1631
- self ._one_replica_host_local_ids = one_replica_host_local_ids
1632
-
1633
- # TODO(yonghao): to make ._value faster(in reorder buffer), cache different
1634
- # buffers with the same mesh shape and sharding spec.
1635
1641
@property
1636
1642
def one_replica_buffer_ids (self ):
1637
1643
"""Indices of buffers containing one complete copy of the array data."""
1638
- if self ._one_replica_buffer_ids is None :
1639
- self ._compute_one_replica_ids ()
1640
- return self ._one_replica_buffer_ids
1644
+ return self .device_mesh ._compute_one_replica_ids (
1645
+ self .indices , self .aval .shape , self .sharding_spec )[0 ]
1641
1646
1642
1647
@property
1643
1648
def one_replica_host_local_ids (self ):
1644
- if self ._one_replica_host_local_ids is None :
1645
- self ._compute_one_replica_ids ()
1646
- return self ._one_replica_host_local_ids
1649
+ return self .device_mesh ._compute_one_replica_ids (
1650
+ self .indices , self .aval .shape , self .sharding_spec )[1 ]
1647
1651
1648
1652
@property
1649
1653
def _value (self ):
0 commit comments