Skip to content

Commit c56a37b

Browse files
Merge pull request #43 from tharapalanivel/opt_args
OptArguments
2 parents cea5bc7 + 725f0e0 commit c56a37b

File tree

6 files changed

+82
-69
lines changed

6 files changed

+82
-69
lines changed

.spellcheck-en-custom.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,11 @@ eval
2626
fms
2727
fp
2828
FP
29+
FP8Arguments
2930
frac
3031
gptq
3132
GPTQ
32-
GPTQArgs
33+
GPTQArguments
3334
graphviz
3435
GPTQ
3536
hyperparameters

examples/FP8_QUANT/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ This is an example of mature FP8, which under the hood leverages some functional
2727
## QuickStart
2828
This end-to-end example utilizes the common set of interfaces provided by `fms_mo` for easily applying multiple quantization algorithms with FP8 being the focus of this example. The steps involved are:
2929
30-
1. **FP8 quantization through CLI**. Other arguments could be found here [FP8Args](../../fms_mo/training_args.py#L84).
30+
1. **FP8 quantization through CLI**. Other arguments could be found here [FP8Arguments](../../fms_mo/training_args.py#L84).
3131
3232
```bash
3333
python -m fms_mo.run_quant \
@@ -100,7 +100,7 @@ This end-to-end example utilizes the common set of interfaces provided by `fms_m
100100
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
101101
```
102102

103-
2. Quantization setting is provided using `QuantizationModifier`, additional settings can be found in [FP8Args](../../fms_mo/training_args.py#L84).
103+
2. Quantization setting is provided using `QuantizationModifier`, additional settings can be found in [FP8Arguments](../../fms_mo/training_args.py#L84).
104104

105105
```python
106106
recipe = QuantizationModifier(

examples/GPTQ/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ This end-to-end example utilizes the common set of interfaces provided by `fms_m
3232
> - Tokenized data will be saved in `<path_to_save>_train` and `<path_to_save>_test`
3333
> - If you have trouble downloading Llama family of models from Hugging Face ([LLama models require access](https://www.llama.com/docs/getting-the-models/hugging-face/)), you can use `ibm-granite/granite-8b-code` instead
3434
35-
2. **Quantize the model** using the data generated above, the following command will kick off the quantization job (by invoking `auto_gptq` under the hood.) Additional acceptable arguments can be found here in [GPTQArgs](../../fms_mo/training_args.py#L127).
35+
2. **Quantize the model** using the data generated above, the following command will kick off the quantization job (by invoking `auto_gptq` under the hood.) Additional acceptable arguments can be found here in [GPTQArguments](../../fms_mo/training_args.py#L127).
3636
3737
```bash
3838
python -m fms_mo.run_quant \

fms_mo/dq.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
logger = logging.getLogger(__name__)
5252

5353

54-
def run_dq(model_args, data_args, fms_mo_args, output_dir):
54+
def run_dq(model_args, data_args, opt_args, fms_mo_args):
5555
"""
5656
For direct quantization LLMs without optimization:
5757
Models are directly quantized into INT8 or FP8 precisions using
@@ -63,8 +63,9 @@ def run_dq(model_args, data_args, fms_mo_args, output_dir):
6363
the model
6464
data_args (fms_mo.training_args.DataArguments): Data arguments to be used when loading the
6565
tokenized dataset
66+
opt_args (fms_mo.training_args.OptArguments): Generic optimization arguments to be used
67+
during DQ
6668
fms_mo_args (fms_mo.training_args.FMSMOArguments): Parameters to use for DQ quantization
67-
output_dir (str) Output directory to write to
6869
"""
6970
# for attention or kv-cache quantization, need to use eager attention
7071
attn_bits = [
@@ -218,9 +219,9 @@ def run_dq(model_args, data_args, fms_mo_args, output_dir):
218219
with patch_torch_bmm(qcfg):
219220
model(**data_mb)
220221

221-
logger.info(f"Saving quantized model and tokenizer to {output_dir}")
222-
model.save_pretrained(output_dir, use_safetensors=True)
223-
tokenizer.save_pretrained(output_dir)
222+
logger.info(f"Saving quantized model and tokenizer to {opt_args.output_dir}")
223+
model.save_pretrained(opt_args.output_dir, use_safetensors=True)
224+
tokenizer.save_pretrained(opt_args.output_dir)
224225

225226
if fms_mo_args.eval_ppl:
226227
path_test = Path(data_args.test_data_path)

fms_mo/run_quant.py

Lines changed: 46 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,10 @@
3939
from fms_mo.training_args import (
4040
DataArguments,
4141
FMSMOArguments,
42-
FP8Args,
43-
GPTQArgs,
42+
FP8Arguments,
43+
GPTQArguments,
4444
ModelArguments,
45+
OptArguments,
4546
)
4647
from fms_mo.utils.import_utils import available_packages
4748

@@ -51,11 +52,10 @@
5152
def quantize(
5253
model_args: ModelArguments,
5354
data_args: DataArguments,
54-
fms_mo_args: FMSMOArguments,
55-
gptq_args: GPTQArgs,
56-
fp8_args: FP8Args,
57-
quant_method: str,
58-
output_dir: str,
55+
opt_args: OptArguments,
56+
fms_mo_args: FMSMOArguments = None,
57+
gptq_args: GPTQArguments = None,
58+
fp8_args: FP8Arguments = None,
5959
):
6060
"""Main entry point to quantize a given model with a set of specified hyperparameters
6161
@@ -64,23 +64,23 @@ def quantize(
6464
the model
6565
data_args (fms_mo.training_args.DataArguments): Data arguments to be used when loading the
6666
tokenized dataset
67+
opt_args (fms_mo.training_args.OptArguments): Generic optimization related arguments
6768
fms_mo_args (fms_mo.training_args.FMSMOArguments): Parameters to use for PTQ quantization
68-
gptq_args (fms_mo.training_args.GPTQArgs): Parameters to use for GPTQ quantization
69-
fp8_args (fms_mo.training_args.FP8Args): Parameters to use for FP8 quantization
70-
quant_method (str): Quantization technique, options are gptq and fp8
71-
output_dir (str) Output directory to write to
69+
gptq_args (fms_mo.training_args.GPTQArguments): Parameters to use for GPTQ quantization
70+
fp8_args (fms_mo.training_args.FP8Arguments): Parameters to use for FP8 quantization
7271
"""
7372

74-
logging.info(f"{fms_mo_args}\n{quant_method}\n")
75-
if quant_method == "gptq":
73+
logger.info(f"{fms_mo_args}\n{opt_args.quant_method}\n")
74+
75+
if opt_args.quant_method == "gptq":
7676
if not available_packages["auto_gptq"]:
7777
raise ImportError(
7878
"Quantization method has been selected as gptq but unable to use external library, "
7979
"auto_gptq module not found. For more instructions on installing the appropriate "
8080
"package, see https://github.com/AutoGPTQ/AutoGPTQ?tab=readme-ov-file#installation"
8181
)
82-
run_gptq(model_args, data_args, gptq_args, output_dir)
83-
elif quant_method == "fp8":
82+
run_gptq(model_args, data_args, opt_args, gptq_args)
83+
elif opt_args.quant_method == "fp8":
8484
if not available_packages["llmcompressor"]:
8585
raise ImportError(
8686
"Quantization method has been selected as fp8 but unable to use external library, "
@@ -89,25 +89,26 @@ def quantize(
8989
"https://github.com/vllm-project/llm-compressor/tree/"
9090
"main?tab=readme-ov-file#installation"
9191
)
92-
run_fp8(model_args, data_args, fp8_args, output_dir)
93-
elif quant_method == "dq":
94-
run_dq(model_args, data_args, fms_mo_args, output_dir)
92+
run_fp8(model_args, data_args, opt_args, fp8_args)
93+
elif opt_args.quant_method == "dq":
94+
run_dq(model_args, data_args, opt_args, fms_mo_args)
9595
else:
9696
raise ValueError(
97-
"Not a valid quantization technique option. Please choose from: gptq, fp8, dq"
97+
f"{opt_args.quant_method} is not a valid quantization technique option. \
98+
Please choose from: gptq, fp8, dq"
9899
)
99100

100101

101-
def run_gptq(model_args, data_args, gptq_args, output_dir):
102+
def run_gptq(model_args, data_args, opt_args, gptq_args):
102103
"""GPTQ quantizes a given model with a set of specified hyperparameters
103104
104105
Args:
105106
model_args (fms_mo.training_args.ModelArguments): Model arguments to be used when loading
106107
the model
107108
data_args (fms_mo.training_args.DataArguments): Data arguments to be used when loading the
108109
tokenized dataset
109-
gptq_args (fms_mo.training_args.GPTQArgs): Parameters to use for GPTQ quantization
110-
output_dir (str) Output directory to write to
110+
opt_args (fms_mo.training_args.OptArguments): Generic optimization related arguments
111+
gptq_args (fms_mo.training_args.GPTQArguments): Parameters to use for GPTQ quantization
111112
"""
112113

113114
# Third Party
@@ -152,23 +153,25 @@ def run_gptq(model_args, data_args, gptq_args, output_dir):
152153
cache_examples_on_gpu=gptq_args.cache_examples_on_gpu,
153154
)
154155

155-
logger.info(f"Time to quantize model at {output_dir}: {time.time() - start_time}")
156+
logger.info(
157+
f"Time to quantize model at {opt_args.output_dir}: {time.time() - start_time}"
158+
)
156159

157-
logger.info(f"Saving quantized model and tokenizer to {output_dir}")
158-
model.save_quantized(output_dir, use_safetensors=True)
159-
tokenizer.save_pretrained(output_dir)
160+
logger.info(f"Saving quantized model and tokenizer to {opt_args.output_dir}")
161+
model.save_quantized(opt_args.output_dir, use_safetensors=True)
162+
tokenizer.save_pretrained(opt_args.output_dir)
160163

161164

162-
def run_fp8(model_args, data_args, fp8_args, output_dir):
165+
def run_fp8(model_args, data_args, opt_args, fp8_args):
163166
"""FP8 quantizes a given model with a set of specified hyperparameters
164167
165168
Args:
166169
model_args (fms_mo.training_args.ModelArguments): Model arguments to be used when loading
167170
the model
168171
data_args (fms_mo.training_args.DataArguments): Data arguments to be used when loading the
169172
tokenized dataset
170-
fp8_args (fms_mo.training_args.FP8Args): Parameters to use for FP8 quantization
171-
output_dir (str) Output directory to write to
173+
opt_args (fms_mo.training_args.OptArguments): Generic optimization related arguments
174+
fp8_args (fms_mo.training_args.FP8Arguments): Parameters to use for FP8 quantization
172175
"""
173176

174177
# Third Party
@@ -192,11 +195,13 @@ def run_fp8(model_args, data_args, fp8_args, output_dir):
192195
max_seq_length=data_args.max_seq_length,
193196
num_calibration_samples=data_args.num_calibration_samples,
194197
)
195-
logger.info(f"Time to quantize model at {output_dir}: {time.time() - start_time}")
198+
logger.info(
199+
f"Time to quantize model at {opt_args.output_dir}: {time.time() - start_time}"
200+
)
196201

197-
logger.info(f"Saving quantized model and tokenizer to {output_dir}")
198-
model.save_pretrained(output_dir)
199-
tokenizer.save_pretrained(output_dir)
202+
logger.info(f"Saving quantized model and tokenizer to {opt_args.output_dir}")
203+
model.save_pretrained(opt_args.output_dir)
204+
tokenizer.save_pretrained(opt_args.output_dir)
200205

201206

202207
def main():
@@ -206,53 +211,41 @@ def main():
206211
dataclass_types=(
207212
ModelArguments,
208213
DataArguments,
214+
OptArguments,
209215
FMSMOArguments,
210-
GPTQArgs,
211-
FP8Args,
216+
GPTQArguments,
217+
FP8Arguments,
212218
)
213219
)
214220

215-
parser.add_argument(
216-
"--quant_method",
217-
type=str.lower,
218-
choices=["gptq", "fp8", None, "none", "dq"],
219-
default="none",
220-
)
221-
222-
parser.add_argument("--output_dir", type=str)
223-
224221
(
225222
model_args,
226223
data_args,
224+
opt_args,
227225
fms_mo_args,
228226
gptq_args,
229227
fp8_args,
230-
additional,
231228
_,
232229
) = parser.parse_args_into_dataclasses(return_remaining_strings=True)
233-
quant_method = additional.quant_method
234-
output_dir = additional.output_dir
235230

236231
logger.debug(
237-
"Input args parsed: \nmodel_args %s, data_args %s, fms_mo_args %s, "
238-
"gptq_args %s, fp8_args %s, quant_method %s, output_dir %s",
232+
"Input args parsed: \nmodel_args %s, data_args %s, opt_args %s, fms_mo_args %s, "
233+
"gptq_args %s, fp8_args %s",
239234
model_args,
240235
data_args,
236+
opt_args,
241237
fms_mo_args,
242238
gptq_args,
243239
fp8_args,
244-
quant_method,
245-
output_dir,
246240
)
247241

248242
quantize(
249243
model_args=model_args,
250244
data_args=data_args,
245+
opt_args=opt_args,
251246
fms_mo_args=fms_mo_args,
252247
gptq_args=gptq_args,
253248
fp8_args=fp8_args,
254-
quant_method=quant_method,
255-
output_dir=output_dir,
256249
)
257250

258251

fms_mo/training_args.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,18 @@
1818

1919
# Standard
2020
from dataclasses import dataclass, field
21-
from typing import List, Optional
21+
from typing import List, Optional, Union
22+
23+
# Third Party
24+
import torch
2225

2326

2427
@dataclass
2528
class ModelArguments:
2629
"""Dataclass for model related arguments."""
2730

2831
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
29-
torch_dtype: Optional[str] = field(
30-
default=None,
31-
metadata={"help": ["bfloat16", "float16", "float", "auto"]},
32-
)
32+
torch_dtype: Optional[Union[torch.dtype, str]] = torch.bfloat16
3333
use_fast_tokenizer: bool = field(
3434
default=True,
3535
metadata={
@@ -79,6 +79,24 @@ class DataArguments:
7979
num_calibration_samples: Optional[int] = field(default=512)
8080

8181

82+
@dataclass
83+
class OptArguments:
84+
"""Dataclass for optimization related arguments."""
85+
86+
quant_method: str = field(
87+
metadata={"choices": ["gptq", "fp8", "dq"], "help": "Quantization technique"}
88+
)
89+
output_dir: str = field(
90+
metadata={
91+
"help": "Output directory to write quantized model artifacts and log files to"
92+
}
93+
)
94+
log_level: str = field(
95+
default="INFO",
96+
metadata={"help": "The log level to adopt during optimization."},
97+
)
98+
99+
82100
@dataclass
83101
class FMSMOArguments:
84102
"""Dataclass arguments used by fms_mo native quantization functions."""
@@ -115,7 +133,7 @@ class FMSMOArguments:
115133

116134

117135
@dataclass
118-
class GPTQArgs:
136+
class GPTQArguments:
119137
"""Dataclass for GPTQ related arguments that will be used by auto-gptq."""
120138

121139
bits: int = field(default=4, metadata={"choices": [2, 3, 4, 8]})
@@ -133,7 +151,7 @@ class GPTQArgs:
133151

134152

135153
@dataclass
136-
class FP8Args:
154+
class FP8Arguments:
137155
"""Dataclass for FP8 related arguments that will be used by llm-compressor."""
138156

139157
targets: str = field(default="Linear")

0 commit comments

Comments
 (0)