Skip to content

Commit 2b92b31

Browse files
ruisizhang123pytorchmergebot
authored andcommitted
[simplefsdp] fix DSV3 autobucketing issue (pytorch#167797)
Fix for this issue on DSV3 autobucketing pass: pytorch/torchtitan#2037; Now users should be able to run DSV3 autobucketing E2E. It fixed three things: (1) fix bug in NCCL estimation support for All-to-all. (2) For dynamic token dispatch/combine in MoE, add fall_back value hint to all-to-all's collective size estimation. (3) Previously, for schedulable node check, I directly modified `is_wait` in bucketing.py. It might be safer to add these criteria in overlap_scheduling.py as another function `_schedulable_wait_node` Pull Request resolved: pytorch#167797 Approved by: https://github.com/eellison
1 parent db1551b commit 2b92b31

File tree

6 files changed

+234
-30
lines changed

6 files changed

+234
-30
lines changed

test/distributed/test_aten_comm_compute_reordering.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
# for some reason importing functional collectives after dynamo breaks collectives handling!
1212
import torch.distributed._functional_collectives as _functional_collectives
13+
import torch.fx as fx
1314
from torch._C import FileCheck
1415
from torch._dynamo.utils import counters, same
1516
from torch._inductor.utils import run_and_get_code, run_and_get_triton_code
@@ -238,6 +239,49 @@ def func(a, *, tag, ranks, group_size):
238239
self.assertTrue(same(out, correct))
239240
self.assertEqual(counters["inductor"]["overlap_scheduling_exposed"], 0)
240241

242+
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
243+
@torch._inductor.config.patch(get_patches())
244+
def test_schedulable_wait(self):
245+
"""Test that if a wait node is scheduable or not."""
246+
from torch._inductor.fx_passes.bucketing import _schedulable_wait_node
247+
248+
def test_graph():
249+
graph = fx.Graph()
250+
251+
inp = graph.placeholder("inp")
252+
group_size = graph.placeholder("group_size")
253+
group_name = graph.placeholder("group_name")
254+
255+
ag_0_out = graph.call_function(
256+
torch.ops._c10d_functional.all_gather_into_tensor.default,
257+
args=(inp, group_size, group_name),
258+
)
259+
ag_0_wait = graph.call_function(
260+
torch.ops._c10d_functional.wait_tensor.default,
261+
args=(ag_0_out,),
262+
)
263+
ag_1_out = graph.call_function(
264+
torch.ops._c10d_functional.all_gather_into_tensor.default,
265+
args=(ag_0_wait, group_size, group_name),
266+
)
267+
ag_1_wait = graph.call_function(
268+
torch.ops._c10d_functional.wait_tensor.default,
269+
args=(ag_1_out,),
270+
)
271+
ag_2_wait = graph.call_function(
272+
torch.ops._c10d_functional.wait_tensor.default,
273+
args=(ag_1_wait,),
274+
)
275+
276+
graph.output(ag_2_wait)
277+
return graph
278+
279+
graph = test_graph()
280+
schedulable = {"wait_tensor_default", "wait_tensor_default_1"}
281+
for node in list(graph.nodes):
282+
expected = node.name in schedulable
283+
assert _schedulable_wait_node(node) is expected
284+
241285
@torch._inductor.config.patch(get_patches())
242286
def test_reorder_compute_for_overlap_mul(self):
243287
def func(a, *, tag, ranks, group_size):

test/distributed/test_inductor_collectives.py

Lines changed: 141 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,12 @@
2323
sink_waits_iterative,
2424
)
2525
from torch._inductor.compile_fx import compile_fx as inductor_compile_fx
26-
from torch._inductor.fx_passes.bucketing import is_all_gather_into_tensor
26+
from torch._inductor.fx_passes.bucketing import (
27+
is_all_gather_into_tensor,
28+
is_all_reduce_tensor,
29+
is_all_to_all_tensor,
30+
is_reduce_scatter_tensor,
31+
)
2732
from torch._inductor.scheduler import (
2833
_get_mm_like_fn,
2934
BaseSchedulerNode,
@@ -2188,7 +2193,7 @@ def test_sync_decision_cross_ranks(self):
21882193
self.assertEqual(saved_values, [wt1])
21892194

21902195
@skip_if_lt_x_gpu(2)
2191-
def test_comm_analysis(self):
2196+
def test_all_gather_comm_analysis(self):
21922197
store = c10d.FileStore(self.file_name, self.world_size)
21932198
torch.cuda.set_device(self.rank)
21942199
c10d.init_process_group(
@@ -2229,6 +2234,140 @@ def func(inp, group_size, group_name):
22292234
)
22302235
assert est_ms_nccl > 0
22312236

2237+
@skip_if_lt_x_gpu(2)
2238+
def test_reduce_scatter_comm_analysis(self):
2239+
store = c10d.FileStore(self.file_name, self.world_size)
2240+
torch.cuda.set_device(self.rank)
2241+
c10d.init_process_group(
2242+
backend="nccl", store=store, rank=self.rank, world_size=self.world_size
2243+
)
2244+
group = c10d.distributed_c10d._get_default_group()
2245+
group_name = "default"
2246+
torch._C._distributed_c10d._register_process_group(
2247+
group_name, torch.distributed.group.WORLD
2248+
)
2249+
group_size = group.size()
2250+
2251+
def func(inp, group_size, group_name):
2252+
rs_0_out = torch.ops._c10d_functional.reduce_scatter_tensor(
2253+
inp, "sum", group_size, group_name
2254+
)
2255+
rs_0_wait = torch.ops.c10d_functional.wait_tensor(rs_0_out)
2256+
rs_1_out = torch.ops._c10d_functional.reduce_scatter_tensor(
2257+
rs_0_wait, "sum", group_size, group_name
2258+
)
2259+
rs_1_wait = torch.ops.c10d_functional.wait_tensor(rs_1_out)
2260+
return rs_1_wait
2261+
2262+
gm = make_fx(func)(torch.ones(4, 4, device=self.device), group_size, group_name)
2263+
g = gm.graph
2264+
for n in g.nodes:
2265+
if is_reduce_scatter_tensor(n):
2266+
from torch._inductor.comm_analysis import (
2267+
estimate_nccl_collective_runtime_from_fx_node,
2268+
)
2269+
2270+
est_ms = estimate_nccl_collective_runtime_from_fx_node(
2271+
n, use_nccl_estimator=False
2272+
)
2273+
assert est_ms > 0
2274+
est_ms_nccl = estimate_nccl_collective_runtime_from_fx_node(
2275+
n, use_nccl_estimator=True
2276+
)
2277+
assert est_ms_nccl > 0
2278+
2279+
@skip_if_lt_x_gpu(2)
2280+
def test_all_reduce_comm_analysis(self):
2281+
store = c10d.FileStore(self.file_name, self.world_size)
2282+
torch.cuda.set_device(self.rank)
2283+
c10d.init_process_group(
2284+
backend="nccl", store=store, rank=self.rank, world_size=self.world_size
2285+
)
2286+
group = c10d.distributed_c10d._get_default_group()
2287+
group_name = "default"
2288+
torch._C._distributed_c10d._register_process_group(
2289+
group_name, torch.distributed.group.WORLD
2290+
)
2291+
group_size = group.size()
2292+
2293+
def func(inp, group_size, group_name):
2294+
ar_0_out = torch.ops._c10d_functional.all_reduce(inp, "sum", group_name)
2295+
ar_0_wait = torch.ops.c10d_functional.wait_tensor(ar_0_out)
2296+
ar_1_out = torch.ops._c10d_functional.all_reduce(
2297+
ar_0_wait, "sum", group_name
2298+
)
2299+
ar_1_wait = torch.ops.c10d_functional.wait_tensor(ar_1_out)
2300+
return ar_1_wait
2301+
2302+
gm = make_fx(func)(torch.ones(4, 4, device=self.device), group_size, group_name)
2303+
g = gm.graph
2304+
for n in g.nodes:
2305+
if is_all_reduce_tensor(n):
2306+
from torch._inductor.comm_analysis import (
2307+
estimate_nccl_collective_runtime_from_fx_node,
2308+
)
2309+
2310+
est_ms = estimate_nccl_collective_runtime_from_fx_node(
2311+
n, use_nccl_estimator=False
2312+
)
2313+
assert est_ms > 0
2314+
est_ms_nccl = estimate_nccl_collective_runtime_from_fx_node(
2315+
n, use_nccl_estimator=True
2316+
)
2317+
assert est_ms_nccl > 0
2318+
2319+
@skip_if_lt_x_gpu(2)
2320+
def test_all_to_all_comm_analysis(self):
2321+
store = c10d.FileStore(self.file_name, self.world_size)
2322+
torch.cuda.set_device(self.rank)
2323+
c10d.init_process_group(
2324+
backend="nccl", store=store, rank=self.rank, world_size=self.world_size
2325+
)
2326+
group = c10d.distributed_c10d._get_default_group()
2327+
group_name = "default"
2328+
torch._C._distributed_c10d._register_process_group(
2329+
group_name, torch.distributed.group.WORLD
2330+
)
2331+
group_size = group.size()
2332+
2333+
def func(inp, group_size, group_name):
2334+
chunk = inp.numel() // self.world_size
2335+
split_sizes = [chunk] * self.world_size
2336+
a2a_0_out = torch.ops._c10d_functional.all_to_all_single(
2337+
inp,
2338+
split_sizes,
2339+
split_sizes,
2340+
group_name,
2341+
)
2342+
a2a_0_wait = torch.ops.c10d_functional.wait_tensor(a2a_0_out)
2343+
a2a_1_out = torch.ops._c10d_functional.all_to_all_single(
2344+
a2a_0_wait,
2345+
split_sizes,
2346+
split_sizes,
2347+
group_name,
2348+
)
2349+
a2a_1_wait = torch.ops.c10d_functional.wait_tensor(a2a_1_out)
2350+
return a2a_1_wait
2351+
2352+
gm = make_fx(func)(
2353+
torch.ones(group_size * 4, 1, device=self.device), group_size, group_name
2354+
)
2355+
g = gm.graph
2356+
for n in g.nodes:
2357+
if is_all_to_all_tensor(n):
2358+
from torch._inductor.comm_analysis import (
2359+
estimate_nccl_collective_runtime_from_fx_node,
2360+
)
2361+
2362+
est_ms = estimate_nccl_collective_runtime_from_fx_node(
2363+
n, use_nccl_estimator=False
2364+
)
2365+
assert est_ms > 0
2366+
est_ms_nccl = estimate_nccl_collective_runtime_from_fx_node(
2367+
n, use_nccl_estimator=True
2368+
)
2369+
assert est_ms_nccl > 0
2370+
22322371
@skip_if_lt_x_gpu(2)
22332372
@requires_gloo()
22342373
def test_regression_use_nccl_estimate_with_gloo(self):

torch/_inductor/comm_analysis.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ class NCCL_COLL(IntEnum):
2323
ALL_GATHER = 1
2424
REDUCE_SCATTER = 2
2525
ALL_TO_ALL = 3
26+
UNSUPPORTED = 4
2627

2728

2829
class NVIDIA_GPU_TYPE(IntEnum):
@@ -53,10 +54,10 @@ def get_collective_type_from_kernel_name(kernel_name: str) -> NCCL_COLL:
5354
return NCCL_COLL.ALL_GATHER
5455
elif "reduce_scatter" in kernel_name:
5556
return NCCL_COLL.REDUCE_SCATTER
56-
elif "torch.ops._dtensor.shard_dim_alltoall.default" in kernel_name:
57+
elif any(comm in kernel_name for comm in ("all_to_all", "alltoall")):
5758
return NCCL_COLL.ALL_TO_ALL
5859
else:
59-
raise ValueError(f"Unsupported collective kernel: {kernel_name}")
60+
return NCCL_COLL.UNSUPPORTED
6061

6162

6263
def get_collective_type(node: ir.IRNode) -> NCCL_COLL:
@@ -340,13 +341,12 @@ def estimate_nccl_collective_runtime(node: ir.IRNode) -> float:
340341

341342

342343
def estimate_fx_collective_size(fx_node: torch.fx.Node) -> int:
343-
size = 0
344+
sz_bytes = 0
344345
for node in fx_node.all_input_nodes:
345346
if (t := node.meta.get("val")) is not None:
346-
size += t.numel() * t.element_size()
347-
348-
# TODO - symbolic
349-
return size
347+
numel = get_size_numel(t.size())
348+
sz_bytes += numel * get_dtype_size(t.dtype)
349+
return sz_bytes
350350

351351

352352
def estimate_nccl_collective_runtime_from_fx_node(

torch/_inductor/fx_passes/bucketing.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
import torch.utils._pytree as pytree
1111
from torch._dispatch.python import enable_python_dispatcher
1212
from torch._dynamo.utils import detect_fake_mode
13+
from torch._inductor.comm_analysis import (
14+
get_collective_type_from_kernel_name,
15+
NCCL_COLL,
16+
)
1317
from torch._inductor.runtime.runtime_utils import dynamo_timed
1418
from torch._logging import trace_structured
1519
from torch.fx.experimental.proxy_tensor import make_fx
@@ -52,6 +56,23 @@ def _ar_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]:
5256
return (group_name, reduce_op, dtype)
5357

5458

59+
def _schedulable_wait_node(node: torch.fx.Node) -> bool:
60+
"""
61+
Add additional check on if the wait node is schedulable
62+
We should not schedule a fx node that is:
63+
1. wait on a collective that is not callable
64+
2. wait on a non-NCCL communication node
65+
"""
66+
if not is_wait_tensor(node):
67+
return False
68+
assert isinstance(node.args[0], torch.fx.Node)
69+
assert isinstance(node.args[0].target.name(), str)
70+
is_callable: bool = node.args[0].op == "call_function"
71+
coll: NCCL_COLL = get_collective_type_from_kernel_name(node.args[0].target.name())
72+
is_collective: bool = coll != NCCL_COLL.UNSUPPORTED
73+
return is_callable and is_collective
74+
75+
5576
def bucket_key(node: torch.fx.Node, mode: BucketMode | None = None) -> object | None:
5677
if is_all_gather_into_tensor(node):
5778
group_key_fn = (
@@ -138,7 +159,6 @@ def is_wait_tensor(node: torch.fx.Node) -> bool:
138159
return (
139160
node.op == "call_function"
140161
and node.target is torch.ops._c10d_functional.wait_tensor.default
141-
and node.args[0].op == "call_function"
142162
)
143163

144164

@@ -149,6 +169,13 @@ def is_all_reduce_tensor(node: torch.fx.Node) -> bool:
149169
)
150170

151171

172+
def is_all_to_all_tensor(node: torch.fx.Node) -> bool:
173+
return (
174+
node.op == "call_function"
175+
and node.target is torch.ops._c10d_functional.all_to_all_single.default
176+
)
177+
178+
152179
def is_wait_tensor_from_all_gather_into_tensor(node: torch.fx.Node) -> bool:
153180
return is_wait_tensor(node) and is_all_gather_into_tensor(node.args[0]) # type: ignore[arg-type]
154181

torch/_inductor/fx_passes/overlap_manual_scheduling.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
import torch.fx as fx
99
from torch._dynamo.graph_deduplication import _stable_topological_sort
1010
from torch._inductor.fx_passes.bucketing import (
11+
_schedulable_wait_node,
1112
is_all_gather_into_tensor as is_all_gather,
1213
is_reduce_scatter_tensor as is_reduce_scatter,
13-
is_wait_tensor,
1414
merge_all_gather_bucket,
1515
merge_reduce_scatter_bucket,
1616
)
@@ -36,7 +36,10 @@ class ManualOverlapPreservingBucketer(OverlapPreservingBucketer):
3636
"""
3737

3838
def __init__(
39-
self, node_users: dict[fx.Node, OrderedSet[fx.Node]], *args: Any, **kwargs: Any
39+
self,
40+
node_users: dict[fx.Node, OrderedSet[fx.Node]],
41+
*args: Any,
42+
**kwargs: Any,
4043
):
4144
super().__init__(*args, **kwargs)
4245
self.node_users = node_users
@@ -97,7 +100,7 @@ def _bucket_group(self, coll_nodes: list[fx.Node]) -> None:
97100
)
98101

99102
# Identify the new wait and start
100-
new_waits = [n for n in new_nodes if is_wait_tensor(n)]
103+
new_waits = [n for n in new_nodes if _schedulable_wait_node(n)]
101104
assert len(new_waits) == 1, f"Expected exactly one new wait, got {new_waits}"
102105
new_wait = new_waits[0]
103106
new_start = new_wait.args[0]
@@ -186,7 +189,7 @@ def __init__(
186189
def _identify_collectives(self) -> None:
187190
"""Identify all collective operations."""
188191
for node in self.nodes:
189-
if is_wait_tensor(node):
192+
if _schedulable_wait_node(node):
190193
start = node.args[0]
191194
info = CollectiveInfo(
192195
start_node=start,

0 commit comments

Comments
 (0)