@@ -56,12 +56,16 @@ def _attn_fwd_inner(acc, l_i, m_i, q, #
5656 # causal = False
5757 else :
5858 lo , hi = 0 , N_CTX
59- offsetkv_y = offset_y + lo
59+ offsetk_y = offset_y + lo
60+ if dtype == tl .float8e5 :
61+ offsetv_y = offset_y * HEAD_DIM + lo
62+ else :
63+ offsetv_y = offset_y + lo
6064 # loop over k, v and update accumulator
6165 for start_n in tl .range (lo , hi , BLOCK_N , warp_specialize = warp_specialize ):
6266 start_n = tl .multiple_of (start_n , BLOCK_N )
6367 # -- compute qk ----
64- k = desc_k .load ([offsetkv_y , 0 ]).T
68+ k = desc_k .load ([offsetk_y , 0 ]).T
6569 qk = tl .dot (q , k )
6670 if STAGE == 2 :
6771 mask = offs_m [:, None ] >= (start_n + offs_n [None , :])
@@ -86,15 +90,19 @@ def _attn_fwd_inner(acc, l_i, m_i, q, #
8690 else :
8791 acc = acc * alpha [:, None ]
8892 # prepare p and v for the dot
89- v = desc_v .load ([offsetkv_y , 0 ])
93+ if dtype == tl .float8e5 :
94+ v = desc_v .load ([0 , offsetv_y ]).T
95+ else :
96+ v = desc_v .load ([offsetv_y , 0 ])
9097 p = p .to (dtype )
9198 # note that this non transposed v for FP8 is only supported on Blackwell
9299 acc = tl .dot (p , v , acc )
93100 # update m_i and l_i
94101 # place this at the end of the loop to reduce register pressure
95102 l_i = l_i * alpha + l_ij
96103 m_i = m_ij
97- offsetkv_y += BLOCK_N
104+ offsetk_y += BLOCK_N
105+ offsetv_y += BLOCK_N
98106 return acc , l_i , m_i
99107
100108
@@ -105,7 +113,10 @@ def _host_descriptor_pre_hook(nargs):
105113 if not isinstance (nargs ["desc_q" ], TensorDescriptor ):
106114 return
107115 nargs ["desc_q" ].block_shape = [BLOCK_M , HEAD_DIM ]
108- nargs ["desc_v" ].block_shape = [BLOCK_N , HEAD_DIM ]
116+ if nargs ["FP8_OUTPUT" ]:
117+ nargs ["desc_v" ].block_shape = [HEAD_DIM , BLOCK_N ]
118+ else :
119+ nargs ["desc_v" ].block_shape = [BLOCK_N , HEAD_DIM ]
109120 nargs ["desc_k" ].block_shape = [BLOCK_N , HEAD_DIM ]
110121 nargs ["desc_o" ].block_shape = [BLOCK_M , HEAD_DIM ]
111122
@@ -120,7 +131,7 @@ def _host_descriptor_pre_hook(nargs):
120131configs = [
121132 triton .Config ({'BLOCK_M' : BM , 'BLOCK_N' : BN }, num_stages = s , num_warps = w , pre_hook = _host_descriptor_pre_hook ) \
122133 for BM in [64 , 128 ]\
123- for BN in [64 , 128 ]\
134+ for BN in [32 , 64 , 128 ]\
124135 for s in NUM_STAGES_OPTIONS \
125136 for w in [4 , 8 ]\
126137]
@@ -134,7 +145,8 @@ def _host_descriptor_pre_hook(nargs):
134145def keep (conf ):
135146 BLOCK_M = conf .kwargs ["BLOCK_M" ]
136147 BLOCK_N = conf .kwargs ["BLOCK_N" ]
137- return not (torch .cuda .get_device_capability ()[0 ] == 9 and BLOCK_M * BLOCK_N < 128 * 128 and conf .num_warps == 8 )
148+ return not (is_cuda () and torch .cuda .get_device_capability ()[0 ] == 9 and BLOCK_M * BLOCK_N < 128 * 128
149+ and conf .num_warps == 8 )
138150
139151
140152def prune_invalid_configs (configs , named_args , ** kwargs ):
@@ -174,8 +186,12 @@ def _attn_fwd(sm_scale, M, #
174186 y_dim = Z * H * N_CTX
175187 desc_q = _maybe_make_tensor_desc (desc_q , shape = [y_dim , HEAD_DIM ], strides = [HEAD_DIM , 1 ],
176188 block_shape = [BLOCK_M , HEAD_DIM ])
177- desc_v = _maybe_make_tensor_desc (desc_v , shape = [y_dim , HEAD_DIM ], strides = [HEAD_DIM , 1 ],
178- block_shape = [BLOCK_N , HEAD_DIM ])
189+ if FP8_OUTPUT :
190+ desc_v = _maybe_make_tensor_desc (desc_v , shape = [HEAD_DIM , y_dim ], strides = [N_CTX , 1 ],
191+ block_shape = [HEAD_DIM , BLOCK_N ])
192+ else :
193+ desc_v = _maybe_make_tensor_desc (desc_v , shape = [y_dim , HEAD_DIM ], strides = [HEAD_DIM , 1 ],
194+ block_shape = [BLOCK_N , HEAD_DIM ])
179195 desc_k = _maybe_make_tensor_desc (desc_k , shape = [y_dim , HEAD_DIM ], strides = [HEAD_DIM , 1 ],
180196 block_shape = [BLOCK_N , HEAD_DIM ])
181197 desc_o = _maybe_make_tensor_desc (desc_o , shape = [y_dim , HEAD_DIM ], strides = [HEAD_DIM , 1 ],
@@ -494,7 +510,12 @@ def forward(ctx, q, k, v, causal, sm_scale, warp_specialize=True):
494510
495511 dummy_block = [1 , 1 ]
496512 desc_q = TensorDescriptor (q , shape = [y_dim , HEAD_DIM_K ], strides = [HEAD_DIM_K , 1 ], block_shape = dummy_block )
497- desc_v = TensorDescriptor (v , shape = [y_dim , HEAD_DIM_K ], strides = [HEAD_DIM_K , 1 ], block_shape = dummy_block )
513+ if q .dtype == torch .float8_e5m2 :
514+ desc_v = TensorDescriptor (v , shape = [HEAD_DIM_K , y_dim ], strides = [q .shape [2 ], 1 ],
515+ block_shape = dummy_block )
516+ else :
517+ desc_v = TensorDescriptor (v , shape = [y_dim , HEAD_DIM_K ], strides = [HEAD_DIM_K , 1 ],
518+ block_shape = dummy_block )
498519 desc_k = TensorDescriptor (k , shape = [y_dim , HEAD_DIM_K ], strides = [HEAD_DIM_K , 1 ], block_shape = dummy_block )
499520 desc_o = TensorDescriptor (o , shape = [y_dim , HEAD_DIM_K ], strides = [HEAD_DIM_K , 1 ], block_shape = dummy_block )
500521 else :
@@ -579,48 +600,74 @@ def backward(ctx, do):
579600
580601attention = _attention .apply
581602
603+ TORCH_HAS_FP8 = hasattr (torch , 'float8_e5m2' )
604+
582605
583606@pytest .mark .parametrize ("Z" , [1 , 4 ])
584607@pytest .mark .parametrize ("H" , [2 , 48 ])
585608@pytest .mark .parametrize ("N_CTX" , [128 , 1024 , (2 if is_hip () else 4 ) * 1024 ])
586609@pytest .mark .parametrize ("HEAD_DIM" , [64 , 128 ])
587610@pytest .mark .parametrize ("causal" , [True ]) # FIXME: Non-causal tests do not pass at the moment.
588611@pytest .mark .parametrize ("warp_specialize" , [False , True ] if is_blackwell () else [False ])
589- def test_op (Z , H , N_CTX , HEAD_DIM , causal , warp_specialize , dtype = torch .float16 ):
612+ @pytest .mark .parametrize ("mode" , ["fwd" , "bwd" ])
613+ @pytest .mark .parametrize ("provider" , ["triton-fp16" ] + (["triton-fp8" ] if TORCH_HAS_FP8 else []))
614+ def test_op (Z , H , N_CTX , HEAD_DIM , causal , warp_specialize , mode , provider , dtype = torch .float16 ):
615+ if mode == "fwd" and "fp16" in provider :
616+ pytest .skip ("Avoid running the forward computation twice." )
617+ if mode == "bwd" and "fp8" in provider :
618+ pytest .skip ("Backward pass with FP8 is not supported." )
590619 torch .manual_seed (20 )
591620 q = (torch .empty ((Z , H , N_CTX , HEAD_DIM ), dtype = dtype , device = DEVICE ).normal_ (mean = 0.0 , std = 0.5 ).requires_grad_ ())
592621 k = (torch .empty ((Z , H , N_CTX , HEAD_DIM ), dtype = dtype , device = DEVICE ).normal_ (mean = 0.0 , std = 0.5 ).requires_grad_ ())
593622 v = (torch .empty ((Z , H , N_CTX , HEAD_DIM ), dtype = dtype , device = DEVICE ).normal_ (mean = 0.0 , std = 0.5 ).requires_grad_ ())
594623 sm_scale = 0.5
595- dout = torch .randn_like (q )
596624 # reference implementation
625+ ref_dtype = dtype
626+ if mode == "fwd" and "fp8" in provider :
627+ ref_dtype = torch .float32
628+ q = q .to (ref_dtype )
629+ k = k .to (ref_dtype )
630+ v = v .to (ref_dtype )
597631 M = torch .tril (torch .ones ((N_CTX , N_CTX ), device = DEVICE ))
598632 p = torch .matmul (q , k .transpose (2 , 3 )) * sm_scale
599633 if causal :
600634 p [:, :, M == 0 ] = float ("-inf" )
601- p = torch .softmax (p .float (), dim = - 1 ).half ()
635+ p = torch .softmax (p .float (), dim = - 1 )
636+ p = p .to (ref_dtype )
602637 # p = torch.exp(p)
603- ref_out = torch .matmul (p , v )
604- ref_out .backward (dout )
605- ref_dv , v .grad = v .grad .clone (), None
606- ref_dk , k .grad = k .grad .clone (), None
607- ref_dq , q .grad = q .grad .clone (), None
638+ ref_out = torch .matmul (p , v ).half ()
639+ if mode == "bwd" :
640+ dout = torch .randn_like (q )
641+ ref_out .backward (dout )
642+ ref_dv , v .grad = v .grad .clone (), None
643+ ref_dk , k .grad = k .grad .clone (), None
644+ ref_dq , q .grad = q .grad .clone (), None
608645 # triton implementation
646+ if mode == "fwd" and "fp8" in provider :
647+ q = q .to (torch .float8_e5m2 )
648+ k = k .to (torch .float8_e5m2 )
649+ v = v .permute (0 , 1 , 3 , 2 ).contiguous ()
650+ v = v .permute (0 , 1 , 3 , 2 )
651+ v = v .to (torch .float8_e5m2 )
609652 tri_out = attention (q , k , v , causal , sm_scale , warp_specialize ).half ()
653+ if mode == "fwd" :
654+ atol = 3 if "fp8" in provider else 1e-2
655+ torch .testing .assert_close (tri_out , ref_out , atol = atol , rtol = 0 )
656+ return
610657 tri_out .backward (dout )
611658 tri_dv , v .grad = v .grad .clone (), None
612659 tri_dk , k .grad = k .grad .clone (), None
613660 tri_dq , q .grad = q .grad .clone (), None
614661 # compare
615- torch .testing .assert_close (ref_out , tri_out , atol = 1e-2 , rtol = 0 )
662+ torch .testing .assert_close (tri_out , ref_out , atol = 1e-2 , rtol = 0 )
616663 rtol = 0.0
617664 # Relative tolerance workaround for known hardware limitation of CDNA2 GPU.
618665 # For details see https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices
619666 if torch .version .hip is not None and triton .runtime .driver .active .get_current_target ().arch == "gfx90a" :
620667 rtol = 1e-2
621- torch .testing .assert_close (ref_dv , tri_dv , atol = 1e-2 , rtol = rtol )
622- torch .testing .assert_close (ref_dk , tri_dk , atol = 1e-2 , rtol = rtol )
623- torch .testing .assert_close (ref_dq , tri_dq , atol = 1e-2 , rtol = rtol )
668+ torch .testing .assert_close (tri_dv , ref_dv , atol = 1e-2 , rtol = rtol )
669+ torch .testing .assert_close (tri_dk , ref_dk , atol = 1e-2 , rtol = rtol )
670+ torch .testing .assert_close (tri_dq , ref_dq , atol = 1e-2 , rtol = rtol )
624671
625672
626673try :
0 commit comments