Skip to content

Commit 6eb71ce

Browse files
mlazospytorchmergebot
authored andcommitted
[user-streams] Assign streams to gradient accum in bwd (pytorch#167513)
Pull Request resolved: pytorch#167513 Approved by: https://github.com/soulitzer
1 parent 2d14e86 commit 6eb71ce

File tree

3 files changed

+119
-1
lines changed

3 files changed

+119
-1
lines changed

test/dynamo/test_streams.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
470470
)
471471

472472
@requires_cuda
473-
def test_stream_backward(self) -> None:
473+
def test_stream_backward_simple(self) -> None:
474474
def fn(x, y):
475475
s2 = torch.Stream()
476476
s0 = torch.Stream()
@@ -524,7 +524,68 @@ def forward(self, tangents_1: "f32[2, 2]", tangents_2: "f32[2, 2]"):
524524
# Annotation: {'stream': 1}
525525
mul_3: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_1, 2); tangents_1 = None
526526
527+
# Annotation: {'stream': 0}
528+
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = mul_3 = None
529+
return (add_3, add_2)
530+
""",
531+
)
532+
533+
@requires_cuda
534+
def test_stream_backward_sync(self) -> None:
535+
def fn(x, y):
536+
s2 = torch.Stream()
537+
s0 = torch.Stream()
538+
with s0:
539+
y0 = 2 * x + y
540+
with s2:
541+
z = 2 * x + y
542+
543+
return y0, z
544+
545+
inp = (
546+
torch.ones(2, 2, device="cuda:0", requires_grad=True) + 1,
547+
torch.ones(2, 2, device="cuda:0", requires_grad=True),
548+
)
549+
expected = fn(*inp)
550+
(
551+
actual,
552+
_,
553+
fw_graphs,
554+
bw_graphs,
555+
) = extract_graph(fn, *inp)
556+
self.assertEqual(len(fw_graphs), 1)
557+
self.assertEqual(expected, actual)
558+
self.assertExpectedInline(
559+
print_graph(fw_graphs[0]),
560+
"""\
561+
class GraphModule(torch.nn.Module):
562+
def forward(self, primals_1: "f32[2, 2]", primals_2: "f32[2, 2]"):
563+
# Annotation: {'stream': 1}
564+
mul: "f32[2, 2]" = torch.ops.aten.mul.Tensor(primals_1, 2); primals_1 = None
565+
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2)
566+
567+
# Annotation: {'stream': 0}
568+
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2); mul = primals_2 = None
569+
return (add, add_1)
570+
""",
571+
)
572+
573+
actual[1].sum().backward()
574+
self.assertExpectedInline(
575+
print_graph(bw_graphs[0]),
576+
"""\
577+
class GraphModule(torch.nn.Module):
578+
def forward(self, tangents_1: "f32[2, 2]", tangents_2: "f32[2, 2]"):
579+
# Annotation: {'stream': 0}
580+
mul_2: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_2, 2)
581+
527582
#
583+
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(tangents_2, tangents_1); tangents_2 = None
584+
585+
# Annotation: {'stream': 1}
586+
mul_3: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_1, 2); tangents_1 = None
587+
588+
# Annotation: {'stream': 0}
528589
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = mul_3 = None
529590
return (add_3, add_2)
530591
""",

torch/_functorch/_aot_autograd/graph_capture.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
handle_effect_tokens_fn,
3434
)
3535
from .schemas import AOTConfig, FxValue, SubclassMeta, TraceFn, ViewAndMutationMeta
36+
from .streams import assign_backward_streams
3637
from .utils import (
3738
call_and_expect_output_descs,
3839
copy_fwd_metadata_to_bw_nodes,
@@ -473,6 +474,9 @@ def aot_dispatch_autograd_graph(
473474
# fw node match might be erased
474475
copy_fwd_metadata_to_bw_nodes(fx_g)
475476

477+
# After copying metadata, assign streams to gradient accumulation nodes
478+
assign_backward_streams(fx_g)
479+
476480
fx_g.graph.eliminate_dead_code()
477481
if not aot_config.disable_functionalization:
478482
# There should be *NO* mutating ops in the graph at this point.
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from typing import Optional, TypeAlias
2+
3+
import torch.fx
4+
import torch.fx.traceback
5+
from torch._dynamo.graph_utils import _get_flat_args
6+
7+
8+
Node: TypeAlias = torch.fx.Node
9+
10+
11+
def is_gradient_acc(node: Node) -> bool:
12+
return node.meta.get("is_gradient_acc", False)
13+
14+
15+
def get_stream(node: Node) -> Optional[int]:
16+
maybe_annotation = node.meta.get("custom", None)
17+
if maybe_annotation is not None:
18+
return node.meta["custom"].get("stream", None)
19+
else:
20+
return None
21+
22+
23+
def set_stream(node: Node, ind: int) -> None:
24+
if "custom" in node.meta:
25+
node.meta["custom"].update({"stream": ind})
26+
else:
27+
node.meta["custom"] = {"stream": ind}
28+
29+
30+
def assign_backward_streams(gm: torch.fx.GraphModule) -> None:
31+
"""Assigns backward streams to gradient accumulation nodes"""
32+
33+
# NB: iterate in reverse order to more closely match eager
34+
# the user node stream will be populated first
35+
for node in reversed(list(gm.graph.nodes)):
36+
if is_gradient_acc(node):
37+
# Accumulation stream selection. Follow the rules from top to bottom to determine the accumulation stream:
38+
# 1. Match first stream assignment of the first user with a stream
39+
# 2. Match first stream assignment encountered in the args from left to right
40+
# This differs from eager in some cases:
41+
# Specifically the eager code uses the autograd node to determine the stream,
42+
# crucially this does not necessarily correspond to the FX graph node. For example,
43+
# in the backward for an add node with a constant we will passthrough and during backward tracing,
44+
# no op will be added to the FX graph, so our stream assignment will differ in this case.
45+
gradients = _get_flat_args(node, {})
46+
users = list(node.users.keys())
47+
48+
# All gradients will be on same device, they will be coerced if they were not with a .to() node
49+
for neighbor in users + gradients:
50+
ind = get_stream(neighbor)
51+
if ind is not None:
52+
set_stream(node, ind)
53+
break

0 commit comments

Comments
 (0)