@@ -1126,7 +1126,6 @@ def squad_eval(model, keep_model_in_eval_mode=True):
11261126 from fms_mo import qconfig_init , qmodel_prep
11271127
11281128 if args .do_qat :
1129-
11301129 # create a config dict, if same item exists in both recipe and args, args has the priority.
11311130 qcfg = qconfig_init (recipe = "qat_int8" , args = args )
11321131
@@ -1141,7 +1140,6 @@ def squad_eval(model, keep_model_in_eval_mode=True):
11411140 qmodel_prep (model , exam_inp , qcfg , optimizer , use_dynamo = True )
11421141
11431142 if args .do_ptq :
1144-
11451143 # Local
11461144 from fms_mo .quant .ptq import calib_PTQ_lm
11471145
@@ -1216,17 +1214,16 @@ def speedtest(model, exam_inp, Ntest=100):
12161214 ("int8" , "ind" ),
12171215 ("int8" , "cugr" ),
12181216 ]:
1219-
12201217 logger .info (
12211218 f"\n { label } { 'with' if comp_mode else 'without' } torch.compile"
12221219 )
12231220 model_copy = deepcopy (model )
12241221
12251222 if label == "int8" :
12261223 qcfg = qconfig_init (recipe = "ptq_int8" , args = args )
1227- qcfg [
1228- "qmodel_calibration"
1229- ] = 0 # no need to run calibration or trained scales will be lost.
1224+ qcfg ["qmodel_calibration" ] = (
1225+ 0 # no need to run calibration or trained scales will be lost.
1226+ )
12301227 qmodel_prep (
12311228 model_copy ,
12321229 exam_inp ,
@@ -1479,9 +1476,9 @@ def speedtest(model, exam_inp, Ntest=100):
14791476 "step" : completed_steps ,
14801477 }
14811478 if args .do_predict :
1482- log [
1483- "squad_v2_predict" if args . version_2_with_negative else "squad_predict"
1484- ] = predict_metric
1479+ log ["squad_v2_predict" if args . version_2_with_negative else "squad_predict" ] = (
1480+ predict_metric
1481+ )
14851482
14861483 accelerator .log (log , step = completed_steps )
14871484
0 commit comments