3434try :
3535 from flash_attn import flash_attn_varlen_func
3636 from flash_attn .bert_padding import index_first_axis , pad_input , unpad_input # noqa
37- except :
37+ except ImportError :
3838 # flash_attn may not be available but it is not required
3939 pass
4040
4141try :
4242 from sageattention import sageattn
43- except :
43+ except ImportError :
4444 pass
4545
4646try :
4747 from apex .normalization import FusedRMSNorm as RMSNorm
48- except :
48+ except ImportError :
4949 import warnings
5050
5151 warnings .warn ("Cannot import apex RMSNorm, switch to vanilla implementation" )
@@ -98,7 +98,7 @@ def forward(self, x: Tensor):
9898 x_dtype = x .dtype
9999 # To handle float8 we need to convert the tensor to float
100100 x = x .float ()
101- rrms = torch .rsqrt (torch .mean (x ** 2 , dim = - 1 , keepdim = True ) + 1e-6 )
101+ rrms = torch .rsqrt (torch .mean (x ** 2 , dim = - 1 , keepdim = True ) + self . eps )
102102 return ((x * rrms ) * self .weight .float ()).to (dtype = x_dtype )
103103
104104
@@ -370,7 +370,7 @@ def forward(
370370 if self .use_sage_attn :
371371 # Handle GQA (Grouped Query Attention) if needed
372372 n_rep = self .n_local_heads // self .n_local_kv_heads
373- if n_rep >= 1 :
373+ if n_rep > 1 :
374374 xk = xk .unsqueeze (3 ).repeat (1 , 1 , 1 , n_rep , 1 ).flatten (2 , 3 )
375375 xv = xv .unsqueeze (3 ).repeat (1 , 1 , 1 , n_rep , 1 ).flatten (2 , 3 )
376376
@@ -379,7 +379,7 @@ def forward(
379379 output = self .flash_attn (xq , xk , xv , x_mask , softmax_scale )
380380 else :
381381 n_rep = self .n_local_heads // self .n_local_kv_heads
382- if n_rep >= 1 :
382+ if n_rep > 1 :
383383 xk = xk .unsqueeze (3 ).repeat (1 , 1 , 1 , n_rep , 1 ).flatten (2 , 3 )
384384 xv = xv .unsqueeze (3 ).repeat (1 , 1 , 1 , n_rep , 1 ).flatten (2 , 3 )
385385
@@ -456,51 +456,47 @@ def sage_attn(self, q: Tensor, k: Tensor, v: Tensor, x_mask: Tensor, softmax_sca
456456 bsz = q .shape [0 ]
457457 seqlen = q .shape [1 ]
458458
459- # Transpose tensors to match SageAttention's expected format (HND layout)
460- q_transposed = q .permute (0 , 2 , 1 , 3 ) # [batch, heads, seq_len, head_dim]
461- k_transposed = k .permute (0 , 2 , 1 , 3 ) # [batch, heads, seq_len, head_dim]
462- v_transposed = v .permute (0 , 2 , 1 , 3 ) # [batch, heads, seq_len, head_dim]
463-
464- # Handle masking for SageAttention
465- # We need to filter out masked positions - this approach handles variable sequence lengths
466- outputs = []
467- for b in range (bsz ):
468- # Find valid token positions from the mask
469- valid_indices = torch .nonzero (x_mask [b ], as_tuple = False ).squeeze (- 1 )
470- if valid_indices .numel () == 0 :
471- # If all tokens are masked, create a zero output
472- batch_output = torch .zeros (
473- seqlen , self .n_local_heads , self .head_dim ,
474- device = q .device , dtype = q .dtype
475- )
476- else :
477- # Extract only valid tokens for this batch
478- batch_q = q_transposed [b , :, valid_indices , :]
479- batch_k = k_transposed [b , :, valid_indices , :]
480- batch_v = v_transposed [b , :, valid_indices , :]
481-
482- # Run SageAttention on valid tokens only
459+ # Transpose to SageAttention's expected HND layout: [batch, heads, seq_len, head_dim]
460+ q_transposed = q .permute (0 , 2 , 1 , 3 )
461+ k_transposed = k .permute (0 , 2 , 1 , 3 )
462+ v_transposed = v .permute (0 , 2 , 1 , 3 )
463+
464+ # Fast path: if all tokens are valid, run batched SageAttention directly
465+ if x_mask .all ():
466+ output = sageattn (
467+ q_transposed , k_transposed , v_transposed ,
468+ tensor_layout = "HND" , is_causal = False , sm_scale = softmax_scale ,
469+ )
470+ # output: [batch, heads, seq_len, head_dim] -> [batch, seq_len, heads, head_dim]
471+ output = output .permute (0 , 2 , 1 , 3 )
472+ else :
473+ # Slow path: per-batch loop to handle variable-length masking
474+ # SageAttention does not support attention masks natively
475+ outputs = []
476+ for b in range (bsz ):
477+ valid_indices = x_mask [b ].nonzero (as_tuple = True )[0 ]
478+ if valid_indices .numel () == 0 :
479+ outputs .append (torch .zeros (
480+ seqlen , self .n_local_heads , self .head_dim ,
481+ device = q .device , dtype = q .dtype ,
482+ ))
483+ continue
484+
483485 batch_output_valid = sageattn (
484- batch_q .unsqueeze (0 ), # Add batch dimension back
485- batch_k .unsqueeze (0 ),
486- batch_v .unsqueeze (0 ),
487- tensor_layout = "HND" ,
488- is_causal = False ,
489- sm_scale = softmax_scale
486+ q_transposed [b :b + 1 , :, valid_indices , :],
487+ k_transposed [b :b + 1 , :, valid_indices , :],
488+ v_transposed [b :b + 1 , :, valid_indices , :],
489+ tensor_layout = "HND" , is_causal = False , sm_scale = softmax_scale ,
490490 )
491-
492- # Create output tensor with zeros for masked positions
491+
493492 batch_output = torch .zeros (
494- seqlen , self .n_local_heads , self .head_dim ,
495- device = q .device , dtype = q .dtype
493+ seqlen , self .n_local_heads , self .head_dim ,
494+ device = q .device , dtype = q .dtype ,
496495 )
497- # Place valid outputs back in the right positions
498496 batch_output [valid_indices ] = batch_output_valid .squeeze (0 ).permute (1 , 0 , 2 )
499-
500- outputs .append (batch_output )
501-
502- # Stack batch outputs and reshape to expected format
503- output = torch .stack (outputs , dim = 0 ) # [batch, seq_len, heads, head_dim]
497+ outputs .append (batch_output )
498+
499+ output = torch .stack (outputs , dim = 0 )
504500 except NameError as e :
505501 raise RuntimeError (
506502 f"Could not load Sage Attention. Please install https://github.com/thu-ml/SageAttention. / Sage Attention を読み込めませんでした。https://github.com/thu-ml/SageAttention をインストールしてください。 / { e } "
@@ -1113,10 +1109,9 @@ def patchify_and_embed(
11131109
11141110 x = x .view (bsz , channels , height // pH , pH , width // pW , pW ).permute (0 , 2 , 4 , 3 , 5 , 1 ).flatten (3 ).flatten (1 , 2 )
11151111
1116- x_mask = torch .zeros (bsz , image_seq_len , dtype = torch .bool , device = device )
1117- for i in range (bsz ):
1118- x [i , :image_seq_len ] = x [i ]
1119- x_mask [i , :image_seq_len ] = True
1112+ # x.shape[1] == image_seq_len after patchify, so this was assigning to itself.
1113+ # The mask can be set without a loop since all samples have the same image_seq_len.
1114+ x_mask = torch .ones (bsz , image_seq_len , dtype = torch .bool , device = device )
11201115
11211116 x = self .x_embedder (x )
11221117
@@ -1389,4 +1384,4 @@ def NextDiT_7B_GQA_patch2_Adaln_Refiner(**kwargs):
13891384 axes_dims = [40 , 40 , 40 ],
13901385 axes_lens = [300 , 512 , 512 ],
13911386 ** kwargs ,
1392- )
1387+ )
0 commit comments