Skip to content

Commit 6213478

Browse files
suofacebook-github-bot
authored andcommitted
implement __len__ for MeshTrait (meta-pytorch#821)
Summary: Pull Request resolved: meta-pytorch#821 straightforward enough ghstack-source-id: 302243074 exported-using-ghexport Reviewed By: mariusae Differential Revision: D79483634 fbshipit-source-id: 785ab0cec0494bf5276179e87d27898d9caefba6
1 parent 83b35d8 commit 6213478

File tree

4 files changed

+47
-3
lines changed

4 files changed

+47
-3
lines changed

python/monarch/_src/actor/actor_mesh.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -624,9 +624,6 @@ def items(self) -> Iterable[Tuple[Point, R]]:
624624
def __iter__(self) -> Iterator[Tuple[Point, R]]:
625625
return iter(self.items())
626626

627-
def __len__(self) -> int:
628-
return len(self._shape)
629-
630627
def __repr__(self) -> str:
631628
return f"ValueMesh({self._shape})"
632629

python/monarch/_src/actor/shape.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,5 +224,8 @@ def size(self, dim: Union[None, str, Sequence[str]] = None) -> int:
224224
def sizes(self) -> dict[str, int]:
225225
return dict(zip(self._labels, self._ndslice.sizes))
226226

227+
def __len__(self) -> int:
228+
return len(self._ndslice)
229+
227230

228231
__all__ = ["NDSlice", "Shape", "MeshTrait"]

python/tests/test_mesh_trait.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Iterable
8+
9+
from monarch._src.actor.shape import MeshTrait, NDSlice, Shape, Slice
10+
11+
12+
class Mesh(MeshTrait):
13+
"""
14+
A simple implementor of MeshTrait.
15+
"""
16+
17+
def __init__(self, shape: Shape, values: list[int]) -> None:
18+
self._shape = shape
19+
self._values = values
20+
21+
def _new_with_shape(self, shape: Shape) -> "Mesh":
22+
return Mesh(shape, self._values)
23+
24+
@property
25+
def _ndslice(self) -> NDSlice:
26+
return self._shape.ndslice
27+
28+
@property
29+
def _labels(self) -> Iterable[str]:
30+
return self._shape.labels
31+
32+
33+
def test_len() -> None:
34+
s = Slice(offset=0, sizes=[2, 3], strides=[3, 1])
35+
shape = Shape(["label0", "label1"], s)
36+
37+
mesh = Mesh(shape, [1, 2, 3, 4, 5, 6])
38+
assert 6 == len(mesh)

python/tests/test_python_actors.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1004,3 +1004,9 @@ def s(t):
10041004
b = PythonTask.spawn_blocking(lambda: s(0))
10051005
r = PythonTask.select_one([a.task(), b.task()]).block_on()
10061006
assert r == (0, 1)
1007+
1008+
1009+
def test_mesh_len():
1010+
proc_mesh = local_proc_mesh(gpus=12).get()
1011+
s = proc_mesh.spawn("sync_actor", SyncActor).get()
1012+
assert 12 == len(s)

0 commit comments

Comments
 (0)