@@ -36,6 +36,10 @@ def supports_host_descriptor():
3636 return is_cuda () and torch .cuda .get_device_capability ()[0 ] >= 9
3737
3838
39+ def is_blackwell ():
40+ return is_cuda () and torch .cuda .get_device_capability ()[0 ] == 10
41+
42+
3943@triton .jit
4044def _attn_fwd_inner (acc , l_i , m_i , q , #
4145 desc_k , desc_v , #
@@ -115,7 +119,7 @@ def _host_descriptor_pre_hook(nargs):
115119if "PYTEST_VERSION" in os .environ :
116120 # Use a single config in testing for reproducibility
117121 configs = [
118- triton .Config (dict (BLOCK_M = 64 , BLOCK_N = 64 ), num_stages = 4 , num_warps = 4 , pre_hook = _host_descriptor_pre_hook ),
122+ triton .Config (dict (BLOCK_M = 64 , BLOCK_N = 64 ), num_stages = 2 , num_warps = 4 , pre_hook = _host_descriptor_pre_hook ),
119123 ]
120124
121125
@@ -484,10 +488,10 @@ def forward(ctx, q, k, v, causal, sm_scale, warp_specialize=True):
484488 y_dim = q .shape [0 ] * q .shape [1 ] * q .shape [2 ]
485489
486490 dummy_block = [1 , 1 ]
487- desc_q = TensorDescriptor (q , shape = [y_dim , HEAD_DIM_K ], strides = [HEAD_DIM , 1 ], block_shape = dummy_block )
488- desc_v = TensorDescriptor (v , shape = [y_dim , HEAD_DIM_K ], strides = [HEAD_DIM , 1 ], block_shape = dummy_block )
489- desc_k = TensorDescriptor (k , shape = [y_dim , HEAD_DIM_K ], strides = [HEAD_DIM , 1 ], block_shape = dummy_block )
490- desc_o = TensorDescriptor (o , shape = [y_dim , HEAD_DIM_K ], strides = [HEAD_DIM , 1 ], block_shape = dummy_block )
491+ desc_q = TensorDescriptor (q , shape = [y_dim , HEAD_DIM_K ], strides = [HEAD_DIM_K , 1 ], block_shape = dummy_block )
492+ desc_v = TensorDescriptor (v , shape = [y_dim , HEAD_DIM_K ], strides = [HEAD_DIM_K , 1 ], block_shape = dummy_block )
493+ desc_k = TensorDescriptor (k , shape = [y_dim , HEAD_DIM_K ], strides = [HEAD_DIM_K , 1 ], block_shape = dummy_block )
494+ desc_o = TensorDescriptor (o , shape = [y_dim , HEAD_DIM_K ], strides = [HEAD_DIM_K , 1 ], block_shape = dummy_block )
491495 else :
492496 desc_q = q
493497 desc_v = v
@@ -510,7 +514,7 @@ def grid(META):
510514 q .shape [0 ], q .shape [1 ], #
511515 desc_q , desc_k , desc_v , desc_o , #
512516 N_CTX = q .shape [2 ], #
513- HEAD_DIM = HEAD_DIM , #
517+ HEAD_DIM = HEAD_DIM_K , #
514518 FP8_OUTPUT = q .dtype == torch .float8_e5m2 , #
515519 STAGE = stage , #
516520 warp_specialize = warp_specialize , #
@@ -568,17 +572,12 @@ def backward(ctx, do):
568572attention = _attention .apply
569573
570574
571- @pytest .mark .parametrize ('Z, H, N_CTX, HEAD_DIM' , [
572- (1 , 2 , 1024 , 64 ),
573- (4 , 48 , 128 , 64 ),
574- (4 , 48 , 256 , 64 ),
575- (4 , 48 , 512 , 64 ),
576- (4 , 48 , 1024 , 64 ),
577- (4 , 48 , 2048 , 64 ),
578- (4 , 48 , 4096 , 64 ),
579- ])
580- @pytest .mark .parametrize ("causal" , [True ])
581- @pytest .mark .parametrize ("warp_specialize" , [False , True ])
575+ @pytest .mark .parametrize ("Z" , [1 , 4 ])
576+ @pytest .mark .parametrize ("H" , [2 , 48 ])
577+ @pytest .mark .parametrize ("N_CTX" , [128 , 1024 , (2 if is_hip () else 4 ) * 1024 ])
578+ @pytest .mark .parametrize ("HEAD_DIM" , [64 , 128 ])
579+ @pytest .mark .parametrize ("causal" , [True ]) # FIXME: Non-causal tests do not pass at the moment.
580+ @pytest .mark .parametrize ("warp_specialize" , [False , True ] if is_blackwell () else [False ])
582581def test_op (Z , H , N_CTX , HEAD_DIM , causal , warp_specialize , dtype = torch .float16 ):
583582 torch .manual_seed (20 )
584583 q = (torch .empty ((Z , H , N_CTX , HEAD_DIM ), dtype = dtype , device = DEVICE ).normal_ (mean = 0.0 , std = 0.5 ).requires_grad_ ())
0 commit comments