Skip to content

Commit d9d12ea

Browse files
Move generic opt params to OptArgs
Signed-off-by: Thara Palanivel <[email protected]>
1 parent 33eec02 commit d9d12ea

File tree

3 files changed

+65
-51
lines changed

3 files changed

+65
-51
lines changed

fms_mo/dq.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
logger = logging.getLogger(__name__)
5252

5353

54-
def run_dq(model_args, data_args, fms_mo_args, output_dir):
54+
def run_dq(model_args, data_args, opt_args, fms_mo_args):
5555
"""
5656
For direct quantization LLMs without optimization:
5757
Models are directly quantized into INT8 or FP8 precisions using
@@ -63,8 +63,9 @@ def run_dq(model_args, data_args, fms_mo_args, output_dir):
6363
the model
6464
data_args (fms_mo.training_args.DataArguments): Data arguments to be used when loading the
6565
tokenized dataset
66+
opt_args (fms_mo.training_args.OptArguments): Generic optimization arguments to be used
67+
during DQ
6668
fms_mo_args (fms_mo.training_args.FMSMOArguments): Parameters to use for DQ quantization
67-
output_dir (str) Output directory to write to
6869
"""
6970
# for attention or kv-cache quantization, need to use eager attention
7071
attn_bits = [
@@ -218,9 +219,9 @@ def run_dq(model_args, data_args, fms_mo_args, output_dir):
218219
with patch_torch_bmm(qcfg):
219220
model(**data_mb)
220221

221-
logger.info(f"Saving quantized model and tokenizer to {output_dir}")
222-
model.save_pretrained(output_dir, use_safetensors=True)
223-
tokenizer.save_pretrained(output_dir)
222+
logger.info(f"Saving quantized model and tokenizer to {opt_args.output_dir}")
223+
model.save_pretrained(opt_args.output_dir, use_safetensors=True)
224+
tokenizer.save_pretrained(opt_args.output_dir)
224225

225226
if fms_mo_args.eval_ppl:
226227
path_test = Path(data_args.test_data_path)

fms_mo/run_quant.py

Lines changed: 36 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
FP8Arguments,
4343
GPTQArguments,
4444
ModelArguments,
45+
OptArguments,
4546
)
4647
from fms_mo.utils.import_utils import available_packages
4748

@@ -51,11 +52,10 @@
5152
def quantize(
5253
model_args: ModelArguments,
5354
data_args: DataArguments,
54-
fms_mo_args: FMSMOArguments,
55-
gptq_args: GPTQArguments,
56-
fp8_args: FP8Arguments,
57-
quant_method: str,
58-
output_dir: str,
55+
opt_args: OptArguments,
56+
fms_mo_args: FMSMOArguments = None,
57+
gptq_args: GPTQArguments = None,
58+
fp8_args: FP8Arguments = None,
5959
):
6060
"""Main entry point to quantize a given model with a set of specified hyperparameters
6161
@@ -71,16 +71,17 @@ def quantize(
7171
output_dir (str) Output directory to write to
7272
"""
7373

74-
logging.info(f"{fms_mo_args}\n{quant_method}\n")
75-
if quant_method == "gptq":
74+
logger.info(f"{fms_mo_args}\n{opt_args.quant_method}\n")
75+
76+
if opt_args.quant_method == "gptq":
7677
if not available_packages["auto_gptq"]:
7778
raise ImportError(
7879
"Quantization method has been selected as gptq but unable to use external library, "
7980
"auto_gptq module not found. For more instructions on installing the appropriate "
8081
"package, see https://github.com/AutoGPTQ/AutoGPTQ?tab=readme-ov-file#installation"
8182
)
82-
run_gptq(model_args, data_args, gptq_args, output_dir)
83-
elif quant_method == "fp8":
83+
run_gptq(model_args, data_args, opt_args, gptq_args)
84+
elif opt_args.quant_method == "fp8":
8485
if not available_packages["llmcompressor"]:
8586
raise ImportError(
8687
"Quantization method has been selected as fp8 but unable to use external library, "
@@ -89,16 +90,18 @@ def quantize(
8990
"https://github.com/vllm-project/llm-compressor/tree/"
9091
"main?tab=readme-ov-file#installation"
9192
)
92-
run_fp8(model_args, data_args, fp8_args, output_dir)
93-
elif quant_method == "dq":
94-
run_dq(model_args, data_args, fms_mo_args, output_dir)
93+
run_fp8(model_args, data_args, opt_args, fp8_args)
94+
elif opt_args.quant_method == "dq":
95+
run_dq(model_args, data_args, opt_args, fms_mo_args)
9596
else:
9697
raise ValueError(
97-
"Not a valid quantization technique option. Please choose from: gptq, fp8, dq"
98+
"{} is not a valid quantization technique option. Please choose from: gptq, fp8, dq".format(
99+
opt_args.quant_method
100+
)
98101
)
99102

100103

101-
def run_gptq(model_args, data_args, gptq_args, output_dir):
104+
def run_gptq(model_args, data_args, opt_args, gptq_args):
102105
"""GPTQ quantizes a given model with a set of specified hyperparameters
103106
104107
Args:
@@ -152,14 +155,16 @@ def run_gptq(model_args, data_args, gptq_args, output_dir):
152155
cache_examples_on_gpu=gptq_args.cache_examples_on_gpu,
153156
)
154157

155-
logger.info(f"Time to quantize model at {output_dir}: {time.time() - start_time}")
158+
logger.info(
159+
f"Time to quantize model at {opt_args.output_dir}: {time.time() - start_time}"
160+
)
156161

157-
logger.info(f"Saving quantized model and tokenizer to {output_dir}")
158-
model.save_quantized(output_dir, use_safetensors=True)
159-
tokenizer.save_pretrained(output_dir)
162+
logger.info(f"Saving quantized model and tokenizer to {opt_args.output_dir}")
163+
model.save_quantized(opt_args.output_dir, use_safetensors=True)
164+
tokenizer.save_pretrained(opt_args.output_dir)
160165

161166

162-
def run_fp8(model_args, data_args, fp8_args, output_dir):
167+
def run_fp8(model_args, data_args, opt_args, fp8_args):
163168
"""FP8 quantizes a given model with a set of specified hyperparameters
164169
165170
Args:
@@ -192,11 +197,13 @@ def run_fp8(model_args, data_args, fp8_args, output_dir):
192197
max_seq_length=data_args.max_seq_length,
193198
num_calibration_samples=data_args.num_calibration_samples,
194199
)
195-
logger.info(f"Time to quantize model at {output_dir}: {time.time() - start_time}")
200+
logger.info(
201+
f"Time to quantize model at {opt_args.output_dir}: {time.time() - start_time}"
202+
)
196203

197-
logger.info(f"Saving quantized model and tokenizer to {output_dir}")
198-
model.save_pretrained(output_dir)
199-
tokenizer.save_pretrained(output_dir)
204+
logger.info(f"Saving quantized model and tokenizer to {opt_args.output_dir}")
205+
model.save_pretrained(opt_args.output_dir)
206+
tokenizer.save_pretrained(opt_args.output_dir)
200207

201208

202209
def main():
@@ -206,53 +213,41 @@ def main():
206213
dataclass_types=(
207214
ModelArguments,
208215
DataArguments,
216+
OptArguments,
209217
FMSMOArguments,
210218
GPTQArguments,
211219
FP8Arguments,
212220
)
213221
)
214222

215-
parser.add_argument(
216-
"--quant_method",
217-
type=str.lower,
218-
choices=["gptq", "fp8", None, "none", "dq"],
219-
default="none",
220-
)
221-
222-
parser.add_argument("--output_dir", type=str)
223-
224223
(
225224
model_args,
226225
data_args,
226+
opt_args,
227227
fms_mo_args,
228228
gptq_args,
229229
fp8_args,
230-
additional,
231230
_,
232231
) = parser.parse_args_into_dataclasses(return_remaining_strings=True)
233-
quant_method = additional.quant_method
234-
output_dir = additional.output_dir
235232

236233
logger.debug(
237-
"Input args parsed: \nmodel_args %s, data_args %s, fms_mo_args %s, "
238-
"gptq_args %s, fp8_args %s, quant_method %s, output_dir %s",
234+
"Input args parsed: \nmodel_args %s, data_args %s, opt_args %s, fms_mo_args %s, "
235+
"gptq_args %s, fp8_args %s",
239236
model_args,
240237
data_args,
238+
opt_args,
241239
fms_mo_args,
242240
gptq_args,
243241
fp8_args,
244-
quant_method,
245-
output_dir,
246242
)
247243

248244
quantize(
249245
model_args=model_args,
250246
data_args=data_args,
247+
opt_args=opt_args,
251248
fms_mo_args=fms_mo_args,
252249
gptq_args=gptq_args,
253250
fp8_args=fp8_args,
254-
quant_method=quant_method,
255-
output_dir=output_dir,
256251
)
257252

258253

fms_mo/training_args.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,18 @@
1818

1919
# Standard
2020
from dataclasses import dataclass, field
21-
from typing import List, Optional
21+
from typing import List, Optional, Union
22+
23+
# Third Party
24+
import torch
2225

2326

2427
@dataclass
2528
class ModelArguments:
2629
"""Dataclass for model related arguments."""
2730

2831
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
29-
torch_dtype: Optional[str] = field(
30-
default=None,
31-
metadata={"help": ["bfloat16", "float16", "float", "auto"]},
32-
)
32+
torch_dtype: Optional[Union[torch.dtype, str]] = torch.bfloat16
3333
use_fast_tokenizer: bool = field(
3434
default=True,
3535
metadata={
@@ -79,6 +79,24 @@ class DataArguments:
7979
num_calibration_samples: Optional[int] = field(default=512)
8080

8181

82+
@dataclass
83+
class OptArguments:
84+
"""Dataclass for optimization related arguments."""
85+
86+
quant_method: str = field(
87+
metadata={"choices": ["gptq", "fp8", "dq"], "help": "Quantization technique"}
88+
)
89+
output_dir: str = field(
90+
metadata={
91+
"help": "Output directory to write quantized model artifacts and log files to"
92+
}
93+
)
94+
log_level: str = field(
95+
default="INFO",
96+
metadata={"help": "The log level to adopt during optimization."},
97+
)
98+
99+
82100
@dataclass
83101
class FMSMOArguments:
84102
"""Dataclass arguments used by fms_mo native quantization functions."""

0 commit comments

Comments
 (0)