Skip to content

Commit e353f11

Browse files
authored
unittest: test qkvo quantization not equal to 1. (#1314)
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 18fd91a commit e353f11

File tree

1 file changed

+55
-28
lines changed

1 file changed

+55
-28
lines changed

tests/test_trtllm_gen_decode.py

Lines changed: 55 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,19 @@ def test_trtllm_batch_decode_fmha(
137137
}
138138

139139
sm_scale = float(1.0 / (head_dim**0.5))
140-
q = torch.randn(batch_size, num_qo_heads, head_dim, device=device).to(
141-
dtype_map[q_dtype]
142-
)
140+
if q_dtype == "fp8":
141+
q = torch.randn(
142+
batch_size, num_qo_heads, head_dim, dtype=torch.bfloat16, device=device
143+
)
144+
q, q_scale = to_float8(q)
145+
# Reference implementation have functional issue or low precision with fp8, use bfloat16 and fake-quantization instead.
146+
ref_q = q.bfloat16() * q_scale
147+
else:
148+
q = torch.randn(
149+
batch_size, num_qo_heads, head_dim, dtype=dtype_map[q_dtype], device=device
150+
)
151+
q_scale = 1.0
152+
ref_q = q
143153

144154
# Sequence lengths and block tables
145155
seq_lens = [torch.randint(1, MAX_SEQ_LEN, (1,)).item() for _ in range(batch_size)]
@@ -171,20 +181,37 @@ def test_trtllm_batch_decode_fmha(
171181
]
172182
block_id += num_blocks_needed
173183

174-
# Create interleaved KV cache
175-
# kv_cache_shape = (block_id, 2, num_kv_heads, page_size, head_dim)
176-
# Allocate more than needed blocks, block_id is just enough, to mimick real-world cases
177-
kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
178-
q_scale = k_scale = v_scale = 1.0
179-
kv_cache = torch.randn(size=kv_cache_shape, device=device).to(dtype_map[q_dtype])
184+
# Create separate K and V caches
185+
kv_dtype = dtype_map[q_dtype] if q_dtype != "fp8" else torch.bfloat16
186+
k_cache = torch.randn(
187+
num_blocks, num_kv_heads, page_size, head_dim, dtype=kv_dtype, device=device
188+
)
189+
v_cache = torch.randn(
190+
num_blocks, num_kv_heads, page_size, head_dim, dtype=kv_dtype, device=device
191+
)
192+
# Convert K and V separately to fp8 if needed
193+
if kv_cache_dtype.startswith("fp8"):
194+
k_cache, k_scale = to_float8(k_cache)
195+
v_cache, v_scale = to_float8(v_cache)
196+
# use high precision for reference kv_cache to avoid precision/functional issue
197+
ref_kv_type = torch.bfloat16 if q_dtype == "fp8" else dtype_map[q_dtype]
198+
ref_kv_cache = torch.stack(
199+
[k_cache.to(ref_kv_type) * k_scale, v_cache.to(ref_kv_type) * v_scale],
200+
dim=1,
201+
)
202+
else:
203+
k_scale = v_scale = 1.0
204+
ref_kv_cache = torch.stack([k_cache, v_cache], dim=1)
205+
206+
# Combine K and V into interleaved format for the API
207+
kv_cache = torch.stack(
208+
[k_cache, v_cache], dim=1
209+
) # Shape: (num_blocks, 2, num_kv_heads, page_size, head_dim)
180210

181211
# Output type is fp8 when q is fp8, set scale for it.
182212
o_scale = (
183213
1.0 if q_dtype != "fp8" else torch.rand(1).item() * 0.5 + 0.5
184214
) # Scale range: 0.5 ~ 1.0
185-
if kv_cache_dtype.startswith("fp8") and q_dtype != "fp8":
186-
kv_cache, _ = to_float8(kv_cache)
187-
188215
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device)
189216

190217
output = flashinfer.decode.trtllm_batch_decode_with_kv_cache(
@@ -199,9 +226,6 @@ def test_trtllm_batch_decode_fmha(
199226
window_left, # window_left
200227
).squeeze(1)
201228

202-
# Reference implementation have functional issue or low precision with fp8, use half instead.
203-
ref_q = q.half() if q_dtype == "fp8" else q
204-
205229
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
206230
workspace_buffer, kv_layout, use_tensor_cores=True
207231
)
@@ -217,11 +241,6 @@ def test_trtllm_batch_decode_fmha(
217241
kv_last_page_len = seq_lens_tensor % page_size
218242
kv_last_page_len[kv_last_page_len == 0] = page_size
219243

220-
if kv_cache_dtype == "auto":
221-
kv_compute_dtype = dtype_map[q_dtype]
222-
elif kv_cache_dtype == "fp8":
223-
kv_compute_dtype = torch.float8_e4m3fn
224-
225244
wrapper.plan(
226245
kv_indptr,
227246
kv_indices,
@@ -232,18 +251,18 @@ def test_trtllm_batch_decode_fmha(
232251
page_size,
233252
pos_encoding_mode="NONE",
234253
window_left=window_left,
235-
data_type=kv_compute_dtype,
254+
data_type=ref_kv_cache.dtype,
236255
q_data_type=ref_q.dtype,
237256
)
238257

239-
output_ref = wrapper.run(
240-
ref_q, kv_cache, q_scale=q_scale * k_scale, v_scale=v_scale / o_scale
241-
)
258+
output_ref = wrapper.run(ref_q, ref_kv_cache)
242259

243-
rtol, atol = (1e-2, 5e-2) if q_dtype != "fp8" else (1e-1, 1e-1)
260+
rtol, atol = (1e-2, 5e-2) if q_dtype != "fp8" else (5e-2, 7e-2)
244261

245262
# convert to float32 for fp8 is not supported by assert_close
246-
torch.testing.assert_close(output.float(), output_ref.float(), rtol=rtol, atol=atol)
263+
torch.testing.assert_close(
264+
output.float() * o_scale, output_ref.float(), rtol=rtol, atol=atol
265+
)
247266

248267
# test wrapper with trtllm-gen backend
249268
wrapper2 = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper(
@@ -263,9 +282,17 @@ def test_trtllm_batch_decode_fmha(
263282
window_left=window_left,
264283
)
265284
output2 = wrapper2.run(
266-
q.contiguous(), kv_cache, q_scale=q_scale * k_scale, v_scale=v_scale / o_scale
285+
q.contiguous(),
286+
kv_cache,
287+
q_scale=q_scale,
288+
k_scale=k_scale,
289+
v_scale=v_scale / o_scale,
267290
)
268-
torch.testing.assert_close(output2.float(), output.float(), rtol=rtol, atol=atol)
291+
# skip compare due to v_scale, o_scale is not supported in wrapper api yet.
292+
if v_scale == o_scale == 1.0:
293+
torch.testing.assert_close(
294+
output2.float(), output.float(), rtol=rtol, atol=atol
295+
)
269296

270297

271298
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)