Skip to content

Commit 209b7b1

Browse files
authored
bugfix: Fix missing v_scale for prefill wrapper. (#1416)
<!-- .github/pull_request_template.md --> ## πŸ“Œ Description We do need k,v scale !=1 for llama3 fp4 model. ## πŸ” Related Issues <!-- Link any related issues here --> ## πŸš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### βœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## πŸ§ͺ Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 04534c7 commit 209b7b1

File tree

4 files changed

+19
-5
lines changed

4 files changed

+19
-5
lines changed

β€Žflashinfer/decode.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
canonicalize_torch_dtype,
5858
device_support_pdl,
5959
get_device_sm_count,
60+
is_float8,
6061
register_custom_op,
6162
register_fake_op,
6263
)
@@ -1318,8 +1319,8 @@ def run(
13181319
self._cached_module.run(*run_args)
13191320
if v_scale is not None:
13201321
# TODO(Zihao): fused into kernel
1321-
if out.itemsize == 1:
1322-
out = (out.to(float) * v_scale).to(out.dtype)
1322+
if is_float8(out):
1323+
out = (out.to(torch.float32) * v_scale).to(out.dtype)
13231324
else:
13241325
out *= v_scale
13251326

β€Žflashinfer/prefill.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2092,7 +2092,12 @@ def run(
20922092
]
20932093

20942094
self._cached_module.paged_run(*run_args)
2095-
2095+
if v_scale is not None:
2096+
# TODO(Zihao): fused into kernel
2097+
if is_float8(out):
2098+
out = (out.to(torch.float32) * v_scale).to(out.dtype)
2099+
else:
2100+
out *= v_scale
20962101
return (out, lse) if return_lse else out
20972102

20982103
run_return_lse = functools.partialmethod(run, return_lse=True)

β€Žtests/test_trtllm_gen_context.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,10 @@ def test_trtllm_batch_prefill(
489489
k_scale=k_scale,
490490
v_scale=v_scale / o_scale,
491491
)
492-
# v_scale, o_scale is not supported in wrapper api yet.
492+
# v_scale, o_scale in wrapper is emulated by multiplying output by v_scale instead of fused into kernel.
493493
if v_scale == o_scale == 1.0:
494494
assert (output2 == output).all()
495+
else:
496+
torch.testing.assert_close(
497+
output.float(), output2.float(), rtol=1e-1, atol=1e-1
498+
)

β€Žtests/test_trtllm_gen_decode.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,9 +377,13 @@ def test_trtllm_batch_decode_fmha(
377377
k_scale=k_scale,
378378
v_scale=v_scale / o_scale,
379379
)
380-
# v_scale, o_scale is not supported in wrapper api yet.
380+
# v_scale, o_scale in wrapper is emulated by multiplying output by v_scale instead of fused into kernel.
381381
if v_scale == o_scale == 1.0:
382382
assert (output2 == output).all()
383+
else:
384+
torch.testing.assert_close(
385+
output.float(), output2.float(), rtol=1e-1, atol=1e-1
386+
)
383387

384388

385389
@pytest.mark.parametrize(

0 commit comments

Comments
Β (0)