@@ -137,9 +137,19 @@ def test_trtllm_batch_decode_fmha(
137
137
}
138
138
139
139
sm_scale = float (1.0 / (head_dim ** 0.5 ))
140
- q = torch .randn (batch_size , num_qo_heads , head_dim , device = device ).to (
141
- dtype_map [q_dtype ]
142
- )
140
+ if q_dtype == "fp8" :
141
+ q = torch .randn (
142
+ batch_size , num_qo_heads , head_dim , dtype = torch .bfloat16 , device = device
143
+ )
144
+ q , q_scale = to_float8 (q )
145
+ # Reference implementation have functional issue or low precision with fp8, use bfloat16 and fake-quantization instead.
146
+ ref_q = q .bfloat16 () * q_scale
147
+ else :
148
+ q = torch .randn (
149
+ batch_size , num_qo_heads , head_dim , dtype = dtype_map [q_dtype ], device = device
150
+ )
151
+ q_scale = 1.0
152
+ ref_q = q
143
153
144
154
# Sequence lengths and block tables
145
155
seq_lens = [torch .randint (1 , MAX_SEQ_LEN , (1 ,)).item () for _ in range (batch_size )]
@@ -171,20 +181,37 @@ def test_trtllm_batch_decode_fmha(
171
181
]
172
182
block_id += num_blocks_needed
173
183
174
- # Create interleaved KV cache
175
- # kv_cache_shape = (block_id, 2, num_kv_heads, page_size, head_dim)
176
- # Allocate more than needed blocks, block_id is just enough, to mimick real-world cases
177
- kv_cache_shape = (num_blocks , 2 , num_kv_heads , page_size , head_dim )
178
- q_scale = k_scale = v_scale = 1.0
179
- kv_cache = torch .randn (size = kv_cache_shape , device = device ).to (dtype_map [q_dtype ])
184
+ # Create separate K and V caches
185
+ kv_dtype = dtype_map [q_dtype ] if q_dtype != "fp8" else torch .bfloat16
186
+ k_cache = torch .randn (
187
+ num_blocks , num_kv_heads , page_size , head_dim , dtype = kv_dtype , device = device
188
+ )
189
+ v_cache = torch .randn (
190
+ num_blocks , num_kv_heads , page_size , head_dim , dtype = kv_dtype , device = device
191
+ )
192
+ # Convert K and V separately to fp8 if needed
193
+ if kv_cache_dtype .startswith ("fp8" ):
194
+ k_cache , k_scale = to_float8 (k_cache )
195
+ v_cache , v_scale = to_float8 (v_cache )
196
+ # use high precision for reference kv_cache to avoid precision/functional issue
197
+ ref_kv_type = torch .bfloat16 if q_dtype == "fp8" else dtype_map [q_dtype ]
198
+ ref_kv_cache = torch .stack (
199
+ [k_cache .to (ref_kv_type ) * k_scale , v_cache .to (ref_kv_type ) * v_scale ],
200
+ dim = 1 ,
201
+ )
202
+ else :
203
+ k_scale = v_scale = 1.0
204
+ ref_kv_cache = torch .stack ([k_cache , v_cache ], dim = 1 )
205
+
206
+ # Combine K and V into interleaved format for the API
207
+ kv_cache = torch .stack (
208
+ [k_cache , v_cache ], dim = 1
209
+ ) # Shape: (num_blocks, 2, num_kv_heads, page_size, head_dim)
180
210
181
211
# Output type is fp8 when q is fp8, set scale for it.
182
212
o_scale = (
183
213
1.0 if q_dtype != "fp8" else torch .rand (1 ).item () * 0.5 + 0.5
184
214
) # Scale range: 0.5 ~ 1.0
185
- if kv_cache_dtype .startswith ("fp8" ) and q_dtype != "fp8" :
186
- kv_cache , _ = to_float8 (kv_cache )
187
-
188
215
workspace_buffer = torch .empty (128 * 1024 * 1024 , dtype = torch .int8 , device = device )
189
216
190
217
output = flashinfer .decode .trtllm_batch_decode_with_kv_cache (
@@ -199,9 +226,6 @@ def test_trtllm_batch_decode_fmha(
199
226
window_left , # window_left
200
227
).squeeze (1 )
201
228
202
- # Reference implementation have functional issue or low precision with fp8, use half instead.
203
- ref_q = q .half () if q_dtype == "fp8" else q
204
-
205
229
wrapper = flashinfer .BatchDecodeWithPagedKVCacheWrapper (
206
230
workspace_buffer , kv_layout , use_tensor_cores = True
207
231
)
@@ -217,11 +241,6 @@ def test_trtllm_batch_decode_fmha(
217
241
kv_last_page_len = seq_lens_tensor % page_size
218
242
kv_last_page_len [kv_last_page_len == 0 ] = page_size
219
243
220
- if kv_cache_dtype == "auto" :
221
- kv_compute_dtype = dtype_map [q_dtype ]
222
- elif kv_cache_dtype == "fp8" :
223
- kv_compute_dtype = torch .float8_e4m3fn
224
-
225
244
wrapper .plan (
226
245
kv_indptr ,
227
246
kv_indices ,
@@ -232,18 +251,18 @@ def test_trtllm_batch_decode_fmha(
232
251
page_size ,
233
252
pos_encoding_mode = "NONE" ,
234
253
window_left = window_left ,
235
- data_type = kv_compute_dtype ,
254
+ data_type = ref_kv_cache . dtype ,
236
255
q_data_type = ref_q .dtype ,
237
256
)
238
257
239
- output_ref = wrapper .run (
240
- ref_q , kv_cache , q_scale = q_scale * k_scale , v_scale = v_scale / o_scale
241
- )
258
+ output_ref = wrapper .run (ref_q , ref_kv_cache )
242
259
243
- rtol , atol = (1e-2 , 5e-2 ) if q_dtype != "fp8" else (1e-1 , 1e-1 )
260
+ rtol , atol = (1e-2 , 5e-2 ) if q_dtype != "fp8" else (5e-2 , 7e-2 )
244
261
245
262
# convert to float32 for fp8 is not supported by assert_close
246
- torch .testing .assert_close (output .float (), output_ref .float (), rtol = rtol , atol = atol )
263
+ torch .testing .assert_close (
264
+ output .float () * o_scale , output_ref .float (), rtol = rtol , atol = atol
265
+ )
247
266
248
267
# test wrapper with trtllm-gen backend
249
268
wrapper2 = flashinfer .decode .BatchDecodeWithPagedKVCacheWrapper (
@@ -263,9 +282,17 @@ def test_trtllm_batch_decode_fmha(
263
282
window_left = window_left ,
264
283
)
265
284
output2 = wrapper2 .run (
266
- q .contiguous (), kv_cache , q_scale = q_scale * k_scale , v_scale = v_scale / o_scale
285
+ q .contiguous (),
286
+ kv_cache ,
287
+ q_scale = q_scale ,
288
+ k_scale = k_scale ,
289
+ v_scale = v_scale / o_scale ,
267
290
)
268
- torch .testing .assert_close (output2 .float (), output .float (), rtol = rtol , atol = atol )
291
+ # skip compare due to v_scale, o_scale is not supported in wrapper api yet.
292
+ if v_scale == o_scale == 1.0 :
293
+ torch .testing .assert_close (
294
+ output2 .float (), output .float (), rtol = rtol , atol = atol
295
+ )
269
296
270
297
271
298
@pytest .mark .parametrize (
0 commit comments