Skip to content

Commit 612206c

Browse files
authored
add longbench and static_attention_dtype (#2393)
* add longbench and static_attention_dtype Signed-off-by: He, Xin3 <xin3.he@intel.com>
1 parent 8ae3b1a commit 612206c

File tree

5 files changed

+95
-25
lines changed

5 files changed

+95
-25
lines changed

examples/pytorch/nlp/huggingface_models/language-modeling/quantization/auto_round/llama3/README.md

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ In this example, you can verify the accuracy on HPU/CUDA device with emulation o
66

77
```bash
88
# neural-compressor-pt
9-
pip install neural-compressor-pt==3.7
9+
pip install neural-compressor-pt
1010
# auto-round
11-
pip install auto-round==0.9.3
11+
pip install auto-round
1212
# other requirements
1313
pip install -r requirements.txt
1414
```
@@ -79,6 +79,8 @@ Notes:
7979
Here we provide several recipes for Llama3 models. The relative accuracy loss of quantized model should be less than 1%.
8080

8181
> Note: You can also enable static quantization for KV cache by adding `--static_kv_dtype fp8` argument to `quantize.py`, or `--static_kv_dtype=fp8` argument to `run_quant.sh` and `run_benchmark.sh`.
82+
>
83+
> You can also enable static quantization for attention by adding `--static_attention_dtype fp8` argument to `quantize.py`, or `--static_attention_dtype=fp8` argument to `run_quant.sh` and `run_benchmark.sh`. When enabled, it automatically sets KV cache dtype to fp8 as well.
8284
8385
#### Llama 3.1 8B MXFP8
8486

@@ -210,8 +212,10 @@ CUDA_VISIBLE_DEVICES=0,1 bash run_benchmark.sh --model_path=Llama-3.1-70B-MXFP8
210212

211213
The script automatically:
212214
- Detects available GPUs from `CUDA_VISIBLE_DEVICES` and sets `tensor_parallel_size` accordingly
213-
- Runs default tasks: `piqa,hellaswag,mmlu_llama,gsm8k_llama` with batch size 8
215+
- Runs default tasks: `piqa,hellaswag,mmlu_llama,gsm8k_llama` with batch size 64
214216
- Supports custom task selection and batch size adjustment
217+
- Handles special tasks like `mmlu_llama`, `gsm8k_llama` (with chat template) and `longbench` (with extended context length) automatically
218+
- For longbench dataset evaluation, use the `--tasks=longbench` parameter
215219

216220

217221
### NVFP4

examples/pytorch/nlp/huggingface_models/language-modeling/quantization/auto_round/llama3/quantize.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,13 @@ def get_accuracy(model_name_or_path, tokenizer=None, eval_tasks="mmlu", limit=No
169169
choices=["fp8", "float8_e4m3fn"],
170170
help="Data type for static quantize key and value.",
171171
)
172+
parser.add_argument(
173+
"--static_attention_dtype",
174+
default=None,
175+
type=str,
176+
choices=["fp8", "float8_e4m3fn"],
177+
help="Data type for static quantize key and value.",
178+
)
172179
parser.add_argument("--use_recipe", action="store_true", help="whether to use recipe to quantize model")
173180
parser.add_argument("--recipe_file", type=str, default="recipes/Meta-Llama-3.1-8B-Instruct_6bits.json", help="path of recipe file")
174181
parser.add_argument("--iters", default=200, type=int, help="iters for autoround.")
@@ -214,9 +221,23 @@ def get_accuracy(model_name_or_path, tokenizer=None, eval_tasks="mmlu", limit=No
214221
model, tokenizer = initialize_model_and_tokenizer(args.model_name_or_path)
215222

216223
if args.quantize:
217-
if args.dtype in ["uNVFP4", "NVFP4+"]:
218-
from auto_round.schemes import QuantizationScheme
224+
from auto_round.schemes import PRESET_SCHEMES, QuantizationScheme
219225

226+
# Check if RCEIL versions are available and use them instead
227+
use_rceil = "MXFP4_RCEIL" in PRESET_SCHEMES and "MXFP8_RCEIL" in PRESET_SCHEMES
228+
if use_rceil:
229+
# Replace dtype if it's MXFP4 or MXFP8
230+
if args.dtype == "MXFP4":
231+
args.dtype = "MXFP4_RCEIL"
232+
elif args.dtype == "MXFP8":
233+
args.dtype = "MXFP8_RCEIL"
234+
# Replace options list entries
235+
args.options = [
236+
"MXFP4_RCEIL" if opt == "MXFP4" else ("MXFP8_RCEIL" if opt == "MXFP8" else opt)
237+
for opt in args.options
238+
]
239+
240+
if args.dtype in ["uNVFP4", "NVFP4+"]:
220241
uNVFP4 = QuantizationScheme.from_dict(
221242
{
222243
"bits": 4,
@@ -256,6 +277,7 @@ def load_recipe_results(file_path):
256277
options=args.options,
257278
shared_layers=args.shared_layers,
258279
static_kv_dtype=args.static_kv_dtype,
280+
static_attention_dtype=args.static_attention_dtype,
259281
enable_torch_compile=args.enable_torch_compile,
260282
low_gpu_mem_usage=args.low_gpu_mem_usage,
261283
export_format=args.export_format,
Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
transformers==4.57.3
22
torch==2.9.0
33
torchvision==0.24.0
4-
lm_eval==0.4.9.2
4+
lm_eval==0.4.10
55
datasets==4.4.2
6-
auto-round==0.9.3
6+
auto-round>=0.9.3
77
neural-compressor-pt>=3.7
8+
jieba
9+
fuzzywuzzy
10+
rouge
11+
hf_transfer

examples/pytorch/nlp/huggingface_models/language-modeling/quantization/auto_round/llama3/run_benchmark.sh

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ TASKS="piqa,hellaswag,mmlu_llama,gsm8k_llama"
77
BATCH_SIZE=64
88
GPU_MEMORY_UTILIZATION=0.8
99
KV_CACHE_DTYPE="auto"
10+
ATTN_DTYPE="auto"
1011

1112
while [[ $# -gt 0 ]]; do
1213
case $1 in
@@ -30,6 +31,10 @@ while [[ $# -gt 0 ]]; do
3031
KV_CACHE_DTYPE="${1#*=}"
3132
shift
3233
;;
34+
--static_attention_dtype=*)
35+
ATTN_DTYPE="${1#*=}"
36+
shift
37+
;;
3338
*)
3439
echo "Unknown parameter: $1"
3540
exit 1
@@ -44,6 +49,14 @@ if [[ "$KV_CACHE_DTYPE" == "fp8" ]]; then
4449
echo "Using FP8 for KV cache"
4550
fi
4651

52+
# for fp8 attention cache
53+
if [[ "$ATTN_DTYPE" == "fp8" ]]; then
54+
export VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION=0
55+
export VLLM_ATTENTION_BACKEND="FLASHINFER"
56+
KV_CACHE_DTYPE="fp8"
57+
echo "Using FP8 Attention"
58+
fi
59+
4760
# Validate required parameters
4861
if [[ -z "$MODEL_PATH" ]]; then
4962
echo "Usage: bash run_benchmark.sh --model_path=<path_to_quantized_model> [--tasks=<tasks>] [--batch_size=<size>]"
@@ -103,10 +116,11 @@ run_evaluation() {
103116
}
104117

105118

106-
# Check if tasks contain gsm8k_llama or mmlu_llama
119+
# Check if tasks contain gsm8k_llama, mmlu_llama, or longbench
107120
NEED_SPLIT=false
108121
OTHER_TASKS="$TASKS"
109122
SPECIAL_TASKS=""
123+
LONGBENCH_TASK=""
110124

111125
if [[ "$TASKS" == *"gsm8k_llama"* ]]; then
112126
SPECIAL_TASKS="gsm8k_llama"
@@ -122,26 +136,24 @@ if [[ "$TASKS" == *"mmlu_llama"* ]]; then
122136
OTHER_TASKS=$(echo "$OTHER_TASKS" | sed 's/,*mmlu_llama,*//' | sed 's/^,//' | sed 's/,$//')
123137
NEED_SPLIT=true
124138
fi
139+
if [[ "$TASKS" == *"longbench"* ]]; then
140+
LONGBENCH_TASK="longbench"
141+
OTHER_TASKS=$(echo "$OTHER_TASKS" | sed 's/,*longbench,*//' | sed 's/^,//' | sed 's/,$//')
142+
NEED_SPLIT=true
143+
fi
125144

126145
if [[ "$NEED_SPLIT" == true ]]; then
127146
if [[ -n "$OTHER_TASKS" ]]; then
128147
echo "Running general tasks"
129148
run_evaluation "$OTHER_TASKS" true ""
130-
if [[ $? -eq 0 ]]; then
131-
IFS=',' read -ra SPECIAL_ARRAY <<< "$SPECIAL_TASKS"
132-
for special_task in "${SPECIAL_ARRAY[@]}"; do
133-
echo "Running $special_task with chat template"
134-
run_evaluation "$special_task" true "--apply_chat_template --fewshot_as_multiturn"
135-
if [[ $? -ne 0 ]]; then
136-
echo "Benchmark failed on $special_task!"
137-
exit 1
138-
fi
139-
done
140-
else
149+
if [[ $? -ne 0 ]]; then
141150
echo "Skipping special tasks due to previous failure"
142151
exit 1
143152
fi
144-
else
153+
fi
154+
155+
# Run special tasks (gsm8k_llama, mmlu_llama)
156+
if [[ -n "$SPECIAL_TASKS" ]]; then
145157
IFS=',' read -ra SPECIAL_ARRAY <<< "$SPECIAL_TASKS"
146158
for special_task in "${SPECIAL_ARRAY[@]}"; do
147159
echo "Running $special_task with chat template"
@@ -152,6 +164,26 @@ if [[ "$NEED_SPLIT" == true ]]; then
152164
fi
153165
done
154166
fi
167+
168+
# Run longbench task with special configuration
169+
if [[ -n "$LONGBENCH_TASK" ]]; then
170+
echo "Running longbench with special configuration"
171+
local longbench_cmd="lm_eval --model vllm --model_args pretrained=\"$MODEL_PATH\",trust_remote_code=True,dtype=bfloat16,max_model_len=66000,tensor_parallel_size=$TENSOR_PARALLEL_SIZE,gpu_memory_utilization=$GPU_MEMORY_UTILIZATION,enable_prefix_caching=False --tasks longbench --seed 42 --batch_size $BATCH_SIZE --apply_chat_template --gen_kwargs '{\"temperature\":0.0}'"
172+
echo "Executing command: $longbench_cmd"
173+
174+
lm_eval --model vllm \
175+
--model_args pretrained="$MODEL_PATH",trust_remote_code=True,dtype=bfloat16,max_model_len=66000,tensor_parallel_size=$TENSOR_PARALLEL_SIZE,gpu_memory_utilization=$GPU_MEMORY_UTILIZATION,enable_prefix_caching=False \
176+
--tasks longbench \
177+
--seed 42 \
178+
--batch_size $BATCH_SIZE \
179+
--apply_chat_template \
180+
--gen_kwargs '{"temperature":0.0}'
181+
182+
if [[ $? -ne 0 ]]; then
183+
echo "Benchmark failed on longbench!"
184+
exit 1
185+
fi
186+
fi
155187
else
156188
run_evaluation "$TASKS" true ""
157189
fi

examples/pytorch/nlp/huggingface_models/language-modeling/quantization/auto_round/llama3/run_quant.sh

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
# Parse command line arguments
66
KV_CACHE_DTYPE="auto"
7+
STATIC_ATTENTION_DTYPE="auto"
78
while [[ $# -gt 0 ]]; do
89
case $1 in
910
--topology=*)
@@ -26,6 +27,10 @@ while [[ $# -gt 0 ]]; do
2627
KV_CACHE_DTYPE="${1#*=}"
2728
shift
2829
;;
30+
--static_attention_dtype=*)
31+
STATIC_ATTENTION_DTYPE="${1#*=}"
32+
shift
33+
;;
2934
*)
3035
echo "Unknown parameter: $1"
3136
exit 1
@@ -48,10 +53,12 @@ echo " Input Model: $INPUT_MODEL"
4853
echo " Output Model: $OUTPUT_MODEL"
4954

5055
# Set common parameters
51-
if [ "$KV_CACHE_DTYPE" = "auto" ]; then
52-
COMMON_ARGS="--quantize --enable_torch_compile --low_gpu_mem_usage --export_format auto_round"
53-
else
54-
COMMON_ARGS="--quantize --enable_torch_compile --low_gpu_mem_usage --export_format auto_round --static_kv_dtype $KV_CACHE_DTYPE"
56+
COMMON_ARGS="--quantize --enable_torch_compile --low_gpu_mem_usage --export_format auto_round"
57+
if [ "$KV_CACHE_DTYPE" != "auto" ]; then
58+
COMMON_ARGS="$COMMON_ARGS --static_kv_dtype $KV_CACHE_DTYPE"
59+
fi
60+
if [ "$STATIC_ATTENTION_DTYPE" != "auto" ]; then
61+
COMMON_ARGS="$COMMON_ARGS --static_attention_dtype $STATIC_ATTENTION_DTYPE"
5562
fi
5663

5764
case "$TOPOLOGY" in
@@ -81,14 +88,15 @@ case "$TOPOLOGY" in
8188
;;
8289
"mxfp4_mixed")
8390
echo "Running Llama 3.1 8B MXFP4 (Mixed with MXFP8) quantization..."
84-
CMD="python quantize.py --model_name_or_path \"$INPUT_MODEL\" $COMMON_ARGS --target_bits 7.8 --options \"MXFP4\" \"MXFP8\" --shared_layers \"k_proj\" \"v_proj\" \"q_proj\" --shared_layers \"gate_proj\" \"up_proj\" --export_path \"$OUTPUT_MODEL\""
91+
CMD="python quantize.py --model_name_or_path \"$INPUT_MODEL\" $COMMON_ARGS --target_bits 7.8 --iters 0 --options \"MXFP4\" \"MXFP8\" --shared_layers \"k_proj\" \"v_proj\" \"q_proj\" --shared_layers \"gate_proj\" \"up_proj\" --export_path \"$OUTPUT_MODEL\""
8592
echo "Executing command: $CMD"
8693
python quantize.py \
8794
--model_name_or_path "$INPUT_MODEL" \
8895
$COMMON_ARGS \
8996
--target_bits 7.8 \
9097
--options "MXFP4" "MXFP8" \
9198
--shared_layers "k_proj" "v_proj" "q_proj" \
99+
--iters 0 \
92100
--shared_layers "gate_proj" "up_proj" \
93101
--export_path "$OUTPUT_MODEL"
94102
;;

0 commit comments

Comments
 (0)