forked from vllm-project/llm-compressor
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathoneshot.py
More file actions
418 lines (367 loc) · 17.5 KB
/
oneshot.py
File metadata and controls
418 lines (367 loc) · 17.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
"""
Oneshot compression entrypoint for post-training model optimization.
Provides the main oneshot compression entry point for applying
quantization, pruning, and other compression techniques to pre-trained
models without additional training. Supports calibration-based compression
with various pipeline configurations for efficient model optimization.
"""
from __future__ import annotations
import os
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING, Callable
from loguru import logger
from torch.utils.data import DataLoader
from transformers import PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin
from llmcompressor.args import parse_args
from llmcompressor.core.session_functions import active_session
from llmcompressor.datasets import get_calibration_dataloader
from llmcompressor.entrypoints.utils import post_process, pre_process
from llmcompressor.modeling.moe_context import moe_calibration_context
from llmcompressor.pipelines import CalibrationPipeline
__all__ = ["Oneshot", "oneshot"]
if TYPE_CHECKING:
from datasets import Dataset, DatasetDict
TOKENIZERS_PARALLELISM_ENV = "TOKENIZERS_PARALLELISM"
class Oneshot:
"""
Class responsible for carrying out one-shot calibration on a pretrained model.
This class handles the entire lifecycle of one-shot calibration, including
preprocessing (model and tokenizer/processor initialization), model optimization
(quantization or sparsification), and postprocessing (saving outputs). The
instructions for model optimization can be specified by using a recipe.
- **Input Keyword Arguments:**
`kwargs` are parsed into:
- `model_args`: Arguments for loading and configuring a pretrained model
(e.g., `AutoModelForCausalLM`).
- `dataset_args`: Arguments for dataset-related configurations, such as
calibration dataloaders.
- `recipe_args`: Arguments for defining and configuring recipes that specify
optimization actions.
Parsers are defined in `src/llmcompressor/args/`.
- **Lifecycle Overview:**
The oneshot calibration lifecycle consists of three steps:
1. **Preprocessing**:
- Instantiates a pretrained model and tokenizer/processor.
- Ensures input and output embedding layers are untied if they share
tensors.
- Patches the model to include additional functionality for saving with
quantization configurations.
2. **Oneshot Calibration**:
- Optimizes the model using a global `CompressionSession` and applies
recipe-defined modifiers (e.g., `GPTQModifier`, `SparseGPTModifier`)
3. **Postprocessing**:
- Saves the model, tokenizer/processor, and configuration to the specified
`output_dir`.
- **Usage:**
```python
oneshot = Oneshot(model=model, recipe=recipe, dataset=dataset)
oneshot()
# Access the processed components
model = oneshot.model
processor = oneshot.processor
recipe = oneshot.recipe
```
Methods:
__init__(**kwargs):
Initializes the `Oneshot` object by parsing input arguments, performing
preprocessing, and setting instance attributes.
__call__(**kwargs):
Performs the one-shot calibration process by preparing a calibration
dataloader, applying recipe modifiers to the model, and executing
postprocessing steps.
save():
Saves the calibrated model and tokenizer/processor to the specified
`output_dir`. Supports saving in compressed formats based on model
arguments.
apply_recipe_modifiers(calibration_dataloader, **kwargs):
Applies lifecycle actions (e.g., `initialize`, `finalize`) using modifiers
defined in the recipe. Each action is executed via the global
`CompressionSession`.
"""
def __init__(
self,
log_dir: str | None = None,
**kwargs,
):
"""
Initializes the `Oneshot` class with provided arguments.
Parses the input keyword arguments into `model_args`, `dataset_args`, and
`recipe_args`. Performs preprocessing to initialize the model and
tokenizer/processor.
:param model_args: ModelArguments parameters, responsible for controlling
model loading and saving logic
:param dataset_args: DatasetArguments parameters, responsible for controlling
dataset loading, preprocessing and dataloader loading
:param recipe_args: RecipeArguments parameters, responsible for containing
recipe-related parameters
:param output_dir: Path to save the output model after carrying out oneshot
:param log_dir: Path to save logs during oneshot run.
Nothing is logged to file if None.
"""
# Disable tokenizer parallelism to prevent warning when using
# multiprocessing for dataset preprocessing. The warning occurs because
# FastTokenizer's internal threading conflicts with dataset.map's num_proc.
# See: https://github.com/vllm-project/llm-compressor/issues/2007
if TOKENIZERS_PARALLELISM_ENV not in os.environ:
os.environ[TOKENIZERS_PARALLELISM_ENV] = "false"
logger.warning(
"Disabling tokenizer parallelism due to threading conflict between "
"FastTokenizer and Datasets. Set "
f"{TOKENIZERS_PARALLELISM_ENV}=false to "
"suppress this warning."
)
# Set up file logging (no default files):
# 1) If LLM_COMPRESSOR_LOG_FILE is set, log to that file.
# 2) Else, if an explicit log_dir is provided, create a timestamped file there.
log_file = os.environ.get("LLM_COMPRESSOR_LOG_FILE", "").strip()
if log_file:
p = Path(log_file).expanduser()
p.parent.mkdir(parents=True, exist_ok=True)
logger.add(
str(p),
level="DEBUG",
)
elif log_dir:
os.makedirs(log_dir, exist_ok=True)
date_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
logger.add(
f"{log_dir}/oneshot_{date_str}.log",
level="DEBUG",
)
model_args, dataset_args, recipe_args, output_dir = parse_args(**kwargs)
self.model_args = model_args
self.dataset_args = dataset_args
self.recipe_args = recipe_args
self.output_dir = output_dir
# initialize the model and processor
pre_process(model_args, dataset_args, output_dir)
# Set instance attributes
self.model = self.model_args.model
self.processor = self.model_args.processor
self.recipe = self.recipe_args.recipe
def __call__(self):
"""
Performs one-shot calibration.
This method prepares a calibration dataloader using dataset arguments and
applies recipe-based modifiers to optimize the model. The lifecycle actions
are executed sequentially, and the modified model is saved during
postprocessing.
"""
calibration_dataloader = get_calibration_dataloader(
self.dataset_args, self.processor
)
self.apply_recipe_modifiers(
calibration_dataloader=calibration_dataloader,
recipe_stage=self.recipe_args.stage,
)
post_process(
model_args=self.model_args,
recipe_args=self.recipe_args,
output_dir=self.output_dir,
)
def apply_recipe_modifiers(
self,
calibration_dataloader: DataLoader | None,
recipe_stage: str | None = None,
):
"""
Applies recipe modifiers to the model during the lifecycle.
The modifiers are defined in the recipe and executed via lifecycle actions
(`initialize`, `finalize`) through the global `CompressionSession`.
:param: calibration_dataloader: Dataloader for calibration data.
Raises:
RuntimeError: If any modifier fails during execution.
"""
session = active_session()
session.reset()
# (Helen INFERENG-661): validate recipe modifiers before initialization
# Apply MoE calibration context for the entire calibration process
with moe_calibration_context(
self.model,
calibrate_all_experts=self.dataset_args.moe_calibrate_all_experts,
):
session.initialize(
model=self.model,
start=-1,
recipe=self.recipe,
recipe_stage=recipe_stage,
recipe_args=self.recipe_args.recipe_args,
calib_data=calibration_dataloader,
sequential_targets=self.dataset_args.sequential_targets,
)
user_pipeline = self.dataset_args.pipeline
pipeline = CalibrationPipeline.from_modifiers(
session.lifecycle.recipe.modifiers, user=user_pipeline
)
pipeline(
self.model,
calibration_dataloader,
self.dataset_args,
)
session.finalize()
def oneshot(
# Model arguments
model: str | PreTrainedModel,
config_name: str | None = None,
tokenizer: str | PreTrainedTokenizerBase | None = None,
processor: str | ProcessorMixin | None = None,
use_auth_token: bool = False,
precision: str = "auto",
tie_word_embeddings: bool = True,
trust_remote_code_model: bool = False,
save_compressed: bool = True,
model_revision: str = "main",
# Recipe arguments
recipe: str | list[str] | None = None,
recipe_args: list[str] | None = None,
clear_sparse_session: bool = False,
stage: str | None = None,
# Dataset arguments
dataset: str | Dataset | DatasetDict | DataLoader | None = None,
dataset_config_name: str | None = None,
dataset_path: str | None = None,
splits: str | list[str] | dict[str, str] | None = None,
batch_size: int = 1,
data_collator: str | Callable = "truncation",
num_calibration_samples: int = 512,
shuffle_calibration_samples: bool = True,
max_seq_length: int = 384,
pad_to_max_length: bool = True,
text_column: str = "text",
concatenate_data: bool = False,
streaming: bool = False,
overwrite_cache: bool = False,
preprocessing_num_workers: int | None = None,
dataloader_num_workers: int = 0,
min_tokens_per_module: float | None = None,
moe_calibrate_all_experts: bool = True,
pipeline: str | None = "independent",
tracing_ignore: list[str] = [
"_update_causal_mask",
"create_causal_mask",
"_update_mamba_mask",
"make_causal_mask",
"get_causal_mask",
"mask_interface",
"mask_function",
"_prepare_4d_causal_attention_mask",
"_prepare_fsmt_decoder_inputs",
"_prepare_4d_causal_attention_mask_with_cache_position",
"_update_linear_attn_mask",
"project_per_layer_inputs",
],
sequential_targets: list[str] | None = None,
sequential_offload_device: str = "cpu",
sequential_weight_offload_device: str = "cpu",
quantization_aware_calibration: bool = True,
sequential_prefetch: bool = False,
# Miscellaneous arguments
output_dir: str | None = None,
log_dir: str | None = None,
**kwargs,
) -> PreTrainedModel:
"""
Performs oneshot calibration on a model.
# Model arguments
:param model: A pretrained model identifier from huggingface.co/models or a path
to a local model. Required parameter.
:param distill_teacher: Teacher model (a trained text generation model)
for distillation.
:param config_name: Pretrained config name or path if not the same as
model_name.
:param tokenizer: Pretrained tokenizer name or path if not the same as
model_name.
:param processor: Pretrained processor name or path if not the same as
model_name.
:param use_auth_token: Whether to use Hugging Face auth token for private
models.
:param precision: Precision to cast model weights to, default to auto.
:param tie_word_embeddings: Whether the model's input and output word embeddings
should be left tied if possible. False means always untie.
:param trust_remote_code_model: Whether to allow for custom models to execute
their own modeling files.
:param save_compressed: Whether to compress sparse models during save.
:param model_revision: The specific model version to use (can be branch name,
tag, or commit id).
# Recipe arguments
:param recipe: Path to a LLM Compressor recipe, or a list of paths
to multiple LLM Compressor recipes.
:param recipe_args: List of recipe arguments to evaluate, in the
format "key1=value1", "key2=value2".
:param clear_sparse_session: Whether to clear CompressionSession/
CompressionLifecycle data between runs.
:param stage: The stage of the recipe to use for oneshot.
# Dataset arguments
:param dataset: The dataset to use for calibration. Can be a dataset name
(str, via the datasets library), a HuggingFace Dataset or DatasetDict,
or a pre-built PyTorch DataLoader. When a DataLoader is passed, the
internal dataset-to-dataloader conversion is skipped.
:param dataset_config_name: The configuration name of the dataset
to use.
:param dataset_path: Path to a custom dataset. Supports json, csv, dvc.
:param splits: Optional percentages of each split to download.
:param batch_size: calibration dataset batch size. During calibration,
LLM Compressor disables lm_head output computations to reduce memory
usage from large calibration batch sizes. Large batch sizes may result
excess padding or truncation, depending on the data_collator
:param data_collator: The function to use to form a batch from the dataset. Can
also specify 'truncation' or 'padding' to truncate or pad non-uniform sequence
lengths in a batch. Defaults to 'truncation'.
:param num_calibration_samples: Number of samples to use for one-shot
calibration.
:param shuffle_calibration_samples: Whether to shuffle the dataset before
calibration.
:param max_seq_length: Maximum total input sequence length after tokenization.
:param pad_to_max_length: Whether to pad all samples to `max_seq_length`.
:param text_column: Key to use as the `text` input to tokenizer/processor.
:param concatenate_data: Whether to concatenate datapoints to fill
max_seq_length.
:param streaming: True to stream data from a cloud dataset.
:param overwrite_cache: Whether to overwrite the cached preprocessed datasets.
:param preprocessing_num_workers: Number of processes for dataset preprocessing.
:param dataloader_num_workers: Number of worker processes for data loading. Default
is 0 (safe for low CPU/GPU memory). Set to 2 or more for faster calibration if
you have sufficient RAM. Custom data collators may not work with
multiprocessing.
:param min_tokens_per_module: Minimum percentage of tokens per
module, relevant for MoE models.
:param moe_calibrate_all_experts: Whether to calibrate all experts during MoE
model calibration. When True, all experts will see all tokens during
calibration, ensuring proper quantization statistics. When False, only
routed experts will be used. Only relevant for MoE models. Default is True.
:param pipeline: Calibration pipeline used to calibrate model Options:
['basic', 'datafree', 'sequential', 'independent']
:param tracing_ignore: List of functions to ignore during tracing, either
{module}.{method_name} or {function_name}
:param sequential_targets: List of layer targets for the sequential pipeline.
This is typically a single DecoderLayer. Not specifying this argument will
cause the sequential pipeline to default to using the `no_split_params`
specified by the HF model definition
:param sequential_offload_device: Device used to offload intermediate activations
between sequential layers. It is recommended to use `cuda:1` if using more
than one gpu. Default is cpu.
:param sequential_weight_offload_device: Device used to offload model weights
in the sequential pipeline. Set to `none` to disable weight offloading and
keep weights on the main execution device. Default is cpu.
:param quantization_aware_calibration: Whether to enable quantization-aware
calibration in the sequential pipeline. When True, quantization is applied
during forward pass in calibration. When False, quantization is disabled
during forward pass in calibration. Default is set to True.
:param sequential_prefetch: When using the sequential pipeline, prefetch the
next batch in a background thread to overlap onload with forward. Default
False; set True for faster calibration when GPU memory allows.
# Miscellaneous arguments
:param output_dir: Path to save the output model after calibration.
Nothing is saved if None.
:param log_dir: Path to save logs during oneshot run.
Nothing is logged to file if None.
:return: The calibrated PreTrainedModel
"""
# pass all args directly into Oneshot
local_args = {
k: v for k, v in locals().items() if k not in ("local_args", "kwargs")
}
one_shot = Oneshot(**local_args, **kwargs)
one_shot()
return one_shot.model