forked from vllm-project/llm-compressor
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathllama4_example.py
More file actions
95 lines (79 loc) · 2.71 KB
/
llama4_example.py
File metadata and controls
95 lines (79 loc) · 2.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch
from datasets import load_dataset
from transformers import Llama4ForConditionalGeneration, Llama4Processor
from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
# Select model and load it.
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
model = Llama4ForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto")
processor = Llama4Processor.from_pretrained(model_id)
# MoE calibration is now handled automatically by the pipeline.
# The `SequentialLlama4TextMoe` modules (from `llmcompressor.modeling.llama4`)
# will be applied during calibration to enable
# proper expert calibration and vLLM compatibility.
# These replace the original `Llama4TextMoe` class from
# `transformers.models.llama4.modeling_llama4`.
DATASET_ID = "neuralmagic/calibration"
NUM_CALIBRATION_SAMPLES = 20
MAX_SEQUENCE_LENGTH = 8192
ds = load_dataset(DATASET_ID, name="LLM", split=f"train[:{NUM_CALIBRATION_SAMPLES}]")
def preprocess_function(example):
messgages = []
for message in example["messages"]:
messgages.append(
{
"role": message["role"],
"content": [{"type": "text", "text": message["content"]}],
}
)
return processor.apply_chat_template(
messgages,
return_tensors="pt",
padding=False,
truncation=True,
max_length=MAX_SEQUENCE_LENGTH,
tokenize=True,
add_special_tokens=False,
return_dict=True,
add_generation_prompt=False,
)
ds = ds.map(preprocess_function, batched=False, remove_columns=ds.column_names)
def data_collator(batch):
assert len(batch) == 1
return {
key: (
torch.tensor(value)
if key != "pixel_values"
else torch.tensor(value, dtype=torch.bfloat16).squeeze(0)
)
for key, value in batch[0].items()
}
# Configure the quantization algorithm to run.
recipe = QuantizationModifier(
targets="Linear",
scheme="NVFP4",
ignore=[
"re:.*lm_head",
"re:.*self_attn",
"re:.*router",
"re:.*vision_model.*",
"re:.*multi_modal_projector.*",
"Llama4TextAttention",
],
)
# Apply algorithms.
# due to the large size of Llama4, we specify sequential targets such that
# only one MLP is loaded into GPU memory at a time
oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
sequential_targets=["Llama4TextMLP"],
data_collator=data_collator,
)
# Save to disk compressed.
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-NVFP4"
model.save_pretrained(SAVE_DIR)
processor.save_pretrained(SAVE_DIR)