Skip to content

Commit 124dd36

Browse files
xmfanpytorchmergebot
authored andcommitted
[hop] support local_map + SAC (pytorch#163322)
Some ops like local_map hop's deferred mode are not desugared by make_fx, this means that when we apply SAC tags, we will need to define dispatch rules for the SAC torch dispatch modes as pointed out here: pytorch#162246 (comment). This PR adds those rules. Additionally it fixes a pre-existing issue where we weren't coercing tangent layout (that AOTAutograd typically does) when partitioning the HOP joint. Pull Request resolved: pytorch#163322 Approved by: https://github.com/ezyang
1 parent 20eeb54 commit 124dd36

File tree

2 files changed

+286
-40
lines changed

2 files changed

+286
-40
lines changed

test/higher_order_ops/test_local_map.py

Lines changed: 264 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
# flake8: noqa: B950
33

44

5+
import contextlib
6+
import functools
57
import unittest
68

79
import torch
@@ -12,19 +14,36 @@
1214
import torch.nn.functional as F
1315
from torch import nn
1416
from torch._dynamo.variables.higher_order_ops import LocalMapWrappedHigherOrderVariable
17+
from torch.nn.attention import sdpa_kernel, SDPBackend
18+
from torch.utils.checkpoint import create_selective_checkpoint_contexts
1519

1620

1721
if torch.distributed.is_available():
1822
from torch.distributed._tensor.experimental import local_map
1923
from torch.distributed.tensor.placement_types import Replicate, Shard
2024

21-
from torch.testing._internal.common_utils import run_tests, TEST_WITH_CROSSREF, TestCase
22-
from torch.testing._internal.triton_utils import requires_cuda_and_triton
25+
from torch.testing._internal.common_utils import (
26+
run_tests,
27+
TEST_WITH_CROSSREF,
28+
TEST_WITH_TORCHDYNAMO,
29+
TEST_WITH_TORCHINDUCTOR,
30+
TestCase,
31+
)
2332

2433

2534
nested_compile_region = torch.compiler.nested_compile_region
2635

2736

37+
def get_skip_reasons():
38+
msg = ""
39+
if not torch.distributed.is_available():
40+
msg += "Torch distributed not available. "
41+
if TEST_WITH_TORCHINDUCTOR or TEST_WITH_TORCHDYNAMO:
42+
msg += "Already manually torch.compile'd. "
43+
44+
return msg != "", msg
45+
46+
2847
class MyTransform(torch.autograd.Function):
2948
@staticmethod
3049
def forward(ctx, x):
@@ -42,7 +61,14 @@ def context_parallel_attention(query, key, value):
4261
return out
4362

4463

45-
def create_model(attention_fn, nheads, dim1, dim2):
64+
# NOTE: we use this function directly in the node checks
65+
def save_scalar_muls(ctx, op, *args, **kwargs):
66+
if op == torch.ops.aten.mul.Scalar:
67+
return torch.utils.checkpoint.CheckpointPolicy.MUST_SAVE
68+
return torch.utils.checkpoint.CheckpointPolicy.MUST_RECOMPUTE
69+
70+
71+
def create_model(attention_fn, nheads, dim1, dim2, sac_policy=None):
4672
class LocalMapTransformerBlock(nn.Module):
4773
def __init__(self, nheads, dim1, dim2):
4874
super().__init__()
@@ -54,8 +80,14 @@ def __init__(self, nheads, dim1, dim2):
5480
self.wo = nn.Linear(dim1, dim1, bias=bias)
5581
self.w1 = nn.Linear(dim1, dim2, bias=bias)
5682
self.w2 = nn.Linear(dim2, dim1, bias=bias)
57-
58-
def forward(self, x):
83+
if sac_policy:
84+
self.sac_context_fn = functools.partial(
85+
create_selective_checkpoint_contexts, sac_policy
86+
)
87+
else:
88+
self.sac_context_fn = None
89+
90+
def _forward(self, x):
5991
q = self.wq(x)
6092
k = self.wk(x)
6193
v = self.wv(x)
@@ -78,41 +110,63 @@ def forward(self, x):
78110
o = o0 + o
79111
return o
80112

113+
def forward(self, x):
114+
if self.sac_context_fn is not None:
115+
return torch.utils.checkpoint.checkpoint(
116+
self._forward,
117+
x,
118+
use_reentrant=False,
119+
context_fn=self.sac_context_fn,
120+
)
121+
return self._forward(x)
122+
81123
return LocalMapTransformerBlock(nheads, dim1, dim2)
82124

83125

84-
class TestLocalMap(TestCase):
85-
@requires_cuda_and_triton
86-
@unittest.skipIf(
87-
not torch.distributed.is_available(), "Torch distributed not available."
126+
def get_local_mapped_functions():
127+
assert torch.distributed.is_available()
128+
129+
@local_map(
130+
out_placements=((Shard(0), Shard(1), Shard(2)),),
131+
in_placements=(
132+
(Shard(0), Shard(1), Shard(2)), # query
133+
(Shard(0), Shard(1), Replicate()), # key
134+
(Shard(0), Shard(1), Replicate()), # value
135+
),
136+
redistribute_inputs=True,
137+
in_grad_placements=None,
138+
device_mesh=None,
139+
)
140+
def cp_decorated(query, key, value):
141+
return context_parallel_attention(query, key, value)
142+
143+
cp_function = local_map(
144+
context_parallel_attention,
145+
out_placements=(Shard(0), Shard(1), Shard(2)),
146+
in_placements=(
147+
(Shard(0), Shard(1), Shard(2)), # query
148+
(Shard(0), Shard(1), Replicate()), # key
149+
(Shard(0), Shard(1), Replicate()), # value
150+
),
151+
redistribute_inputs=True,
152+
in_grad_placements=None,
153+
device_mesh=None,
88154
)
155+
156+
return cp_decorated, cp_function
157+
158+
159+
class TestLocalMap(TestCase):
160+
def setUp(self):
161+
self.exit_stack = contextlib.ExitStack()
162+
self.exit_stack.enter_context(sdpa_kernel(backends=[SDPBackend.MATH]))
163+
164+
def tearDown(self):
165+
self.exit_stack.close()
166+
167+
@unittest.skipIf(*get_skip_reasons())
89168
def test_simple(self):
90-
@local_map(
91-
out_placements=((Shard(0), Shard(1), Shard(2)),),
92-
in_placements=(
93-
(Shard(0), Shard(1), Shard(2)), # query
94-
(Shard(0), Shard(1), Replicate()), # key
95-
(Shard(0), Shard(1), Replicate()), # value
96-
),
97-
redistribute_inputs=True,
98-
in_grad_placements=None,
99-
device_mesh=None,
100-
)
101-
def cp_decorated(query, key, value):
102-
return context_parallel_attention(query, key, value)
103-
104-
cp_function = local_map(
105-
context_parallel_attention,
106-
out_placements=(Shard(0), Shard(1), Shard(2)),
107-
in_placements=(
108-
(Shard(0), Shard(1), Shard(2)), # query
109-
(Shard(0), Shard(1), Replicate()), # key
110-
(Shard(0), Shard(1), Replicate()), # value
111-
),
112-
redistribute_inputs=True,
113-
in_grad_placements=None,
114-
device_mesh=None,
115-
)
169+
cp_decorated, cp_function = get_local_mapped_functions()
116170
bs = 8 * 1
117171
dim1 = 96
118172
dim2 = dim1 * 4
@@ -123,21 +177,24 @@ def cp_decorated(query, key, value):
123177

124178
backend = EagerAndRecordGraphs()
125179

126-
model = create_model(cp_decorated, nheads, dim1, dim2).cuda()
127-
inputs = (torch.randn(bs, seq_len, dim1, requires_grad=True).cuda(),)
180+
model = create_model(cp_decorated, nheads, dim1, dim2)
181+
inputs = (torch.randn(bs, seq_len, dim1, requires_grad=True),)
128182
with LocalMapWrappedHigherOrderVariable.enable():
129183
out = torch.compile(model, backend=backend)(*inputs)
130184
out.sum().backward()
131185

132-
model = create_model(cp_function, nheads, dim1, dim2).cuda()
133-
inputs = (torch.randn(bs, seq_len, dim1, requires_grad=True).cuda(),)
186+
model = create_model(cp_function, nheads, dim1, dim2)
187+
inputs = (torch.randn(bs, seq_len, dim1, requires_grad=True),)
134188
with LocalMapWrappedHigherOrderVariable.enable():
135189
out = torch.compile(model, backend=backend)(*inputs)
136190
out.sum().backward()
137191

138192
if not TEST_WITH_CROSSREF:
139193
self.assertEqual(len(backend.graphs), 2)
140-
# should see local_map_hop in both
194+
self.assertEqual(
195+
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
196+
normalize_gm(backend.graphs[1].print_readable(print_output=False)),
197+
)
141198
self.assertExpectedInline(
142199
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
143200
"""\
@@ -193,10 +250,177 @@ def forward(self, q_1: "f32[8, 16, 16, 6]", k_1: "f32[8, 16, 16, 6]", v_1: "f32[
193250
""",
194251
)
195252

253+
@unittest.skipIf(*get_skip_reasons())
254+
def test_sac(self):
255+
cp_decorated, cp_function = get_local_mapped_functions()
256+
bs = 8 * 1
257+
dim1 = 96
258+
dim2 = dim1 * 4
259+
nheads = 16
260+
seq_len = 16
261+
262+
from torch._dynamo.testing import AotEagerAndRecordGraphs, normalize_gm
263+
264+
backend = AotEagerAndRecordGraphs()
265+
266+
model = create_model(
267+
cp_decorated, nheads, dim1, dim2, sac_policy=save_scalar_muls
268+
)
269+
inputs = (torch.randn(bs, seq_len, dim1, requires_grad=True),)
270+
with LocalMapWrappedHigherOrderVariable.enable():
271+
out = torch.compile(model, backend=backend)(*inputs)
272+
out.sum().backward()
273+
274+
model = create_model(
275+
cp_function, nheads, dim1, dim2, sac_policy=save_scalar_muls
276+
)
277+
inputs = (torch.randn(bs, seq_len, dim1, requires_grad=True),)
278+
with LocalMapWrappedHigherOrderVariable.enable():
279+
out = torch.compile(model, backend=backend)(*inputs)
280+
out.sum().backward()
281+
282+
if not TEST_WITH_CROSSREF:
283+
self.assertEqual(len(backend.graphs), 2)
284+
self.assertEqual(
285+
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
286+
normalize_gm(backend.graphs[1].print_readable(print_output=False)),
287+
)
288+
self.assertEqual(
289+
normalize_gm(backend.fw_graphs[0].print_readable(print_output=False)),
290+
normalize_gm(backend.fw_graphs[1].print_readable(print_output=False)),
291+
)
292+
self.assertEqual(
293+
normalize_gm(backend.bw_graphs[0].print_readable(print_output=False)),
294+
normalize_gm(backend.bw_graphs[1].print_readable(print_output=False)),
295+
)
296+
self.assertEqual(
297+
len(
298+
backend.graphs[0].graph.find_nodes(
299+
op="call_function",
300+
target=torch._higher_order_ops.wrap.tag_activation_checkpoint,
301+
)
302+
),
303+
1,
304+
)
305+
# TODO: add joint to the testing compile backend
306+
fw_outs = {
307+
n.name
308+
for n in backend.fw_graphs[0].graph.find_nodes(op="output")[0].args[0]
309+
}
310+
bw_ins = {
311+
n.name for n in backend.bw_graphs[0].graph.find_nodes(op="placeholder")
312+
}
313+
for node in backend.fw_graphs[0].graph.nodes:
314+
if "recompute" in node.meta:
315+
expected = save_scalar_muls(None, node.target, None, None)
316+
actual = node.meta["recompute"]
317+
self.assertEqual(expected, actual)
318+
if actual == torch.utils.checkpoint.CheckpointPolicy.MUST_SAVE:
319+
self.assertTrue(node.name in fw_outs and node.name in bw_ins)
320+
elif (
321+
actual == torch.utils.checkpoint.CheckpointPolicy.MUST_RECOMPUTE
322+
):
323+
# can still be in fw_outs for post-graph bytecode
324+
self.assertFalse(node.name in bw_ins)
325+
326+
@unittest.skipIf(*get_skip_reasons())
327+
def test_sac_deferred(self):
328+
# This test is in a bit of a weird state, it needs compositional compile API
329+
# so that we can defer inlining for up until AOTAutograd stage 1.
330+
# Then we should be inlined by stage 2. But we can't do that today.
331+
332+
cp_decorated, cp_function = get_local_mapped_functions()
333+
bs = 8 * 1
334+
dim1 = 96
335+
dim2 = dim1 * 4
336+
nheads = 16
337+
seq_len = 16
338+
339+
from torch._dynamo.testing import AotEagerAndRecordGraphs, normalize_gm
340+
341+
backend = AotEagerAndRecordGraphs()
342+
343+
model = create_model(
344+
cp_decorated, nheads, dim1, dim2, sac_policy=save_scalar_muls
345+
)
346+
inputs = (torch.randn(bs, seq_len, dim1, requires_grad=True),)
347+
try:
348+
with (
349+
LocalMapWrappedHigherOrderVariable.enable(),
350+
torch._higher_order_ops.local_map.defer_inlining(),
351+
):
352+
out = torch.compile(model, backend=backend)(*inputs)
353+
out.sum().backward()
354+
except AttributeError as e:
355+
# TODO: get rid of this when we can install as a subgraph
356+
self.assertTrue(
357+
"module 'torch._higher_order_ops.local_map' has no attribute 'call_local_map'"
358+
in str(e)
359+
)
360+
361+
model = create_model(
362+
cp_function, nheads, dim1, dim2, sac_policy=save_scalar_muls
363+
)
364+
inputs = (torch.randn(bs, seq_len, dim1, requires_grad=True),)
365+
try:
366+
with (
367+
LocalMapWrappedHigherOrderVariable.enable(),
368+
torch._higher_order_ops.local_map.defer_inlining(),
369+
):
370+
out = torch.compile(model, backend=backend)(*inputs)
371+
out.sum().backward()
372+
except AttributeError as e:
373+
# TODO: get rid of this when we can install as a subgraph
374+
self.assertTrue(
375+
"module 'torch._higher_order_ops.local_map' has no attribute 'call_local_map'"
376+
in str(e)
377+
)
378+
379+
# TODO: re-enable tests on backward when we can install as a subgraph
380+
if not TEST_WITH_CROSSREF:
381+
self.assertEqual(len(backend.graphs), 2)
196382
self.assertEqual(
197383
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
198384
normalize_gm(backend.graphs[1].print_readable(print_output=False)),
199385
)
386+
self.assertEqual(
387+
normalize_gm(backend.fw_graphs[0].print_readable(print_output=False)),
388+
normalize_gm(backend.fw_graphs[1].print_readable(print_output=False)),
389+
)
390+
# self.assertEqual(
391+
# normalize_gm(backend.bw_graphs[0].print_readable(print_output=False)),
392+
# normalize_gm(backend.bw_graphs[1].print_readable(print_output=False)),
393+
# )
394+
self.assertEqual(
395+
len(
396+
backend.graphs[0].graph.find_nodes(
397+
op="call_function",
398+
target=torch._higher_order_ops.wrap.tag_activation_checkpoint,
399+
)
400+
),
401+
1,
402+
)
403+
# TODO: add joint to the testing compile backend
404+
fw_outs = {
405+
n.name
406+
for n in backend.fw_graphs[0].graph.find_nodes(op="output")[0].args[0]
407+
}
408+
# bw_ins = {
409+
# n.name for n in backend.bw_graphs[0].graph.find_nodes(op="placeholder")
410+
# }
411+
for node in backend.fw_graphs[0].graph.nodes:
412+
if "recompute" in node.meta:
413+
expected = save_scalar_muls(None, node.target, None, None)
414+
actual = node.meta["recompute"]
415+
self.assertEqual(expected, actual)
416+
if actual == torch.utils.checkpoint.CheckpointPolicy.MUST_SAVE:
417+
self.assertTrue(node.name in fw_outs)
418+
# self.assertTrue(node.name in fw_outs and node.name in bw_ins)
419+
# elif (
420+
# actual == torch.utils.checkpoint.CheckpointPolicy.MUST_RECOMPUTE
421+
# ):
422+
# # can still be in fw_outs for post-graph bytecode
423+
# self.assertFalse(node.name in bw_ins)
200424

201425

202426
if __name__ == "__main__":

0 commit comments

Comments
 (0)