Skip to content

Commit 931566b

Browse files
ahmadsharif1facebook-github-bot
authored andcommitted
Add a __str__ function for Point (meta-pytorch#950)
Summary: Pull Request resolved: meta-pytorch#950 str(current_rank()) in an actor doesn't print anything meaningful I added a `__repr__` function to address that It prints something like this: ``` rank=0/2 coords={hosts=0/1,gpus=0/2} ``` Reviewed By: dulinriley Differential Revision: D80649558 fbshipit-source-id: cd8d126b98d6ae0006872c882d022205410f998d
1 parent fe1dcbc commit 931566b

File tree

2 files changed

+45
-0
lines changed

2 files changed

+45
-0
lines changed

monarch_hyperactor/src/shape.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,36 @@ impl PyPoint {
210210
}
211211
}
212212

213+
fn __repr__(&self, py: Python) -> PyResult<String> {
214+
let shape = self.shape.bind(py).get();
215+
let inner_shape = &shape.inner;
216+
let slice = inner_shape.slice();
217+
218+
let total_size = slice.len();
219+
let current_rank = self.rank;
220+
221+
let coords = slice
222+
.coordinates(current_rank)
223+
.map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
224+
225+
let labels = inner_shape.labels();
226+
let sizes = slice.sizes();
227+
228+
let coords_parts: Vec<String> = labels
229+
.iter()
230+
.zip(coords.iter())
231+
.zip(sizes.iter())
232+
.map(|((label, &coord), &size)| format!("{}={}/{}", label, coord, size))
233+
.collect();
234+
235+
let coords_str = coords_parts.join(",");
236+
237+
Ok(format!(
238+
"rank={}/{} coords={{{}}}",
239+
current_rank, total_size, coords_str
240+
))
241+
}
242+
213243
fn __len__(&self, py: Python) -> usize {
214244
self.shape.bind(py).get().__len__()
215245
}

python/tests/test_python_actors.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,10 @@ class RunIt(Actor):
163163
async def run(self, fn):
164164
return fn()
165165

166+
@endpoint
167+
async def return_current_rank_str(self):
168+
return str(current_rank())
169+
166170

167171
@pytest.mark.timeout(60)
168172
async def test_rank_size():
@@ -175,6 +179,17 @@ async def test_rank_size():
175179
assert 4 == await acc.accumulate(lambda: current_size()["gpus"])
176180

177181

182+
@pytest.mark.timeout(60)
183+
async def test_rank_string():
184+
proc = await local_proc_mesh(gpus=2)
185+
r = await proc.spawn("runit", RunIt)
186+
vm = r.return_current_rank_str.call().get()
187+
r0 = vm.flatten("r").slice(r=0).item()
188+
r1 = vm.flatten("r").slice(r=1).item()
189+
assert r0 == "rank=0/2 coords={hosts=0/1,gpus=0/2}"
190+
assert r1 == "rank=1/2 coords={hosts=0/1,gpus=1/2}"
191+
192+
178193
class SyncActor(Actor):
179194
@endpoint
180195
def sync_endpoint(self, a_counter: Counter):

0 commit comments

Comments
 (0)