@@ -56,12 +56,16 @@ def _attn_fwd_inner(acc, l_i, m_i, q, #
56
56
# causal = False
57
57
else :
58
58
lo , hi = 0 , N_CTX
59
- offsetkv_y = offset_y + lo
59
+ offsetk_y = offset_y + lo
60
+ if dtype == tl .float8e5 :
61
+ offsetv_y = offset_y * HEAD_DIM + lo
62
+ else :
63
+ offsetv_y = offset_y + lo
60
64
# loop over k, v and update accumulator
61
65
for start_n in tl .range (lo , hi , BLOCK_N , warp_specialize = warp_specialize ):
62
66
start_n = tl .multiple_of (start_n , BLOCK_N )
63
67
# -- compute qk ----
64
- k = desc_k .load ([offsetkv_y , 0 ]).T
68
+ k = desc_k .load ([offsetk_y , 0 ]).T
65
69
qk = tl .dot (q , k )
66
70
if STAGE == 2 :
67
71
mask = offs_m [:, None ] >= (start_n + offs_n [None , :])
@@ -86,15 +90,19 @@ def _attn_fwd_inner(acc, l_i, m_i, q, #
86
90
else :
87
91
acc = acc * alpha [:, None ]
88
92
# prepare p and v for the dot
89
- v = desc_v .load ([offsetkv_y , 0 ])
93
+ if dtype == tl .float8e5 :
94
+ v = desc_v .load ([0 , offsetv_y ]).T
95
+ else :
96
+ v = desc_v .load ([offsetv_y , 0 ])
90
97
p = p .to (dtype )
91
98
# note that this non transposed v for FP8 is only supported on Blackwell
92
99
acc = tl .dot (p , v , acc )
93
100
# update m_i and l_i
94
101
# place this at the end of the loop to reduce register pressure
95
102
l_i = l_i * alpha + l_ij
96
103
m_i = m_ij
97
- offsetkv_y += BLOCK_N
104
+ offsetk_y += BLOCK_N
105
+ offsetv_y += BLOCK_N
98
106
return acc , l_i , m_i
99
107
100
108
@@ -105,7 +113,10 @@ def _host_descriptor_pre_hook(nargs):
105
113
if not isinstance (nargs ["desc_q" ], TensorDescriptor ):
106
114
return
107
115
nargs ["desc_q" ].block_shape = [BLOCK_M , HEAD_DIM ]
108
- nargs ["desc_v" ].block_shape = [BLOCK_N , HEAD_DIM ]
116
+ if nargs ["FP8_OUTPUT" ]:
117
+ nargs ["desc_v" ].block_shape = [HEAD_DIM , BLOCK_N ]
118
+ else :
119
+ nargs ["desc_v" ].block_shape = [BLOCK_N , HEAD_DIM ]
109
120
nargs ["desc_k" ].block_shape = [BLOCK_N , HEAD_DIM ]
110
121
nargs ["desc_o" ].block_shape = [BLOCK_M , HEAD_DIM ]
111
122
@@ -120,7 +131,7 @@ def _host_descriptor_pre_hook(nargs):
120
131
configs = [
121
132
triton .Config ({'BLOCK_M' : BM , 'BLOCK_N' : BN }, num_stages = s , num_warps = w , pre_hook = _host_descriptor_pre_hook ) \
122
133
for BM in [64 , 128 ]\
123
- for BN in [64 , 128 ]\
134
+ for BN in [32 , 64 , 128 ]\
124
135
for s in NUM_STAGES_OPTIONS \
125
136
for w in [4 , 8 ]\
126
137
]
@@ -134,7 +145,8 @@ def _host_descriptor_pre_hook(nargs):
134
145
def keep (conf ):
135
146
BLOCK_M = conf .kwargs ["BLOCK_M" ]
136
147
BLOCK_N = conf .kwargs ["BLOCK_N" ]
137
- return not (torch .cuda .get_device_capability ()[0 ] == 9 and BLOCK_M * BLOCK_N < 128 * 128 and conf .num_warps == 8 )
148
+ return not (is_cuda () and torch .cuda .get_device_capability ()[0 ] == 9 and BLOCK_M * BLOCK_N < 128 * 128
149
+ and conf .num_warps == 8 )
138
150
139
151
140
152
def prune_invalid_configs (configs , named_args , ** kwargs ):
@@ -174,8 +186,12 @@ def _attn_fwd(sm_scale, M, #
174
186
y_dim = Z * H * N_CTX
175
187
desc_q = _maybe_make_tensor_desc (desc_q , shape = [y_dim , HEAD_DIM ], strides = [HEAD_DIM , 1 ],
176
188
block_shape = [BLOCK_M , HEAD_DIM ])
177
- desc_v = _maybe_make_tensor_desc (desc_v , shape = [y_dim , HEAD_DIM ], strides = [HEAD_DIM , 1 ],
178
- block_shape = [BLOCK_N , HEAD_DIM ])
189
+ if FP8_OUTPUT :
190
+ desc_v = _maybe_make_tensor_desc (desc_v , shape = [HEAD_DIM , y_dim ], strides = [N_CTX , 1 ],
191
+ block_shape = [HEAD_DIM , BLOCK_N ])
192
+ else :
193
+ desc_v = _maybe_make_tensor_desc (desc_v , shape = [y_dim , HEAD_DIM ], strides = [HEAD_DIM , 1 ],
194
+ block_shape = [BLOCK_N , HEAD_DIM ])
179
195
desc_k = _maybe_make_tensor_desc (desc_k , shape = [y_dim , HEAD_DIM ], strides = [HEAD_DIM , 1 ],
180
196
block_shape = [BLOCK_N , HEAD_DIM ])
181
197
desc_o = _maybe_make_tensor_desc (desc_o , shape = [y_dim , HEAD_DIM ], strides = [HEAD_DIM , 1 ],
@@ -494,7 +510,12 @@ def forward(ctx, q, k, v, causal, sm_scale, warp_specialize=True):
494
510
495
511
dummy_block = [1 , 1 ]
496
512
desc_q = TensorDescriptor (q , shape = [y_dim , HEAD_DIM_K ], strides = [HEAD_DIM_K , 1 ], block_shape = dummy_block )
497
- desc_v = TensorDescriptor (v , shape = [y_dim , HEAD_DIM_K ], strides = [HEAD_DIM_K , 1 ], block_shape = dummy_block )
513
+ if q .dtype == torch .float8_e5m2 :
514
+ desc_v = TensorDescriptor (v , shape = [HEAD_DIM_K , y_dim ], strides = [q .shape [2 ], 1 ],
515
+ block_shape = dummy_block )
516
+ else :
517
+ desc_v = TensorDescriptor (v , shape = [y_dim , HEAD_DIM_K ], strides = [HEAD_DIM_K , 1 ],
518
+ block_shape = dummy_block )
498
519
desc_k = TensorDescriptor (k , shape = [y_dim , HEAD_DIM_K ], strides = [HEAD_DIM_K , 1 ], block_shape = dummy_block )
499
520
desc_o = TensorDescriptor (o , shape = [y_dim , HEAD_DIM_K ], strides = [HEAD_DIM_K , 1 ], block_shape = dummy_block )
500
521
else :
@@ -579,48 +600,74 @@ def backward(ctx, do):
579
600
580
601
attention = _attention .apply
581
602
603
+ TORCH_HAS_FP8 = hasattr (torch , 'float8_e5m2' )
604
+
582
605
583
606
@pytest .mark .parametrize ("Z" , [1 , 4 ])
584
607
@pytest .mark .parametrize ("H" , [2 , 48 ])
585
608
@pytest .mark .parametrize ("N_CTX" , [128 , 1024 , (2 if is_hip () else 4 ) * 1024 ])
586
609
@pytest .mark .parametrize ("HEAD_DIM" , [64 , 128 ])
587
610
@pytest .mark .parametrize ("causal" , [True ]) # FIXME: Non-causal tests do not pass at the moment.
588
611
@pytest .mark .parametrize ("warp_specialize" , [False , True ] if is_blackwell () else [False ])
589
- def test_op (Z , H , N_CTX , HEAD_DIM , causal , warp_specialize , dtype = torch .float16 ):
612
+ @pytest .mark .parametrize ("mode" , ["fwd" , "bwd" ])
613
+ @pytest .mark .parametrize ("provider" , ["triton-fp16" ] + (["triton-fp8" ] if TORCH_HAS_FP8 else []))
614
+ def test_op (Z , H , N_CTX , HEAD_DIM , causal , warp_specialize , mode , provider , dtype = torch .float16 ):
615
+ if mode == "fwd" and "fp16" in provider :
616
+ pytest .skip ("Avoid running the forward computation twice." )
617
+ if mode == "bwd" and "fp8" in provider :
618
+ pytest .skip ("Backward pass with FP8 is not supported." )
590
619
torch .manual_seed (20 )
591
620
q = (torch .empty ((Z , H , N_CTX , HEAD_DIM ), dtype = dtype , device = DEVICE ).normal_ (mean = 0.0 , std = 0.5 ).requires_grad_ ())
592
621
k = (torch .empty ((Z , H , N_CTX , HEAD_DIM ), dtype = dtype , device = DEVICE ).normal_ (mean = 0.0 , std = 0.5 ).requires_grad_ ())
593
622
v = (torch .empty ((Z , H , N_CTX , HEAD_DIM ), dtype = dtype , device = DEVICE ).normal_ (mean = 0.0 , std = 0.5 ).requires_grad_ ())
594
623
sm_scale = 0.5
595
- dout = torch .randn_like (q )
596
624
# reference implementation
625
+ ref_dtype = dtype
626
+ if mode == "fwd" and "fp8" in provider :
627
+ ref_dtype = torch .float32
628
+ q = q .to (ref_dtype )
629
+ k = k .to (ref_dtype )
630
+ v = v .to (ref_dtype )
597
631
M = torch .tril (torch .ones ((N_CTX , N_CTX ), device = DEVICE ))
598
632
p = torch .matmul (q , k .transpose (2 , 3 )) * sm_scale
599
633
if causal :
600
634
p [:, :, M == 0 ] = float ("-inf" )
601
- p = torch .softmax (p .float (), dim = - 1 ).half ()
635
+ p = torch .softmax (p .float (), dim = - 1 )
636
+ p = p .to (ref_dtype )
602
637
# p = torch.exp(p)
603
- ref_out = torch .matmul (p , v )
604
- ref_out .backward (dout )
605
- ref_dv , v .grad = v .grad .clone (), None
606
- ref_dk , k .grad = k .grad .clone (), None
607
- ref_dq , q .grad = q .grad .clone (), None
638
+ ref_out = torch .matmul (p , v ).half ()
639
+ if mode == "bwd" :
640
+ dout = torch .randn_like (q )
641
+ ref_out .backward (dout )
642
+ ref_dv , v .grad = v .grad .clone (), None
643
+ ref_dk , k .grad = k .grad .clone (), None
644
+ ref_dq , q .grad = q .grad .clone (), None
608
645
# triton implementation
646
+ if mode == "fwd" and "fp8" in provider :
647
+ q = q .to (torch .float8_e5m2 )
648
+ k = k .to (torch .float8_e5m2 )
649
+ v = v .permute (0 , 1 , 3 , 2 ).contiguous ()
650
+ v = v .permute (0 , 1 , 3 , 2 )
651
+ v = v .to (torch .float8_e5m2 )
609
652
tri_out = attention (q , k , v , causal , sm_scale , warp_specialize ).half ()
653
+ if mode == "fwd" :
654
+ atol = 3 if "fp8" in provider else 1e-2
655
+ torch .testing .assert_close (tri_out , ref_out , atol = atol , rtol = 0 )
656
+ return
610
657
tri_out .backward (dout )
611
658
tri_dv , v .grad = v .grad .clone (), None
612
659
tri_dk , k .grad = k .grad .clone (), None
613
660
tri_dq , q .grad = q .grad .clone (), None
614
661
# compare
615
- torch .testing .assert_close (ref_out , tri_out , atol = 1e-2 , rtol = 0 )
662
+ torch .testing .assert_close (tri_out , ref_out , atol = 1e-2 , rtol = 0 )
616
663
rtol = 0.0
617
664
# Relative tolerance workaround for known hardware limitation of CDNA2 GPU.
618
665
# For details see https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices
619
666
if torch .version .hip is not None and triton .runtime .driver .active .get_current_target ().arch == "gfx90a" :
620
667
rtol = 1e-2
621
- torch .testing .assert_close (ref_dv , tri_dv , atol = 1e-2 , rtol = rtol )
622
- torch .testing .assert_close (ref_dk , tri_dk , atol = 1e-2 , rtol = rtol )
623
- torch .testing .assert_close (ref_dq , tri_dq , atol = 1e-2 , rtol = rtol )
668
+ torch .testing .assert_close (tri_dv , ref_dv , atol = 1e-2 , rtol = rtol )
669
+ torch .testing .assert_close (tri_dk , ref_dk , atol = 1e-2 , rtol = rtol )
670
+ torch .testing .assert_close (tri_dq , ref_dq , atol = 1e-2 , rtol = rtol )
624
671
625
672
626
673
try :
0 commit comments