11# Copyright The FMS Model Optimizer Authors
2- #
2+
33# Licensed under the Apache License, Version 2.0 (the "License");
44# you may not use this file except in compliance with the License.
55# You may obtain a copy of the License at
6- #
6+
77# http://www.apache.org/licenses/LICENSE-2.0
8- #
8+
99# Unless required by applicable law or agreed to in writing, software
1010# distributed under the License is distributed on an "AS IS" BASIS,
1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
3434)
3535import torch
3636
37+ import os
3738# Local
3839from fms_mo import qconfig_init , qmodel_prep
3940from fms_mo .custom_ext_kernels .utils import (
5051from fms_mo .utils .dq_utils import config_quantize_smooth_layers
5152from fms_mo .utils .eval_utils import Evaluator , eval_llm_1GPU
5253from fms_mo .utils .utils import patch_torch_bmm , prepare_input
53- from fms_mo .utils .dq_inf import load_fp8_vllm , save_vllm_fp8
54- from accelerate import load_checkpoint_and_dispatch
54+ from fms_mo .utils .dq_inf import (
55+ save_vllm_fp8 ,
56+ convert_fp8_vllm_to_fms_mo ,
57+ check_quantization_setting ,
58+ )
5559
5660logger = logging .getLogger (__name__ )
5761
@@ -129,18 +133,42 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
129133 low_cpu_mem_usage = bool (model_args .device_map ),
130134 )
131135
136+ inference = model .config .to_dict ().get ("quantization_config" ,None )
137+
138+ if inference :
139+ quant_setting = check_quantization_setting (inference )
140+ if quant_setting :
141+ logger .info ("Quantization config settings validated " )
142+ model = convert_fp8_vllm_to_fms_mo (model = model )
143+ else :
144+ exit ("__This quantization config is wrong/not supported__" )
145+
146+
132147 embedding_size = model .get_input_embeddings ().weight .shape [0 ]
133148 if len (tokenizer ) > embedding_size :
134149 model .resize_token_embeddings (len (tokenizer ))
135150
136151 logger .info (f"Initialized model is: \n { model } " )
137152 logger .info (f"Model is at { model .device } after intialization" )
138153 logger .info (f"Tokenizer is { tokenizer } , block size is { block_size } " )
139-
140- if not fms_mo_args .inference or fms_mo_args .vllm_fp8_load :
154+
155+ if not inference :
156+ logger .info ("quantization mode activated, initalizing the qcfg file " )
141157 qcfg = qconfig_init (recipe = "dq" , args = fms_mo_args )
142158 else :
143- qcfg = qconfig_init (recipe = opt_args .output_dir + "/qcfg" )
159+ logger .info ("inference mode activated" )
160+ if os .path .isfile (model_args .model_name_or_path + "/qcfg.json" ):
161+ if fms_mo_args .override_fms_args :
162+ logger .info ("qcfg file found and some parameters are being over-written " )
163+ qcfg = qconfig_init (recipe = model_args .model_name_or_path + "/qcfg" , args = fms_mo_args )
164+ else :
165+ logger .info ("qcfg file found, loading the qcfg file " )
166+ qcfg = qconfig_init (recipe = model_args .model_name_or_path + "/qcfg" )
167+ else :
168+ logger .info ("qcfg file not found in {model_args.model_name_or_path},\
169+ loading fms_mo_args and recipe"
170+ )
171+ qcfg = qconfig_init (recipe = "dq" , args = fms_mo_args )
144172
145173 model_size = model_size_Wb (model , unit = "GB" )
146174 gpu_mem_util_per = model_size / total_gpu_memory
@@ -184,6 +212,7 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
184212 qcfg ["model" ] = model_args .model_name_or_path
185213 qcfg ["smoothq" ] = qcfg .get ("smoothq_alpha" , - 1 ) >= 0 and "mx_specs" not in qcfg
186214 qcfg ["plotsvg" ] = False
215+ qcfg ["output_folder" ] = opt_args .output_dir
187216
188217 calibration_dataset = load_from_disk (data_args .training_data_path )
189218 calibration_dataset = calibration_dataset .with_format ("torch" )
@@ -196,7 +225,7 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
196225 )
197226
198227 # For loading or creating smoothquant scale. Sometimes we may include scales in ckpt as well.
199- if not fms_mo_args . inference and qcfg ["smoothq" ] :
228+ if not inference and qcfg ["smoothq" ] :
200229 scale_file = Path (f"./act_scales/{ qcfg ['model' ].replace ('/' , '-' )} .pt" )
201230 if qcfg .get ("act_scale_path" , None ):
202231 # user provided a scale file (or a dir)
@@ -230,14 +259,12 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
230259 use_layer_name_pattern_matching = use_layer_name_pattern_matching ,
231260 use_dynamo = use_dynamo ,
232261 dev = dev ,
233- mode = fms_mo_args . inference ,
262+ mode = inference ,
234263 save_fname = "dq" ,
235- folder = opt_args .output_dir ,
236264 )
237265 logger .info (f"Quantized model { model } " )
238266 logger .info ("==" * 20 )
239-
240- if not fms_mo_args .inference :
267+ if not inference :
241268 if qcfg ["smoothq" ]:
242269 logger .info ("Starting to apply smooth scale" )
243270 dq_llm (model , act_scales , qcfg )
@@ -264,7 +291,7 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
264291 f"Saving model processed for AIU and tokenizer to { opt_args .output_dir } "
265292 )
266293 save_for_aiu (model , qcfg , output_dir = opt_args .output_dir , verbose = True )
267- elif opt_args .save_ckpt_for_vllm :
294+ elif not opt_args .save_ckpt :
268295 logger .info (
269296 f"Saving model processed for vLLM and tokenizer to { opt_args .output_dir } "
270297 )
@@ -287,19 +314,6 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
287314 clamp_acc_to_dl16 = fms_mo_args .aiu_sim_triton == "fp8" ,
288315 # layer_to_exclude=["lm_head",]
289316 )
290- else :
291- if fms_mo_args .vllm_fp8_load :
292- logger .info ("loading llmcompressor fp8 model saved_checkpoint" )
293- model = load_fp8_vllm ( model = model , checkpoint = opt_args .output_dir )
294-
295- else :
296- logger .info ("loading dq fms_mo fp8 model saved_checkpoint" )
297- model = load_checkpoint_and_dispatch (
298- model ,
299- checkpoint = opt_args .output_dir ,
300- device_map = None ,
301- no_split_module_classes = ['Block' ]
302- )
303317
304318 if fms_mo_args .eval_ppl :
305319 path_test = Path (data_args .test_data_path )
0 commit comments