Skip to content

Commit 4c2c401

Browse files
SherlockNoMadpytorchmergebot
authored andcommitted
Record redistribute_local_tensor in DebugMode (pytorch#163704)
Explicit redistribute_local_tensor API call could also results in communication, record it! Pull Request resolved: pytorch#163704 Approved by: https://github.com/ezyang
1 parent 5d0f639 commit 4c2c401

File tree

5 files changed

+145
-113
lines changed

5 files changed

+145
-113
lines changed

test/distributed/tensor/debug/test_debug_mode.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,9 @@ def test_debug_mode_mm(self):
5050
torch.mm(dt: f32[8, 8][S(0)], dt: f32[8, 32][S(0)])
5151
aten::mm(dt: f32[8, 8][S(0)], dt: f32[8, 32][S(0)])
5252
redistribute_input(1, [S(0)] -> [R])
53-
_c10d_functional::all_gather_into_tensor(t: f32[1, 32], 8, 0)
54-
_c10d_functional::wait_tensor(t: f32[8, 32])
53+
redistribute_input(t: f32[1, 32], [S(0)] -> [R])
54+
_c10d_functional::all_gather_into_tensor(t: f32[1, 32], 8, 0)
55+
_c10d_functional::wait_tensor(t: f32[8, 32])
5556
aten::mm(t: f32[1, 8], t: f32[8, 32])
5657
<method 'sum' of 'torch._C.TensorBase' objects>(dt: f32[8, 32][S(0)])
5758
aten::sum(dt: f32[8, 32][S(0)])
@@ -90,7 +91,8 @@ def test_debug_mode_backward(self):
9091
<method 'add' of 'torch._C.TensorBase' objects>(dt: f32[8, 8][S(0)], dt: f32[8, 8][S(1)])
9192
aten::add.Tensor(dt: f32[8, 8][S(0)], dt: f32[8, 8][S(1)])
9293
redistribute_input(1, [S(1)] -> [S(0)])
93-
_dtensor::shard_dim_alltoall(t: f32[8, 1], 1, 0, 0)
94+
redistribute_input(t: f32[8, 1], [S(1)] -> [S(0)])
95+
_dtensor::shard_dim_alltoall(t: f32[8, 1], 1, 0, 0)
9496
aten::add.Tensor(t: f32[1, 8], t: f32[1, 8])
9597
<method 'sum' of 'torch._C.TensorBase' objects>(dt: f32[8, 8][S(0)])
9698
aten::sum(dt: f32[8, 8][S(0)])
@@ -100,12 +102,14 @@ def test_debug_mode_backward(self):
100102
aten::ones_like(t: f32[], pin_memory=False, memory_format=torch.preserve_format)
101103
aten::expand(dt: f32[][R], [8, 8])
102104
aten::expand(t: f32[], [8, 8])
103-
aten::split.Tensor(t: f32[8, 8], 1, 1)
104-
aten::clone(t: f32[8, 1])
105+
redistribute_input(t: f32[8, 8], [R] -> [S(1)])
106+
aten::split.Tensor(t: f32[8, 8], 1, 1)
107+
aten::clone(t: f32[8, 1])
105108
aten::_to_copy(t: f32[8, 1], dtype=torch.float32, layout=torch.strided, device=cpu)
106-
aten::detach(t: f32[8, 1])
107-
aten::split.Tensor(t: f32[8, 8], 1)
108-
aten::clone(t: f32[1, 8])
109+
redistribute_input(t: f32[8, 8], [R] -> [S(0)])
110+
aten::detach(t: f32[8, 1])
111+
aten::split.Tensor(t: f32[8, 8], 1)
112+
aten::clone(t: f32[1, 8])
109113
aten::_to_copy(t: f32[1, 8], dtype=torch.float32, layout=torch.strided, device=cpu)
110114
aten::detach(t: f32[1, 8])""",
111115
)
@@ -150,19 +154,21 @@ def test_debug_mode_einsum(self):
150154
aten::view(t: f32[8, 4, 4, 1, 1], [1, 8, 16])
151155
aten::bmm(dt: f32[1, 96, 8][P, R], dt: f32[1, 8, 16][R, P])
152156
redistribute_input(0, [P, R] -> [S(2), S(2)])
153-
aten::chunk(t: f32[1, 96, 8], 4, 2)
154-
aten::cat(['t: f32[1, 96, 2]', 't: f32[1, 96, 2]', 't: f32[1, 96, 2]', 't: f32[1, 96, 2]'])
155-
_c10d_functional::reduce_scatter_tensor(t: f32[4, 96, 2], sum, 4, 1)
156-
_c10d_functional::wait_tensor(t: f32[1, 96, 2])
157-
aten::chunk(t: f32[1, 96, 2], 2, 2)
158-
aten::clone(t: f32[1, 96, 1])
157+
redistribute_input(t: f32[1, 96, 8], [P, R] -> [S(2), S(2)])
158+
aten::chunk(t: f32[1, 96, 8], 4, 2)
159+
aten::cat(['t: f32[1, 96, 2]', 't: f32[1, 96, 2]', 't: f32[1, 96, 2]', 't: f32[1, 96, 2]'])
160+
_c10d_functional::reduce_scatter_tensor(t: f32[4, 96, 2], sum, 4, 1)
161+
_c10d_functional::wait_tensor(t: f32[1, 96, 2])
162+
aten::chunk(t: f32[1, 96, 2], 2, 2)
163+
aten::clone(t: f32[1, 96, 1])
159164
redistribute_input(1, [R, P] -> [S(1), S(1)])
160-
aten::chunk(t: f32[1, 8, 16], 4, 1)
161-
aten::clone(t: f32[1, 2, 16])
162-
aten::chunk(t: f32[1, 2, 16], 2, 1)
163-
aten::cat(['t: f32[1, 1, 16]', 't: f32[1, 1, 16]'])
164-
_c10d_functional::reduce_scatter_tensor(t: f32[2, 1, 16], sum, 2, 3)
165-
_c10d_functional::wait_tensor(t: f32[1, 1, 16])
165+
redistribute_input(t: f32[1, 8, 16], [R, P] -> [S(1), S(1)])
166+
aten::chunk(t: f32[1, 8, 16], 4, 1)
167+
aten::clone(t: f32[1, 2, 16])
168+
aten::chunk(t: f32[1, 2, 16], 2, 1)
169+
aten::cat(['t: f32[1, 1, 16]', 't: f32[1, 1, 16]'])
170+
_c10d_functional::reduce_scatter_tensor(t: f32[2, 1, 16], sum, 2, 3)
171+
_c10d_functional::wait_tensor(t: f32[1, 1, 16])
166172
aten::bmm(t: f32[1, 96, 1], t: f32[1, 1, 16])
167173
aten::view(dt: f32[1, 96, 16][P, P], [16, 6, 1, 4, 4])
168174
aten::view(t: f32[1, 96, 16], [16, 6, 1, 4, 4])

test/distributed/tensor/test_dtensor_ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -670,7 +670,7 @@ def test_mean(self):
670670
.to(DEVICE_TYPE)
671671
)
672672

673-
for is_evenly_shardable in [True]:
673+
for is_evenly_shardable in [True, False]:
674674
if is_evenly_shardable:
675675
placement = [Shard(1)]
676676
reduce_dim = 1
@@ -686,9 +686,9 @@ def test_mean(self):
686686
self.assertEqual(full_tensor, tensor.mean(dim=reduce_dim))
687687

688688
if is_evenly_shardable:
689-
self.assertFalse("redistribute_input" in debug_mode.debug_string())
689+
self.assertTrue("[P] -> [R]" in debug_mode.debug_string())
690690
else:
691-
self.assertTrue("redistribute_input" in debug_mode.debug_string())
691+
self.assertTrue("[S(0)] -> [R])" in debug_mode.debug_string())
692692

693693

694694
# only instantiate tests for DEVICE_TYPE alone (i.e. either CPU or GPU)

torch/distributed/tensor/_dispatch.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,8 @@
2323
)
2424
from torch.distributed.tensor._utils import try_find_mesh_from_args
2525
from torch.distributed.tensor.placement_types import Partial, Placement, Replicate
26-
from torch.utils._debug_mode import DebugMode
27-
from torch.utils._python_dispatch import (
28-
_get_current_dispatch_mode,
29-
return_and_correct_aliasing,
30-
)
26+
from torch.utils._debug_mode import get_active_debug_mode
27+
from torch.utils._python_dispatch import return_and_correct_aliasing
3128

3229

3330
try:
@@ -338,8 +335,7 @@ def redistribute_local_args(
338335
suggested_input_schema: OpSchema,
339336
use_val_from_redistribute_schema: bool,
340337
) -> None:
341-
debug_mode = _get_current_dispatch_mode()
342-
in_debug_mode = isinstance(debug_mode, DebugMode)
338+
debug_mode = get_active_debug_mode()
343339

344340
# NOTE: it's very rare that we need to reshard kwargs so we intentionally skip it
345341
if op_info.args_tree_spec is not None:
@@ -359,7 +355,7 @@ def redistribute_local_args(
359355
debug_mode.record_redistribute_calls( # type: ignore[union-attr]
360356
i, arg_spec, reshard_arg_spec
361357
)
362-
if in_debug_mode
358+
if debug_mode is not None
363359
else contextlib.nullcontext()
364360
)
365361

torch/distributed/tensor/_redistribute.py

Lines changed: 96 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# mypy: allow-untyped-defs
22
# Copyright (c) Meta Platforms, Inc. and affiliates
3+
import contextlib
34
import logging
45
from functools import cache
56
from typing import cast, NamedTuple, Optional
@@ -16,6 +17,7 @@
1617
Replicate,
1718
Shard,
1819
)
20+
from torch.utils._debug_mode import get_active_debug_mode
1921

2022

2123
logger = logging.getLogger(__name__)
@@ -187,92 +189,106 @@ def redistribute_local_tensor(
187189
else:
188190
transform_infos = _gen_transform_infos(current_spec, target_spec)
189191

190-
for transform_info in transform_infos:
191-
i = transform_info.mesh_dim
192-
current, target = transform_info.src_dst_placements
193-
device_mesh.size(mesh_dim=i)
194-
195-
if current == target:
196-
# short cut, just use the original local tensor
197-
new_local_tensor = local_tensor
198-
continue
192+
debug_mode = get_active_debug_mode()
193+
redistribute_context = (
194+
debug_mode.record_redistribute_calls( # type: ignore[union-attr]
195+
local_tensor, current_spec, target_spec
196+
)
197+
if debug_mode is not None
198+
else contextlib.nullcontext()
199+
)
200+
201+
with redistribute_context:
202+
for transform_info in transform_infos:
203+
i = transform_info.mesh_dim
204+
current, target = transform_info.src_dst_placements
205+
device_mesh.size(mesh_dim=i)
206+
207+
if current == target:
208+
# short cut, just use the original local tensor
209+
new_local_tensor = local_tensor
210+
continue
199211

200-
logger.debug("redistribute from %s to %s on mesh dim %s", current, target, i)
212+
logger.debug(
213+
"redistribute from %s to %s on mesh dim %s", current, target, i
214+
)
201215

202-
if target.is_replicate():
203-
# Case 1: target is Replicate
204-
if current.is_partial():
205-
partial_spec = cast(Partial, current)
206-
new_local_tensor = partial_spec._reduce_value(
207-
local_tensor, device_mesh, i
208-
)
209-
elif current.is_shard():
210-
current_placement = cast(Shard, current)
211-
new_local_tensor = current_placement._to_replicate_tensor(
212-
local_tensor, device_mesh, i, transform_info.logical_shape
213-
)
214-
else:
215-
raise RuntimeError(
216-
f"redistribute from {current} to {target} not supported yet"
217-
)
218-
elif target.is_shard():
219-
# Case 2: target is Shard
220-
target_placement = cast(Shard, target)
221-
if current.is_partial():
222-
partial_spec = cast(Partial, current)
223-
new_local_tensor = partial_spec._reduce_shard_value(
224-
local_tensor, device_mesh, i, target_placement
225-
)
226-
elif current.is_replicate():
227-
# split the tensor and return the corresponding cloned local shard
228-
new_local_tensor = target_placement._replicate_to_shard(
229-
local_tensor, device_mesh, i, my_coordinate[i]
230-
)
231-
else:
232-
assert current.is_shard(), (
233-
f"Current placement should be shard but found {current}"
234-
)
235-
shard_spec = cast(Shard, current)
236-
if shard_spec.dim != target_placement.dim:
237-
new_local_tensor = shard_spec._to_new_shard_dim(
238-
local_tensor,
239-
device_mesh,
240-
i,
241-
transform_info.logical_shape,
242-
target_placement.dim,
216+
if target.is_replicate():
217+
# Case 1: target is Replicate
218+
if current.is_partial():
219+
partial_spec = cast(Partial, current)
220+
new_local_tensor = partial_spec._reduce_value(
221+
local_tensor, device_mesh, i
243222
)
244-
elif target.is_partial():
245-
if current.is_replicate():
246-
partial_spec = cast(Partial, target)
247-
# skip the replicate to partial transformation when we are in backward pass
248-
# In this case we keep the grad as replicate, this is because we don't
249-
# want to convert the replicated gradients back to partial, although
250-
# that's logically conform with the same layout, converting the gradients
251-
# back to partial is actually useless as you would have to do reduce later
252-
# which would be more expensive than keeping it replicate! For this reason,
253-
# we keep the replicate grad here.
254-
new_local_tensor = (
255-
partial_spec._partition_value(local_tensor, device_mesh, i)
256-
if not is_backward
257-
else local_tensor
258-
)
259-
elif current.is_shard():
260-
if not is_backward:
223+
elif current.is_shard():
224+
current_placement = cast(Shard, current)
225+
new_local_tensor = current_placement._to_replicate_tensor(
226+
local_tensor, device_mesh, i, transform_info.logical_shape
227+
)
228+
else:
261229
raise RuntimeError(
262230
f"redistribute from {current} to {target} not supported yet"
263231
)
264-
# for backward shard -> partial, we just need to convert the shard to replicate
265-
current_placement = cast(Shard, current)
266-
new_local_tensor = current_placement._to_replicate_tensor(
267-
local_tensor, device_mesh, i, transform_info.logical_shape
268-
)
269-
else:
270-
# partial -> partial no op, should never hit
271-
new_local_tensor = local_tensor
272-
273-
if not async_op and isinstance(new_local_tensor, funcol.AsyncCollectiveTensor):
274-
new_local_tensor = new_local_tensor.wait()
275-
local_tensor = new_local_tensor
232+
elif target.is_shard():
233+
# Case 2: target is Shard
234+
target_placement = cast(Shard, target)
235+
if current.is_partial():
236+
partial_spec = cast(Partial, current)
237+
new_local_tensor = partial_spec._reduce_shard_value(
238+
local_tensor, device_mesh, i, target_placement
239+
)
240+
elif current.is_replicate():
241+
# split the tensor and return the corresponding cloned local shard
242+
new_local_tensor = target_placement._replicate_to_shard(
243+
local_tensor, device_mesh, i, my_coordinate[i]
244+
)
245+
else:
246+
assert current.is_shard(), (
247+
f"Current placement should be shard but found {current}"
248+
)
249+
shard_spec = cast(Shard, current)
250+
if shard_spec.dim != target_placement.dim:
251+
new_local_tensor = shard_spec._to_new_shard_dim(
252+
local_tensor,
253+
device_mesh,
254+
i,
255+
transform_info.logical_shape,
256+
target_placement.dim,
257+
)
258+
elif target.is_partial():
259+
if current.is_replicate():
260+
partial_spec = cast(Partial, target)
261+
# skip the replicate to partial transformation when we are in backward pass
262+
# In this case we keep the grad as replicate, this is because we don't
263+
# want to convert the replicated gradients back to partial, although
264+
# that's logically conform with the same layout, converting the gradients
265+
# back to partial is actually useless as you would have to do reduce later
266+
# which would be more expensive than keeping it replicate! For this reason,
267+
# we keep the replicate grad here.
268+
new_local_tensor = (
269+
partial_spec._partition_value(local_tensor, device_mesh, i)
270+
if not is_backward
271+
else local_tensor
272+
)
273+
elif current.is_shard():
274+
if not is_backward:
275+
raise RuntimeError(
276+
f"redistribute from {current} to {target} not supported yet"
277+
)
278+
# for backward shard -> partial, we just need to convert the shard to replicate
279+
current_placement = cast(Shard, current)
280+
new_local_tensor = current_placement._to_replicate_tensor(
281+
local_tensor, device_mesh, i, transform_info.logical_shape
282+
)
283+
else:
284+
# partial -> partial no op, should never hit
285+
new_local_tensor = local_tensor
286+
287+
if not async_op and isinstance(
288+
new_local_tensor, funcol.AsyncCollectiveTensor
289+
):
290+
new_local_tensor = new_local_tensor.wait()
291+
local_tensor = new_local_tensor
276292
return new_local_tensor
277293

278294

torch/utils/_debug_mode.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
11
# mypy: allow-untyped-defs
22
import contextlib
3+
from typing import Optional
34

45
import torch
56
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
67
from torch.utils._dtype_abbrs import dtype_abbrs
7-
from torch.utils._python_dispatch import _get_current_dispatch_mode, TorchDispatchMode
8+
from torch.utils._python_dispatch import (
9+
_get_current_dispatch_mode,
10+
_get_current_dispatch_mode_stack,
11+
TorchDispatchMode,
12+
)
813
from torch.utils._pytree import tree_map
914

1015

11-
__all__ = ["DebugMode"]
16+
__all__ = ["DebugMode", "get_active_debug_mode"]
1217

1318
REDISTRIBUTE_FUNC = "redistribute_input"
1419

@@ -168,3 +173,12 @@ def debug_string(self) -> str:
168173
for op, args, kwargs, depth in self.operators
169174
)
170175
return result
176+
177+
178+
def get_active_debug_mode() -> Optional[DebugMode]:
179+
debug_mode = None
180+
for mode in _get_current_dispatch_mode_stack():
181+
if isinstance(mode, DebugMode):
182+
debug_mode = mode
183+
break
184+
return debug_mode

0 commit comments

Comments
 (0)