55import torch
66import torch .nn as nn
77import torch .nn .functional as F
8- from torch .nn .attention . varlen import varlen_attn
8+ from torch .nn .attention import varlen_attn
99from torch .testing ._internal .common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
1010from torch .testing ._internal .common_device_type import instantiate_device_type_tests
1111from torch .testing ._internal .common_nn import NNTestCase
1212from torch .testing ._internal .common_utils import parametrize , run_tests
13- from torch .utils ._python_dispatch import TorchDispatchMode
1413
1514
1615VarlenShape = namedtuple (
2423}
2524
2625
27- class OpLoggingMode (TorchDispatchMode ):
28- """Logging mode that captures all dispatched operations"""
29-
30- def __init__ (self ):
31- self .called_ops = []
32-
33- def __torch_dispatch__ (self , func , types , args = (), kwargs = None ):
34- op_name = str (func )
35- self .called_ops .append (op_name )
36- return func (* args , ** (kwargs or {}))
37-
38-
3926class AttentionBlock (nn .Module ):
4027 def __init__ (
4128 self , embed_dim : int , num_heads : int , device : torch .device , dtype : torch .dtype
@@ -52,9 +39,12 @@ def __init__(
5239 embed_dim , embed_dim , bias = False , device = device , dtype = dtype
5340 )
5441
55- def get_varlen_qkv (
42+ def forward_varlen (
5643 self ,
5744 x_packed : torch .Tensor ,
45+ cu_seq : torch .Tensor ,
46+ max_len : int ,
47+ is_causal : bool = False ,
5848 ):
5949 qkv = self .qkv_proj (x_packed )
6050 q , k , v = qkv .chunk (3 , dim = - 1 )
@@ -63,51 +53,24 @@ def get_varlen_qkv(
6353 k = k .view (- 1 , self .num_heads , self .head_dim )
6454 v = v .view (- 1 , self .num_heads , self .head_dim )
6555
66- return q , k , v
67-
68- def forward_varlen (
69- self ,
70- x_packed : torch .Tensor ,
71- cu_seq : torch .Tensor ,
72- max_len : int ,
73- is_causal : bool = False ,
74- ):
75- q , k , v = self .get_varlen_qkv (x_packed )
76-
77- attn_out = varlen_attn (q , k , v , cu_seq , cu_seq , max_len , max_len , is_causal )
56+ attn_out = varlen_attn (
57+ q , k , v , cu_seq , cu_seq , max_len , max_len , is_causal = is_causal
58+ )
7859 attn_out = attn_out .view (- 1 , self .embed_dim )
7960
8061 return self .out_proj (attn_out )
8162
82- def forward_sdpa (
83- self ,
84- x_padded : torch .Tensor ,
85- seq_lengths : torch .Tensor ,
86- dtype : torch .dtype ,
87- is_causal : bool = False ,
88- ):
63+ def forward_sdpa (self , x_padded : torch .Tensor , is_causal : bool = False ):
8964 batch_size , seq_len , _ = x_padded .shape
9065
9166 qkv = self .qkv_proj (x_padded )
9267 q , k , v = qkv .chunk (3 , dim = - 1 )
9368
94- mask = (
95- torch .arange (seq_len , device = x_padded .device )[None , :]
96- < seq_lengths [:, None ]
97- )
98-
99- attn_mask = mask [:, None , None , :].expand (
100- batch_size , self .num_heads , seq_len , seq_len
101- )
102-
10369 q = q .view (batch_size , seq_len , self .num_heads , self .head_dim ).transpose (1 , 2 )
10470 k = k .view (batch_size , seq_len , self .num_heads , self .head_dim ).transpose (1 , 2 )
10571 v = v .view (batch_size , seq_len , self .num_heads , self .head_dim ).transpose (1 , 2 )
10672
107- attn_out = F .scaled_dot_product_attention (
108- q , k , v , attn_mask = attn_mask , is_causal = is_causal
109- )
110-
73+ attn_out = F .scaled_dot_product_attention (q , k , v , is_causal = is_causal )
11174 attn_out = (
11275 attn_out .transpose (1 , 2 )
11376 .contiguous ()
@@ -128,9 +91,7 @@ def create_variable_length_batch(
12891 seq_lengths = torch .tensor (seq_lengths , device = device )
12992 total_tokens = seq_lengths .sum ().item ()
13093
131- x_packed = torch .randn (
132- total_tokens , shape .embed_dim , device = device , dtype = dtype , requires_grad = True
133- )
94+ x_packed = torch .randn (total_tokens , shape .embed_dim , device = device , dtype = dtype )
13495
13596 cu_seq = torch .zeros (shape .batch_size + 1 , device = device , dtype = torch .int32 )
13697 cu_seq [1 :] = seq_lengths .cumsum (0 )
@@ -145,7 +106,6 @@ def create_variable_length_batch(
145106 end_idx = start_idx + seq_len
146107 x_padded [i , :seq_len ] = x_packed [start_idx :end_idx ]
147108 start_idx = end_idx
148- x_padded = x_padded .clone ().detach ().requires_grad_ ()
149109
150110 return {
151111 "seq_lengths" : seq_lengths ,
@@ -173,11 +133,7 @@ def test_basic_functionality(self, device, dtype):
173133
174134 total_tokens = shape .batch_size * shape .max_seq_len
175135 x_packed = torch .randn (
176- total_tokens ,
177- shape .embed_dim ,
178- device = device ,
179- dtype = dtype ,
180- requires_grad = True ,
136+ total_tokens , shape .embed_dim , device = device , dtype = dtype
181137 )
182138 cu_seq = torch .tensor (
183139 [0 , shape .max_seq_len , total_tokens ], device = device , dtype = torch .int32
@@ -191,128 +147,6 @@ def test_basic_functionality(self, device, dtype):
191147 self .assertEqual (output .device , torch .device (device ))
192148 self .assertEqual (output .dtype , dtype )
193149
194- varlen_grad_out = torch .ones_like (output )
195-
196- varlen_grad = torch .autograd .grad (
197- outputs = output ,
198- inputs = x_packed ,
199- grad_outputs = varlen_grad_out ,
200- retain_graph = True ,
201- create_graph = False ,
202- allow_unused = False ,
203- )[0 ]
204-
205- self .assertIsNotNone (varlen_grad )
206- self .assertEqual (varlen_grad .shape , x_packed .shape )
207- self .assertEqual (varlen_grad .dtype , x_packed .dtype )
208-
209- @unittest .skipIf (
210- not PLATFORM_SUPPORTS_FLASH_ATTENTION , "Flash Attention not supported"
211- )
212- @parametrize ("dtype" , [torch .bfloat16 , torch .float16 ])
213- def test_custom_op_compliance (self , device , dtype ):
214- torch .manual_seed (42 )
215-
216- shape = VarlenShape (batch_size = 2 , max_seq_len = 512 , embed_dim = 1024 , num_heads = 16 )
217-
218- attention_block = AttentionBlock (
219- shape .embed_dim , shape .num_heads , device , dtype
220- )
221-
222- total_tokens = shape .batch_size * shape .max_seq_len
223- x_packed = torch .randn (
224- total_tokens ,
225- shape .embed_dim ,
226- device = device ,
227- dtype = dtype ,
228- )
229- cu_seq = torch .tensor (
230- [0 , shape .max_seq_len , total_tokens ], device = device , dtype = torch .int32
231- )
232-
233- q , k , v = attention_block .get_varlen_qkv (x_packed )
234-
235- torch .library .opcheck (
236- torch .ops .torch_attn ._varlen_attn ,
237- (q , k , v , cu_seq , cu_seq , shape .max_seq_len , shape .max_seq_len , False ),
238- )
239-
240- out , lse , rng_state = torch .ops .torch_attn ._varlen_attn (
241- q , k , v , cu_seq , cu_seq , shape .max_seq_len , shape .max_seq_len , False
242- )
243- grad_out = torch .randn_like (out )
244-
245- # we don't support double backward
246- # skipping test_autograd_registration, test_aot_dispatch_dynamic, test_aot_dispatch_static
247- torch .library .opcheck (
248- torch .ops .torch_attn ._varlen_attn_backward ,
249- (
250- grad_out ,
251- q ,
252- k ,
253- v ,
254- out ,
255- lse ,
256- cu_seq ,
257- cu_seq ,
258- shape .max_seq_len ,
259- shape .max_seq_len ,
260- False ,
261- rng_state ,
262- ),
263- test_utils = ["test_schema" , "test_faketensor" ],
264- )
265-
266- @unittest .skipIf (
267- not PLATFORM_SUPPORTS_FLASH_ATTENTION , "Flash Attention not supported"
268- )
269- @parametrize ("dtype" , [torch .bfloat16 , torch .float16 ])
270- def test_custom_op_registration (self , device , dtype ):
271- torch .manual_seed (42 )
272-
273- shape = VarlenShape (batch_size = 2 , max_seq_len = 512 , embed_dim = 1024 , num_heads = 16 )
274-
275- attention_block = AttentionBlock (
276- shape .embed_dim , shape .num_heads , device , dtype
277- )
278-
279- total_tokens = shape .batch_size * shape .max_seq_len
280- x_packed = torch .randn (
281- total_tokens ,
282- shape .embed_dim ,
283- device = device ,
284- dtype = dtype ,
285- requires_grad = True ,
286- )
287- cu_seq = torch .tensor (
288- [0 , shape .max_seq_len , total_tokens ], device = device , dtype = torch .int32
289- )
290-
291- compiled_forward = torch .compile (
292- attention_block .forward_varlen , backend = "eager" , fullgraph = True
293- )
294- with OpLoggingMode () as mode :
295- output = compiled_forward (
296- x_packed , cu_seq , shape .max_seq_len , is_causal = False
297- )
298-
299- varlen_grad_out = torch .ones_like (output )
300- _ = torch .autograd .grad (
301- outputs = output ,
302- inputs = x_packed ,
303- grad_outputs = varlen_grad_out ,
304- retain_graph = True ,
305- create_graph = False ,
306- allow_unused = False ,
307- )[0 ]
308-
309- called_ops = mode .called_ops
310-
311- custom_ops_called = any (
312- "torch_attn._varlen_attn" in op for op in called_ops
313- ) and any ("torch_attn._varlen_attn_backward" in op for op in called_ops )
314- assert custom_ops_called
315-
316150 @unittest .skipIf (
317151 not PLATFORM_SUPPORTS_FLASH_ATTENTION , "Flash Attention not supported"
318152 )
@@ -338,10 +172,7 @@ def test_varlen_vs_sdpa(self, device, dtype, is_causal):
338172 is_causal = is_causal ,
339173 )
340174 sdpa_output = attention_block .forward_sdpa (
341- variable_length_batch_data ["x_padded" ],
342- variable_length_batch_data ["seq_lengths" ],
343- dtype = dtype ,
344- is_causal = is_causal ,
175+ variable_length_batch_data ["x_padded" ], is_causal = is_causal
345176 )
346177
347178 tolerances = default_tolerances [dtype ]
@@ -355,44 +186,6 @@ def test_varlen_vs_sdpa(self, device, dtype, is_causal):
355186 torch .testing .assert_close (varlen_seq , sdpa_seq , ** tolerances )
356187 start_idx = end_idx
357188
358- varlen_grad_out = torch .ones_like (varlen_output )
359-
360- sdpa_grad_out = torch .zeros_like (sdpa_output )
361-
362- start_idx = 0
363- for i , seq_len in enumerate (variable_length_batch_data ["seq_lengths" ]):
364- end_idx = start_idx + seq_len
365- sdpa_grad_out [i , :seq_len ] = varlen_grad_out [start_idx :end_idx ]
366- start_idx = end_idx
367-
368- varlen_grad = torch .autograd .grad (
369- outputs = varlen_output ,
370- inputs = variable_length_batch_data ["x_packed" ],
371- grad_outputs = varlen_grad_out ,
372- retain_graph = True ,
373- create_graph = False ,
374- allow_unused = False ,
375- )[0 ]
376-
377- sdpa_grad = torch .autograd .grad (
378- outputs = sdpa_output ,
379- inputs = variable_length_batch_data ["x_padded" ],
380- grad_outputs = sdpa_grad_out ,
381- retain_graph = True ,
382- create_graph = False ,
383- allow_unused = False ,
384- )[0 ]
385-
386- start_idx = 0
387- for i , seq_len in enumerate (variable_length_batch_data ["seq_lengths" ]):
388- end_idx = start_idx + seq_len
389-
390- varlen_grad_seq = varlen_grad [start_idx :end_idx ]
391- sdpa_grad_seq = sdpa_grad [i , :seq_len ]
392-
393- torch .testing .assert_close (varlen_grad_seq , sdpa_grad_seq , ** tolerances )
394- start_idx = end_idx
395-
396189
397190device_types = ("cuda" ,)
398191
0 commit comments