Skip to content

Commit 0341e30

Browse files
committed
add fp8 attention and longbench task for llama4
Signed-off-by: Mengni Wang <mengni.wang@intel.com>
1 parent 1b867f0 commit 0341e30

File tree

5 files changed

+37
-1
lines changed

5 files changed

+37
-1
lines changed

examples/pytorch/multimodal-modeling/quantization/auto_round/llama4/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ CUDA_VISIBLE_DEVICES=0 bash run_quant.sh --topology=llama4_mxfp4 --input_model=L
3333
```
3434

3535
> Note: You can also enable static quantization for KV cache by adding `--static_kv_dtype fp8` argument to `main.py`, or `--static_kv_dtype=fp8` argument to `run_quant.sh` and `run_benchmark.sh`.
36+
>
37+
> You can also enable static quantization for attention by adding `--static_attention_dtype fp8` argument to `main.py`, or `--static_attention_dtype=fp8` argument to `run_quant.sh` and `run_benchmark.sh`. When enabled, it automatically sets KV cache dtype to fp8 as well.
3638
3739
## 2. Benchmark
3840

examples/pytorch/multimodal-modeling/quantization/auto_round/llama4/main.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,13 @@ def setup_parser():
8080
choices=["fp8", "float8_e4m3fn"],
8181
help="Data type for static quantize key and value."
8282
)
83-
83+
parser.add_argument(
84+
"--static_attention_dtype",
85+
default=None,
86+
type=str,
87+
choices=["fp8", "float8_e4m3fn"],
88+
help="Data type for static quantize query, key and value."
89+
)
8490
parser.add_argument(
8591
"--iters",
8692
"--iter",
@@ -122,6 +128,7 @@ def tune(args):
122128
output_dir=args.output_dir,
123129
processor=processor,
124130
static_kv_dtype=args.static_kv_dtype,
131+
static_attention_dtype=args.static_attention_dtype,
125132
reloading=False,
126133
)
127134
model = prepare(model, qconfig)

examples/pytorch/multimodal-modeling/quantization/auto_round/llama4/requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,6 @@ lm-eval==0.4.9.1
22
setuptools_scm
33
torchao==0.12.0
44
triton==3.3.1
5+
jieba
6+
fuzzywuzzy
7+
rouge

examples/pytorch/multimodal-modeling/quantization/auto_round/llama4/run_benchmark.sh

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ function init_params {
3030
--static_kv_dtype=*)
3131
kv_cache_dtype=$(echo $var |cut -f2 -d=)
3232
;;
33+
--static_attention_dtype=*)
34+
attention_dtype=$(echo $var |cut -f2 -d=)
35+
;;
3336
esac
3437
done
3538

@@ -41,6 +44,7 @@ function run_benchmark {
4144
extra_model_args=""
4245
extra_cmd=""
4346
kv_cache_dtype=${kv_cache_dtype:="auto"}
47+
attention_dtype=${attention_dtype:="auto"}
4448
batch_size=${batch_size:=1}
4549

4650
if [ "${topology}" = "llama4_mxfp4" ]; then
@@ -57,6 +61,10 @@ function run_benchmark {
5761
if [[ "${tasks}" == *"chartqa"* || "${tasks}" == *"mmmu_val"* ]]; then
5862
model="vllm-vlm"
5963
extra_cmd=${extra_cmd}" --apply_chat_template"
64+
elif [[ "${tasks}" == *"longbench"* ]]; then
65+
model="vllm"
66+
extra_cmd="--seed 42 --apply_chat_template --gen_kwargs {\"temperature\":0.0} "
67+
extra_model_args="max_model_len=66000,gpu_memory_utilization=0.7"
6068
else
6169
model="vllm"
6270
fi
@@ -67,6 +75,13 @@ function run_benchmark {
6775
echo "Using FP8 for KV cache"
6876
fi
6977

78+
if [[ "${attention_dtype}" == "fp8" ]]; then
79+
export VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION=0
80+
export VLLM_ATTENTION_BACKEND="FLASHINFER"
81+
kv_cache_dtype="fp8"
82+
echo "Using FP8 Attention"
83+
fi
84+
7085
NCCL_NVLS_ENABLE=0 VLLM_USE_STANDALONE_COMPILE=0 VLLM_WORKER_MULTIPROC_METHOD=spawn \
7186
lm_eval --model ${model} \
7287
--model_args pretrained=${input_model},tensor_parallel_size=${tp_size},${extra_model_args},enable_expert_parallel=True,kv_cache_dtype=${kv_cache_dtype} \

examples/pytorch/multimodal-modeling/quantization/auto_round/llama4/run_quant.sh

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ function init_params {
3131
--static_kv_dtype=*)
3232
kv_cache_dtype=$(echo $var |cut -f2 -d=)
3333
;;
34+
--static_attention_dtype=*)
35+
attention_dtype=$(echo $var |cut -f2 -d=)
36+
;;
3437
*)
3538
echo "Error: No such parameter: ${var}"
3639
exit 1
@@ -46,6 +49,7 @@ function run_tuning {
4649
tuned_checkpoint=${tuned_checkpoint:="saved_results"}
4750
iters=${iters:=0}
4851
kv_cache_dtype=${kv_cache_dtype:="auto"}
52+
attention_dtype=${attention_dtype:="auto"}
4953

5054
if [ "${topology}" = "llama4_mxfp4" ]; then
5155
extra_cmd="--fp_layers lm-head,self_attn,router,vision_model,multi_modal_projector,shared_expert --scheme MXFP4 --export_format auto_round"
@@ -55,6 +59,11 @@ function run_tuning {
5559
extra_cmd=${extra_cmd}" --static_kv_dtype ${kv_cache_dtype}"
5660
fi
5761

62+
if [[ ! "${attention_dtype}" = "auto" ]]; then
63+
extra_cmd=${extra_cmd}" --static_attention_dtype ${attention_dtype}"
64+
fi
65+
66+
5867
python3 main.py \
5968
--model ${input_model} \
6069
--iters ${iters} \

0 commit comments

Comments
 (0)