Skip to content

Commit 6328fbd

Browse files
committed
Move aiu arguments to other dataclasses
Signed-off-by: Andrea Fasoli <[email protected]>
1 parent 842aead commit 6328fbd

File tree

4 files changed

+16
-31
lines changed

4 files changed

+16
-31
lines changed

fms_mo/dq.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
logger = logging.getLogger(__name__)
5252

5353

54-
def run_dq(model_args, data_args, opt_args, fms_mo_args, aiu_args = None):
54+
def run_dq(model_args, data_args, opt_args, fms_mo_args):
5555
"""
5656
For direct quantization LLMs without optimization:
5757
Models are directly quantized into INT8 or FP8 precisions using
@@ -66,8 +66,6 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args, aiu_args = None):
6666
opt_args (fms_mo.training_args.OptArguments): Generic optimization arguments to be used
6767
during DQ
6868
fms_mo_args (fms_mo.training_args.FMSMOArguments): Parameters to use for DQ quantization
69-
aiu_args (fms_mo.training_args.AIUArguments): Parameters specific to AIU-compliant
70-
checkpoint generation and saving
7169
7270
NOTE:
7371
use dynamo tracing instead of torchscript by default. if torchscript is needed, change
@@ -175,7 +173,7 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args, aiu_args = None):
175173

176174
qcfg["seq_len"] = block_size
177175
qcfg["model"] = model_args.model_name_or_path
178-
qcfg["smoothq"] = True
176+
qcfg["smoothq"] = qcfg.get("smoothq_alpha", -1) >= 0
179177
qcfg["plotsvg"] = False
180178

181179
calibration_dataset = load_from_disk(data_args.training_data_path)
@@ -224,10 +222,13 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args, aiu_args = None):
224222
save_fname="dq",
225223
)
226224
logger.info(f"Quantized model {model}")
227-
logger.info("Starting to apply smooth scale")
228-
dq_llm(model, act_scales, qcfg)
229-
logger.info("Finished applying smooth scale")
230225
logger.info("==" * 20)
226+
227+
if qcfg["smoothq"]:
228+
logger.info("Starting to apply smooth scale")
229+
dq_llm(model, act_scales, qcfg)
230+
logger.info("Finished applying smooth scale")
231+
231232
if qcfg["qmodel_calibration_new"] > 0:
232233
logger.info("Starting to calibrate activation clip_val")
233234
if qcfg["large_model"]:
@@ -244,9 +245,9 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args, aiu_args = None):
244245
with patch_torch_bmm(qcfg):
245246
model(**data_mb)
246247

247-
if aiu_args is not None and aiu_args.save_ckpt_for_aiu:
248+
if opt_args.save_ckpt_for_aiu:
248249
logger.info(
249-
f"Saving model processed for AIU and tokenizer to {aiu_args.output_dir}"
250+
f"Saving model processed for AIU and tokenizer to {opt_args.output_dir}"
250251
)
251252
save_for_aiu(model, qcfg, output_dir=opt_args.output_dir, verbose=True)
252253
elif opt_args.save_ckpt:

fms_mo/run_quant.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
# Local
4444
from fms_mo.dq import run_dq
4545
from fms_mo.training_args import (
46-
AIUArguments,
4746
DataArguments,
4847
FMSMOArguments,
4948
FP8Arguments,
@@ -68,7 +67,6 @@ def quantize(
6867
fms_mo_args: FMSMOArguments = None,
6968
gptq_args: GPTQArguments = None,
7069
fp8_args: FP8Arguments = None,
71-
aiu_args: AIUArguments = None,
7270
):
7371
"""Main entry point to quantize a given model with a set of specified hyperparameters
7472
@@ -107,7 +105,7 @@ def quantize(
107105
)
108106
run_fp8(model_args, data_args, opt_args, fp8_args)
109107
elif opt_args.quant_method == "dq":
110-
run_dq(model_args, data_args, opt_args, fms_mo_args, aiu_args)
108+
run_dq(model_args, data_args, opt_args, fms_mo_args)
111109
else:
112110
raise ValueError(
113111
f"{opt_args.quant_method} is not a valid quantization technique option. \
@@ -236,7 +234,6 @@ def get_parser():
236234
FMSMOArguments,
237235
GPTQArguments,
238236
FP8Arguments,
239-
AIUArguments,
240237
)
241238
)
242239
return parser
@@ -273,7 +270,6 @@ def parse_arguments(parser, json_config=None):
273270
fms_mo_args,
274271
gptq_args,
275272
fp8_args,
276-
aiu_args,
277273
) = parser.parse_dict(json_config, allow_extra_keys=True)
278274
else:
279275
(
@@ -283,7 +279,6 @@ def parse_arguments(parser, json_config=None):
283279
fms_mo_args,
284280
gptq_args,
285281
fp8_args,
286-
aiu_args,
287282
_,
288283
) = parser.parse_args_into_dataclasses(return_remaining_strings=True)
289284

@@ -298,7 +293,6 @@ def parse_arguments(parser, json_config=None):
298293
fms_mo_args,
299294
gptq_args,
300295
fp8_args,
301-
aiu_args,
302296
)
303297

304298

@@ -317,15 +311,14 @@ def main():
317311
fms_mo_args,
318312
gptq_args,
319313
fp8_args,
320-
aiu_args,
321314
) = parse_arguments(parser, job_config)
322315

323316
logger = set_log_level(opt_args.log_level, __name__)
324317

325318
logger.debug(
326319
f"Input args parsed: \nmodel_args {model_args}, data_args {data_args}, "
327320
f"opt_args {opt_args}, fms_mo_args {fms_mo_args}, gptq_args {gptq_args}, "
328-
f"fp8_args {fp8_args}, aiu_args {aiu_args}"
321+
f"fp8_args {fp8_args}"
329322
)
330323
except Exception as e: # pylint: disable=broad-except
331324
logger.error(traceback.format_exc())
@@ -345,7 +338,6 @@ def main():
345338
fms_mo_args=fms_mo_args,
346339
gptq_args=gptq_args,
347340
fp8_args=fp8_args,
348-
aiu_args=aiu_args,
349341
)
350342
except (MemoryError, OutOfMemoryError) as e:
351343
logger.error(traceback.format_exc())

fms_mo/training_args.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -142,16 +142,6 @@ class OptArguments(TypeChecker):
142142
default=True,
143143
metadata={"help": "Save quantized checkpoint."},
144144
)
145-
146-
147-
@dataclass
148-
class AIUArguments(TypeChecker):
149-
"""Dataclass for AIU-related arguments. Only apply to Direct Quantization runs."""
150-
151-
recompute_narrow_weights: bool = field(
152-
default=False,
153-
metadata={"help": "Apply recomputation during checkpoint saving."},
154-
)
155145
save_ckpt_for_aiu: bool = field(
156146
default=False,
157147
metadata={"help": "Prepare and save AIU-compliant checkpoint."},
@@ -191,6 +181,10 @@ class FMSMOArguments(TypeChecker):
191181
default=2048, metadata={"help": "input sequence length after tokenization"}
192182
)
193183
eval_ppl: bool = field(default=False)
184+
recompute_narrow_weights: bool = field(
185+
default=False,
186+
metadata={"help": "Apply recomputation during checkpoint saving for AIU."},
187+
)
194188

195189

196190
@dataclass

tests/test_run_quant.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,6 @@ def test_parse_arguments(job_config):
110110
_,
111111
_,
112112
_,
113-
_,
114113
) = parse_arguments(parser, job_config_copy)
115114
assert str(model_args.torch_dtype) == "torch.bfloat16"
116115
assert data_args.training_data_path == "data_train"
@@ -133,7 +132,6 @@ def test_parse_arguments_defaults(job_config):
133132
fms_mo_args,
134133
_,
135134
_,
136-
_,
137135
) = parse_arguments(parser, job_config_defaults)
138136
assert str(model_args.torch_dtype) == "torch.bfloat16"
139137
assert model_args.model_revision == "main"

0 commit comments

Comments
 (0)