22# flake8: noqa: B950
33
44
5+ import contextlib
6+ import functools
57import unittest
68
79import torch
1214import torch .nn .functional as F
1315from torch import nn
1416from torch ._dynamo .variables .higher_order_ops import LocalMapWrappedHigherOrderVariable
17+ from torch .nn .attention import sdpa_kernel , SDPBackend
18+ from torch .utils .checkpoint import create_selective_checkpoint_contexts
1519
1620
1721if torch .distributed .is_available ():
1822 from torch .distributed ._tensor .experimental import local_map
1923 from torch .distributed .tensor .placement_types import Replicate , Shard
2024
21- from torch .testing ._internal .common_utils import run_tests , TEST_WITH_CROSSREF , TestCase
22- from torch .testing ._internal .triton_utils import requires_cuda_and_triton
25+ from torch .testing ._internal .common_utils import (
26+ run_tests ,
27+ TEST_WITH_CROSSREF ,
28+ TEST_WITH_TORCHDYNAMO ,
29+ TEST_WITH_TORCHINDUCTOR ,
30+ TestCase ,
31+ )
2332
2433
2534nested_compile_region = torch .compiler .nested_compile_region
2635
2736
37+ def get_skip_reasons ():
38+ msg = ""
39+ if not torch .distributed .is_available ():
40+ msg += "Torch distributed not available. "
41+ if TEST_WITH_TORCHINDUCTOR or TEST_WITH_TORCHDYNAMO :
42+ msg += "Already manually torch.compile'd. "
43+
44+ return msg != "" , msg
45+
46+
2847class MyTransform (torch .autograd .Function ):
2948 @staticmethod
3049 def forward (ctx , x ):
@@ -42,7 +61,14 @@ def context_parallel_attention(query, key, value):
4261 return out
4362
4463
45- def create_model (attention_fn , nheads , dim1 , dim2 ):
64+ # NOTE: we use this function directly in the node checks
65+ def save_scalar_muls (ctx , op , * args , ** kwargs ):
66+ if op == torch .ops .aten .mul .Scalar :
67+ return torch .utils .checkpoint .CheckpointPolicy .MUST_SAVE
68+ return torch .utils .checkpoint .CheckpointPolicy .MUST_RECOMPUTE
69+
70+
71+ def create_model (attention_fn , nheads , dim1 , dim2 , sac_policy = None ):
4672 class LocalMapTransformerBlock (nn .Module ):
4773 def __init__ (self , nheads , dim1 , dim2 ):
4874 super ().__init__ ()
@@ -54,8 +80,14 @@ def __init__(self, nheads, dim1, dim2):
5480 self .wo = nn .Linear (dim1 , dim1 , bias = bias )
5581 self .w1 = nn .Linear (dim1 , dim2 , bias = bias )
5682 self .w2 = nn .Linear (dim2 , dim1 , bias = bias )
57-
58- def forward (self , x ):
83+ if sac_policy :
84+ self .sac_context_fn = functools .partial (
85+ create_selective_checkpoint_contexts , sac_policy
86+ )
87+ else :
88+ self .sac_context_fn = None
89+
90+ def _forward (self , x ):
5991 q = self .wq (x )
6092 k = self .wk (x )
6193 v = self .wv (x )
@@ -78,41 +110,63 @@ def forward(self, x):
78110 o = o0 + o
79111 return o
80112
113+ def forward (self , x ):
114+ if self .sac_context_fn is not None :
115+ return torch .utils .checkpoint .checkpoint (
116+ self ._forward ,
117+ x ,
118+ use_reentrant = False ,
119+ context_fn = self .sac_context_fn ,
120+ )
121+ return self ._forward (x )
122+
81123 return LocalMapTransformerBlock (nheads , dim1 , dim2 )
82124
83125
84- class TestLocalMap (TestCase ):
85- @requires_cuda_and_triton
86- @unittest .skipIf (
87- not torch .distributed .is_available (), "Torch distributed not available."
126+ def get_local_mapped_functions ():
127+ assert torch .distributed .is_available ()
128+
129+ @local_map (
130+ out_placements = ((Shard (0 ), Shard (1 ), Shard (2 )),),
131+ in_placements = (
132+ (Shard (0 ), Shard (1 ), Shard (2 )), # query
133+ (Shard (0 ), Shard (1 ), Replicate ()), # key
134+ (Shard (0 ), Shard (1 ), Replicate ()), # value
135+ ),
136+ redistribute_inputs = True ,
137+ in_grad_placements = None ,
138+ device_mesh = None ,
139+ )
140+ def cp_decorated (query , key , value ):
141+ return context_parallel_attention (query , key , value )
142+
143+ cp_function = local_map (
144+ context_parallel_attention ,
145+ out_placements = (Shard (0 ), Shard (1 ), Shard (2 )),
146+ in_placements = (
147+ (Shard (0 ), Shard (1 ), Shard (2 )), # query
148+ (Shard (0 ), Shard (1 ), Replicate ()), # key
149+ (Shard (0 ), Shard (1 ), Replicate ()), # value
150+ ),
151+ redistribute_inputs = True ,
152+ in_grad_placements = None ,
153+ device_mesh = None ,
88154 )
155+
156+ return cp_decorated , cp_function
157+
158+
159+ class TestLocalMap (TestCase ):
160+ def setUp (self ):
161+ self .exit_stack = contextlib .ExitStack ()
162+ self .exit_stack .enter_context (sdpa_kernel (backends = [SDPBackend .MATH ]))
163+
164+ def tearDown (self ):
165+ self .exit_stack .close ()
166+
167+ @unittest .skipIf (* get_skip_reasons ())
89168 def test_simple (self ):
90- @local_map (
91- out_placements = ((Shard (0 ), Shard (1 ), Shard (2 )),),
92- in_placements = (
93- (Shard (0 ), Shard (1 ), Shard (2 )), # query
94- (Shard (0 ), Shard (1 ), Replicate ()), # key
95- (Shard (0 ), Shard (1 ), Replicate ()), # value
96- ),
97- redistribute_inputs = True ,
98- in_grad_placements = None ,
99- device_mesh = None ,
100- )
101- def cp_decorated (query , key , value ):
102- return context_parallel_attention (query , key , value )
103-
104- cp_function = local_map (
105- context_parallel_attention ,
106- out_placements = (Shard (0 ), Shard (1 ), Shard (2 )),
107- in_placements = (
108- (Shard (0 ), Shard (1 ), Shard (2 )), # query
109- (Shard (0 ), Shard (1 ), Replicate ()), # key
110- (Shard (0 ), Shard (1 ), Replicate ()), # value
111- ),
112- redistribute_inputs = True ,
113- in_grad_placements = None ,
114- device_mesh = None ,
115- )
169+ cp_decorated , cp_function = get_local_mapped_functions ()
116170 bs = 8 * 1
117171 dim1 = 96
118172 dim2 = dim1 * 4
@@ -123,21 +177,24 @@ def cp_decorated(query, key, value):
123177
124178 backend = EagerAndRecordGraphs ()
125179
126- model = create_model (cp_decorated , nheads , dim1 , dim2 ). cuda ()
127- inputs = (torch .randn (bs , seq_len , dim1 , requires_grad = True ). cuda () ,)
180+ model = create_model (cp_decorated , nheads , dim1 , dim2 )
181+ inputs = (torch .randn (bs , seq_len , dim1 , requires_grad = True ),)
128182 with LocalMapWrappedHigherOrderVariable .enable ():
129183 out = torch .compile (model , backend = backend )(* inputs )
130184 out .sum ().backward ()
131185
132- model = create_model (cp_function , nheads , dim1 , dim2 ). cuda ()
133- inputs = (torch .randn (bs , seq_len , dim1 , requires_grad = True ). cuda () ,)
186+ model = create_model (cp_function , nheads , dim1 , dim2 )
187+ inputs = (torch .randn (bs , seq_len , dim1 , requires_grad = True ),)
134188 with LocalMapWrappedHigherOrderVariable .enable ():
135189 out = torch .compile (model , backend = backend )(* inputs )
136190 out .sum ().backward ()
137191
138192 if not TEST_WITH_CROSSREF :
139193 self .assertEqual (len (backend .graphs ), 2 )
140- # should see local_map_hop in both
194+ self .assertEqual (
195+ normalize_gm (backend .graphs [0 ].print_readable (print_output = False )),
196+ normalize_gm (backend .graphs [1 ].print_readable (print_output = False )),
197+ )
141198 self .assertExpectedInline (
142199 normalize_gm (backend .graphs [0 ].print_readable (print_output = False )),
143200 """\
@@ -193,10 +250,177 @@ def forward(self, q_1: "f32[8, 16, 16, 6]", k_1: "f32[8, 16, 16, 6]", v_1: "f32[
193250""" ,
194251 )
195252
253+ @unittest .skipIf (* get_skip_reasons ())
254+ def test_sac (self ):
255+ cp_decorated , cp_function = get_local_mapped_functions ()
256+ bs = 8 * 1
257+ dim1 = 96
258+ dim2 = dim1 * 4
259+ nheads = 16
260+ seq_len = 16
261+
262+ from torch ._dynamo .testing import AotEagerAndRecordGraphs , normalize_gm
263+
264+ backend = AotEagerAndRecordGraphs ()
265+
266+ model = create_model (
267+ cp_decorated , nheads , dim1 , dim2 , sac_policy = save_scalar_muls
268+ )
269+ inputs = (torch .randn (bs , seq_len , dim1 , requires_grad = True ),)
270+ with LocalMapWrappedHigherOrderVariable .enable ():
271+ out = torch .compile (model , backend = backend )(* inputs )
272+ out .sum ().backward ()
273+
274+ model = create_model (
275+ cp_function , nheads , dim1 , dim2 , sac_policy = save_scalar_muls
276+ )
277+ inputs = (torch .randn (bs , seq_len , dim1 , requires_grad = True ),)
278+ with LocalMapWrappedHigherOrderVariable .enable ():
279+ out = torch .compile (model , backend = backend )(* inputs )
280+ out .sum ().backward ()
281+
282+ if not TEST_WITH_CROSSREF :
283+ self .assertEqual (len (backend .graphs ), 2 )
284+ self .assertEqual (
285+ normalize_gm (backend .graphs [0 ].print_readable (print_output = False )),
286+ normalize_gm (backend .graphs [1 ].print_readable (print_output = False )),
287+ )
288+ self .assertEqual (
289+ normalize_gm (backend .fw_graphs [0 ].print_readable (print_output = False )),
290+ normalize_gm (backend .fw_graphs [1 ].print_readable (print_output = False )),
291+ )
292+ self .assertEqual (
293+ normalize_gm (backend .bw_graphs [0 ].print_readable (print_output = False )),
294+ normalize_gm (backend .bw_graphs [1 ].print_readable (print_output = False )),
295+ )
296+ self .assertEqual (
297+ len (
298+ backend .graphs [0 ].graph .find_nodes (
299+ op = "call_function" ,
300+ target = torch ._higher_order_ops .wrap .tag_activation_checkpoint ,
301+ )
302+ ),
303+ 1 ,
304+ )
305+ # TODO: add joint to the testing compile backend
306+ fw_outs = {
307+ n .name
308+ for n in backend .fw_graphs [0 ].graph .find_nodes (op = "output" )[0 ].args [0 ]
309+ }
310+ bw_ins = {
311+ n .name for n in backend .bw_graphs [0 ].graph .find_nodes (op = "placeholder" )
312+ }
313+ for node in backend .fw_graphs [0 ].graph .nodes :
314+ if "recompute" in node .meta :
315+ expected = save_scalar_muls (None , node .target , None , None )
316+ actual = node .meta ["recompute" ]
317+ self .assertEqual (expected , actual )
318+ if actual == torch .utils .checkpoint .CheckpointPolicy .MUST_SAVE :
319+ self .assertTrue (node .name in fw_outs and node .name in bw_ins )
320+ elif (
321+ actual == torch .utils .checkpoint .CheckpointPolicy .MUST_RECOMPUTE
322+ ):
323+ # can still be in fw_outs for post-graph bytecode
324+ self .assertFalse (node .name in bw_ins )
325+
326+ @unittest .skipIf (* get_skip_reasons ())
327+ def test_sac_deferred (self ):
328+ # This test is in a bit of a weird state, it needs compositional compile API
329+ # so that we can defer inlining for up until AOTAutograd stage 1.
330+ # Then we should be inlined by stage 2. But we can't do that today.
331+
332+ cp_decorated , cp_function = get_local_mapped_functions ()
333+ bs = 8 * 1
334+ dim1 = 96
335+ dim2 = dim1 * 4
336+ nheads = 16
337+ seq_len = 16
338+
339+ from torch ._dynamo .testing import AotEagerAndRecordGraphs , normalize_gm
340+
341+ backend = AotEagerAndRecordGraphs ()
342+
343+ model = create_model (
344+ cp_decorated , nheads , dim1 , dim2 , sac_policy = save_scalar_muls
345+ )
346+ inputs = (torch .randn (bs , seq_len , dim1 , requires_grad = True ),)
347+ try :
348+ with (
349+ LocalMapWrappedHigherOrderVariable .enable (),
350+ torch ._higher_order_ops .local_map .defer_inlining (),
351+ ):
352+ out = torch .compile (model , backend = backend )(* inputs )
353+ out .sum ().backward ()
354+ except AttributeError as e :
355+ # TODO: get rid of this when we can install as a subgraph
356+ self .assertTrue (
357+ "module 'torch._higher_order_ops.local_map' has no attribute 'call_local_map'"
358+ in str (e )
359+ )
360+
361+ model = create_model (
362+ cp_function , nheads , dim1 , dim2 , sac_policy = save_scalar_muls
363+ )
364+ inputs = (torch .randn (bs , seq_len , dim1 , requires_grad = True ),)
365+ try :
366+ with (
367+ LocalMapWrappedHigherOrderVariable .enable (),
368+ torch ._higher_order_ops .local_map .defer_inlining (),
369+ ):
370+ out = torch .compile (model , backend = backend )(* inputs )
371+ out .sum ().backward ()
372+ except AttributeError as e :
373+ # TODO: get rid of this when we can install as a subgraph
374+ self .assertTrue (
375+ "module 'torch._higher_order_ops.local_map' has no attribute 'call_local_map'"
376+ in str (e )
377+ )
378+
379+ # TODO: re-enable tests on backward when we can install as a subgraph
380+ if not TEST_WITH_CROSSREF :
381+ self .assertEqual (len (backend .graphs ), 2 )
196382 self .assertEqual (
197383 normalize_gm (backend .graphs [0 ].print_readable (print_output = False )),
198384 normalize_gm (backend .graphs [1 ].print_readable (print_output = False )),
199385 )
386+ self .assertEqual (
387+ normalize_gm (backend .fw_graphs [0 ].print_readable (print_output = False )),
388+ normalize_gm (backend .fw_graphs [1 ].print_readable (print_output = False )),
389+ )
390+ # self.assertEqual(
391+ # normalize_gm(backend.bw_graphs[0].print_readable(print_output=False)),
392+ # normalize_gm(backend.bw_graphs[1].print_readable(print_output=False)),
393+ # )
394+ self .assertEqual (
395+ len (
396+ backend .graphs [0 ].graph .find_nodes (
397+ op = "call_function" ,
398+ target = torch ._higher_order_ops .wrap .tag_activation_checkpoint ,
399+ )
400+ ),
401+ 1 ,
402+ )
403+ # TODO: add joint to the testing compile backend
404+ fw_outs = {
405+ n .name
406+ for n in backend .fw_graphs [0 ].graph .find_nodes (op = "output" )[0 ].args [0 ]
407+ }
408+ # bw_ins = {
409+ # n.name for n in backend.bw_graphs[0].graph.find_nodes(op="placeholder")
410+ # }
411+ for node in backend .fw_graphs [0 ].graph .nodes :
412+ if "recompute" in node .meta :
413+ expected = save_scalar_muls (None , node .target , None , None )
414+ actual = node .meta ["recompute" ]
415+ self .assertEqual (expected , actual )
416+ if actual == torch .utils .checkpoint .CheckpointPolicy .MUST_SAVE :
417+ self .assertTrue (node .name in fw_outs )
418+ # self.assertTrue(node.name in fw_outs and node.name in bw_ins)
419+ # elif (
420+ # actual == torch.utils.checkpoint.CheckpointPolicy.MUST_RECOMPUTE
421+ # ):
422+ # # can still be in fw_outs for post-graph bytecode
423+ # self.assertFalse(node.name in bw_ins)
200424
201425
202426if __name__ == "__main__" :
0 commit comments