Skip to content

Commit 01738a3

Browse files
dzmitry-hubapytorchmergebot
authored andcommitted
Continue local tensor mode enablement for DTensor tests (pytorch#165451)
Pull Request resolved: pytorch#165451 Approved by: https://github.com/ezyang, https://github.com/albanD
1 parent a2f34bd commit 01738a3

File tree

3 files changed

+74
-10
lines changed

3 files changed

+74
-10
lines changed

test/distributed/tensor/test_dtensor.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,6 +1020,19 @@ def test_metadata_consistency_check(self):
10201020
self.fail("Unexpected ValueError raised with run_check=False")
10211021

10221022

1023+
DTensorMeshTestWithLocalTensor = create_local_tensor_test_class(
1024+
DTensorMeshTest,
1025+
skipped_tests=[
1026+
# Submeshes are not supported by local tensor mode
1027+
"test_from_local_sub_mesh",
1028+
"test_default_value_sub_mesh",
1029+
"test_redistribute_sub_mesh",
1030+
# Local tensor mode doesn't support tensors of different types on different ranks
1031+
"test_metadata_consistency_check",
1032+
],
1033+
)
1034+
1035+
10231036
class TestDTensorPlacementTypes(DTensorTestBase):
10241037
@property
10251038
def world_size(self):
@@ -1086,6 +1099,11 @@ def test_split_tensor_1D(self) -> None:
10861099
assert_array_equal(expected_is_tensor_empty, is_tensor_empty)
10871100

10881101

1102+
TestDTensorPlacementTypesWithLocalTensor = create_local_tensor_test_class(
1103+
TestDTensorPlacementTypes,
1104+
)
1105+
1106+
10891107
class TestDTensorSpec(DTensorTestBase):
10901108
@property
10911109
def world_size(self):
@@ -1265,5 +1283,9 @@ def test_default_shard_order(self):
12651283
)
12661284

12671285

1286+
TestDTensorSpecWithLocalTensor = create_local_tensor_test_class(
1287+
TestDTensorSpec,
1288+
)
1289+
12681290
if __name__ == "__main__":
12691291
run_tests()

torch/distributed/_local_tensor/__init__.py

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@
5151
from types import TracebackType
5252
from typing import Any, Callable, Generator, Optional, Union
5353

54+
import numpy as np
55+
5456
import torch
5557
from torch import Size, SymBool, SymInt, Tensor
5658
from torch._C import DispatchKey, DispatchKeySet, ScriptObject
@@ -70,11 +72,13 @@
7072
from . import _c10d
7173

7274

73-
def _int_on_rank(i: "LocalIntNode | ConstantIntNode", r: int) -> int:
75+
def _int_on_rank(i: "int | LocalIntNode | ConstantIntNode", r: int) -> int:
7476
if isinstance(i, LocalIntNode):
7577
return i._local_ints[r]
7678
elif isinstance(i, ConstantIntNode):
7779
return i.val
80+
elif isinstance(i, int):
81+
return i
7882
else:
7983
raise AssertionError(type(i))
8084

@@ -216,7 +220,7 @@ def is_constant(self) -> bool:
216220
return False
217221

218222
def sym_max(
219-
self, other: "LocalIntNode | ConstantIntNode"
223+
self, other: "int | LocalIntNode | ConstantIntNode"
220224
) -> "LocalIntNode | ConstantIntNode":
221225
return LocalIntNode(
222226
{
@@ -226,36 +230,50 @@ def sym_max(
226230
)
227231

228232
def add(
229-
self, other: "LocalIntNode | ConstantIntNode"
233+
self, other: "int | LocalIntNode | ConstantIntNode"
230234
) -> "LocalIntNode | ConstantIntNode":
231235
return LocalIntNode(
232236
{r: self._local_ints[r] + _int_on_rank(other, r) for r in self._local_ints}
233237
)
234238

235239
def sub(
236-
self, other: "LocalIntNode | ConstantIntNode"
240+
self, other: "int | LocalIntNode | ConstantIntNode"
237241
) -> "LocalIntNode | ConstantIntNode":
238242
return LocalIntNode(
239243
{r: self._local_ints[r] - _int_on_rank(other, r) for r in self._local_ints}
240244
)
241245

242246
def mul(
243-
self, other: "LocalIntNode | ConstantIntNode"
247+
self, other: "int | LocalIntNode | ConstantIntNode"
244248
) -> "LocalIntNode | ConstantIntNode":
245249
return LocalIntNode(
246250
{r: self._local_ints[r] * _int_on_rank(other, r) for r in self._local_ints}
247251
)
248252

249-
def eq(self, other: "LocalIntNode | ConstantIntNode") -> bool | SymBool:
253+
def mod(
254+
self, other: "int | LocalIntNode | ConstantIntNode"
255+
) -> "LocalIntNode | ConstantIntNode":
256+
return LocalIntNode(
257+
{r: self._local_ints[r] % _int_on_rank(other, r) for r in self._local_ints}
258+
)
259+
260+
def int_floordiv(
261+
self, other: "int | LocalIntNode | ConstantIntNode"
262+
) -> "LocalIntNode | ConstantIntNode":
263+
return LocalIntNode(
264+
{r: self._local_ints[r] // _int_on_rank(other, r) for r in self._local_ints}
265+
)
266+
267+
def eq(self, other: "int | LocalIntNode | ConstantIntNode") -> bool | SymBool:
250268
r = {self._local_ints[r] == _int_on_rank(other, r) for r in self._local_ints}
251269
return torch._C._get_constant_bool_symnode(len(r) == 1 and next(iter(r)))
252270

253-
def gt(self, other: "LocalIntNode | ConstantIntNode") -> bool | SymBool:
271+
def gt(self, other: "int | LocalIntNode | ConstantIntNode") -> bool | SymBool:
254272
r = {self._local_ints[r] > _int_on_rank(other, r) for r in self._local_ints}
255273
assert len(r) == 1, (self, other)
256274
return torch._C._get_constant_bool_symnode(next(iter(r)))
257275

258-
def lt(self, other: "LocalIntNode | ConstantIntNode") -> bool | SymBool:
276+
def lt(self, other: "int | LocalIntNode | ConstantIntNode") -> bool | SymBool:
259277
r = {self._local_ints[r] < _int_on_rank(other, r) for r in self._local_ints}
260278
assert len(r) == 1, (self, other)
261279
return torch._C._get_constant_bool_symnode(next(iter(r)))
@@ -437,6 +455,27 @@ def __torch_dispatch__( # type: ignore[override]
437455
with LocalTensorMode(local_tensor._ranks):
438456
return func(*args, **kwargs)
439457

458+
def numpy(self, *, force: bool = False) -> np.ndarray:
459+
return self.reconcile().numpy(force=force)
460+
461+
def __lt__(
462+
self, other: torch.Tensor | bool | complex | float | int
463+
) -> torch.Tensor:
464+
self_rec = self.reconcile()
465+
other_rec = other
466+
if isinstance(other, LocalTensor):
467+
other_rec = other.reconcile()
468+
return self_rec < other_rec
469+
470+
def __gt__(
471+
self, other: torch.Tensor | bool | complex | float | int
472+
) -> torch.Tensor:
473+
self_rec = self.reconcile()
474+
other_rec = other
475+
if isinstance(other, LocalTensor):
476+
other_rec = other.reconcile()
477+
return self_rec > other_rec
478+
440479
def tolist(self) -> list[Any]:
441480
"""
442481
Reconcile and convert result to list.

torch/distributed/_local_tensor/_c10d.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,6 @@ def _local_all_gather_(
320320

321321
ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so)
322322

323-
assert isinstance(input_tensor, LocalTensor), "Input tensor must be a LocalTensor"
324323
for i in range(len(output_tensors)):
325324
assert isinstance(output_tensors[i], LocalTensor), (
326325
"Output tensor must be a LocalTensor"
@@ -333,7 +332,11 @@ def _local_all_gather_(
333332

334333
# For each rank in the group, gather from their input tensor
335334
for i, rank_i in enumerate(group_ranks):
336-
output_tensors[i].copy_(input_tensor._local_tensors[rank_i])
335+
# allgather object happens to create pure tensor, so we special case it here
336+
source_tensor = input_tensor
337+
if isinstance(input_tensor, LocalTensor):
338+
source_tensor = input_tensor._local_tensors[rank_i]
339+
output_tensors[i].copy_(source_tensor)
337340

338341
work = FakeWork()
339342
work_so = Work.boxed(work)

0 commit comments

Comments
 (0)