forked from vllm-project/llm-compressor
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathe2e_utils.py
More file actions
107 lines (88 loc) · 3.49 KB
/
e2e_utils.py
File metadata and controls
107 lines (88 loc) · 3.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from typing import Callable
import torch
import transformers
from datasets import load_dataset
from loguru import logger
from transformers import AutoProcessor, DefaultDataCollator
from llmcompressor import oneshot
from llmcompressor.modifiers.gptq import GPTQModifier
from llmcompressor.modifiers.quantization import QuantizationModifier
from tests.test_timer.timer_utils import log_time
from tests.testing_utils import process_dataset
def load_model(model: str, model_class: str, device_map: str | None = None):
pretrained_model_class = getattr(transformers, model_class)
loaded_model = pretrained_model_class.from_pretrained(
model, dtype="auto", device_map=device_map
)
return loaded_model
@log_time
def _run_oneshot(**oneshot_kwargs):
oneshot(**oneshot_kwargs)
def run_oneshot_for_e2e_testing(
model: str,
model_class: str,
num_calibration_samples: int,
max_seq_length: int,
dataset_id: str,
recipe: str,
dataset_split: str,
dataset_config: str,
scheme: str,
quant_type: str,
shuffle_calibration_samples: bool = True,
data_collator: str | Callable = DefaultDataCollator(),
):
# Load model.
oneshot_kwargs = {}
oneshot_kwargs["data_collator"] = data_collator
loaded_model = load_model(model=model, model_class=model_class)
processor = AutoProcessor.from_pretrained(model)
if dataset_id:
ds = load_dataset(dataset_id, name=dataset_config, split=dataset_split)
ds = ds.shuffle(seed=42).select(range(num_calibration_samples))
ds = process_dataset(ds, processor, max_seq_length)
oneshot_kwargs["dataset"] = ds
oneshot_kwargs["max_seq_length"] = max_seq_length
oneshot_kwargs["num_calibration_samples"] = num_calibration_samples
# Define a data collator for multimodal inputs.
if "flickr30k" in dataset_id:
def data_collator(batch):
assert len(batch) == 1
return {key: torch.tensor(value) for key, value in batch[0].items()}
oneshot_kwargs["data_collator"] = data_collator
elif "calibration" in dataset_id:
def data_collator(batch):
assert len(batch) == 1
return {
key: (
torch.tensor(value)
if key != "pixel_values"
else torch.tensor(value, dtype=torch.bfloat16).squeeze(0)
)
for key, value in batch[0].items()
}
oneshot_kwargs["data_collator"] = data_collator
oneshot_kwargs["model"] = loaded_model
oneshot_kwargs["shuffle_calibration_samples"] = shuffle_calibration_samples
if recipe:
oneshot_kwargs["recipe"] = recipe
else:
# Test assumes that if a recipe was not provided, using
# a compatible preset sceme
if quant_type == "GPTQ":
oneshot_kwargs["recipe"] = GPTQModifier(
targets="Linear",
scheme=scheme,
actorder=None, # added for consistency with past testing configs
ignore=["lm_head", "re:.*mlp.gate[.].*"],
)
else:
oneshot_kwargs["recipe"] = QuantizationModifier(
targets="Linear",
scheme=scheme,
ignore=["lm_head", "re:.*mlp.gate[.].*"],
)
# Apply quantization.
logger.info("ONESHOT KWARGS", oneshot_kwargs)
_run_oneshot(**oneshot_kwargs)
return oneshot_kwargs["model"], processor