Skip to content
This repository was archived by the owner on Oct 19, 2024. It is now read-only.

Commit f7f9338

Browse files
authored
Moved _compute_one_replica_ids from DistributedArray to PhysicalDeviceMesh (#914)
1 parent 4617a01 commit f7f9338

File tree

1 file changed

+28
-24
lines changed

1 file changed

+28
-24
lines changed

alpa/device_mesh.py

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,7 @@ class PhysicalDeviceMesh(ABC):
638638
num_devices_per_host: int
639639
mesh_id: int
640640
operation_executables: dict
641+
one_replica_ids: dict
641642

642643
def get_signature(self) -> str:
643644
"""Return a signature string that contains the mesh shape and GPU
@@ -648,6 +649,27 @@ def get_signature(self) -> str:
648649
ret = ret.replace(" ", "-")
649650
return ret
650651

652+
def _compute_one_replica_ids(self, indices, aval_shape, sharding_spec):
653+
# Tuple (aval_shape, sharding_spec) is 1-1 mapped to indices
654+
# used to compute one_replica_ids
655+
if (aval_shape, sharding_spec) in self.one_replica_ids:
656+
return self.one_replica_ids[(aval_shape, sharding_spec)]
657+
658+
one_replica_indices = []
659+
one_replica_host_local_ids = []
660+
seen_index_hashes = set()
661+
for i, index in enumerate(indices):
662+
hashed_index = _hashable_index(index)
663+
if hashed_index not in seen_index_hashes:
664+
one_replica_indices.append(i)
665+
one_replica_host_local_ids.append(
666+
divmod(i, self.num_devices_per_host))
667+
seen_index_hashes.add(hashed_index)
668+
self.one_replica_ids[(
669+
aval_shape,
670+
sharding_spec)] = one_replica_indices, one_replica_host_local_ids
671+
return one_replica_indices, one_replica_host_local_ids
672+
651673
@property
652674
def shape(self):
653675
return self.num_hosts, self.num_devices_per_host
@@ -845,6 +867,7 @@ def __init__(self, devices: Sequence["Device"] = None):
845867
self.mesh_id = -1
846868
self.device_strs = []
847869
self.operation_executables = {}
870+
self.one_replica_ids = {}
848871

849872
self.backend = xb.get_backend(global_config.backend)
850873

@@ -974,6 +997,7 @@ def __init__(self,
974997
self.workers = None
975998
self.service_server = None
976999
self.operation_executables = {}
1000+
self.one_replica_ids = {}
9771001
self.namespace = namespace
9781002

9791003
if devices is not None:
@@ -1508,8 +1532,6 @@ def __init__(self,
15081532
self.shape = self.aval.shape
15091533
self.dtype = self.aval.dtype
15101534
self._npy_value = None
1511-
self._one_replica_host_local_ids = None
1512-
self._one_replica_buffer_ids = None
15131535
self._fetched_np_buffers = None
15141536
self._fetched_np_buffers_ref = None
15151537
self.skip_shard_args_check = False
@@ -1616,34 +1638,16 @@ def load(cls, path: str, aval: ShapedArray, device_mesh: PhysicalDeviceMesh,
16161638
return DistributedArray(device_mesh, aval, sharding_spec, ary_ref,
16171639
indices)
16181640

1619-
def _compute_one_replica_ids(self):
1620-
one_replica_indices = []
1621-
one_replica_host_local_ids = []
1622-
seen_index_hashes = set()
1623-
for i, index in enumerate(self.indices):
1624-
hashed_index = _hashable_index(index)
1625-
if hashed_index not in seen_index_hashes:
1626-
one_replica_indices.append(i)
1627-
one_replica_host_local_ids.append(
1628-
divmod(i, self.device_mesh.num_devices_per_host))
1629-
seen_index_hashes.add(hashed_index)
1630-
self._one_replica_buffer_ids = one_replica_indices
1631-
self._one_replica_host_local_ids = one_replica_host_local_ids
1632-
1633-
# TODO(yonghao): to make ._value faster(in reorder buffer), cache different
1634-
# buffers with the same mesh shape and sharding spec.
16351641
@property
16361642
def one_replica_buffer_ids(self):
16371643
"""Indices of buffers containing one complete copy of the array data."""
1638-
if self._one_replica_buffer_ids is None:
1639-
self._compute_one_replica_ids()
1640-
return self._one_replica_buffer_ids
1644+
return self.device_mesh._compute_one_replica_ids(
1645+
self.indices, self.aval.shape, self.sharding_spec)[0]
16411646

16421647
@property
16431648
def one_replica_host_local_ids(self):
1644-
if self._one_replica_host_local_ids is None:
1645-
self._compute_one_replica_ids()
1646-
return self._one_replica_host_local_ids
1649+
return self.device_mesh._compute_one_replica_ids(
1650+
self.indices, self.aval.shape, self.sharding_spec)[1]
16471651

16481652
@property
16491653
def _value(self):

0 commit comments

Comments
 (0)