1919# You can also adapt this script on your own question answering task. Pointers for this are left as comments.
2020
2121# Standard
22- from pathlib import Path
2322import argparse
2423import json
2524import logging
2625import math
2726import os
2827import random
2928import time
29+ from pathlib import Path
3030
3131# Third Party
32+ import datasets
33+ import evaluate
34+ import numpy as np
35+ import torch
36+ import transformers
3237from accelerate import Accelerator
3338from accelerate .logging import get_logger
3439from accelerate .utils import set_seed
3540from datasets import load_dataset
3641from huggingface_hub import HfApi
3742from torch .utils .data import DataLoader
3843from tqdm .auto import tqdm
39- from transformers import (
40- CONFIG_MAPPING ,
41- MODEL_MAPPING ,
42- AutoConfig ,
43- AutoModelForQuestionAnswering ,
44- AutoTokenizer ,
45- DataCollatorWithPadding ,
46- EvalPrediction ,
47- SchedulerType ,
48- default_data_collator ,
49- get_scheduler ,
50- )
44+ from transformers import (CONFIG_MAPPING , MODEL_MAPPING , AutoConfig ,
45+ AutoModelForQuestionAnswering , AutoTokenizer ,
46+ DataCollatorWithPadding , EvalPrediction ,
47+ SchedulerType , default_data_collator , get_scheduler )
5148from transformers .utils import check_min_version , send_example_telemetry
5249from transformers .utils .versions import require_version
5350from utils_qa import postprocess_qa_predictions
54- import datasets
55- import evaluate
56- import numpy as np
57- import torch
58- import transformers
5951
6052# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
6153check_min_version ("4.39.0.dev0" )
@@ -1122,11 +1114,10 @@ def squad_eval(model, keep_model_in_eval_mode=True):
11221114 return eval_metric
11231115
11241116 # ---- [fms_mo] the following code are added for qat/ptq ----
1125- # Local
1117+ # First Party
11261118 from fms_mo import qconfig_init , qmodel_prep
11271119
11281120 if args .do_qat :
1129-
11301121 # create a config dict, if same item exists in both recipe and args, args has the priority.
11311122 qcfg = qconfig_init (recipe = "qat_int8" , args = args )
11321123
@@ -1141,8 +1132,7 @@ def squad_eval(model, keep_model_in_eval_mode=True):
11411132 qmodel_prep (model , exam_inp , qcfg , optimizer , use_dynamo = True )
11421133
11431134 if args .do_ptq :
1144-
1145- # Local
1135+ # First Party
11461136 from fms_mo .quant .ptq import calib_PTQ_lm
11471137
11481138 # create a config dict, if same item exists in both recipe and args, args has the priority.
@@ -1177,10 +1167,10 @@ def squad_eval(model, keep_model_in_eval_mode=True):
11771167 from copy import deepcopy
11781168
11791169 # Third Party
1180- from torch .ao .quantization .utils import _parent_name
11811170 import pandas as pd
1171+ from torch .ao .quantization .utils import _parent_name
11821172
1183- # Local
1173+ # First Party
11841174 from fms_mo .modules .linear import QLinear , QLinearINT8Deploy
11851175
11861176 def speedtest (model , exam_inp , Ntest = 100 ):
@@ -1216,17 +1206,16 @@ def speedtest(model, exam_inp, Ntest=100):
12161206 ("int8" , "ind" ),
12171207 ("int8" , "cugr" ),
12181208 ]:
1219-
12201209 logger .info (
12211210 f"\n { label } { 'with' if comp_mode else 'without' } torch.compile"
12221211 )
12231212 model_copy = deepcopy (model )
12241213
12251214 if label == "int8" :
12261215 qcfg = qconfig_init (recipe = "ptq_int8" , args = args )
1227- qcfg [
1228- "qmodel_calibration"
1229- ] = 0 # no need to run calibration or trained scales will be lost.
1216+ qcfg ["qmodel_calibration" ] = (
1217+ 0 # no need to run calibration or trained scales will be lost.
1218+ )
12301219 qmodel_prep (
12311220 model_copy ,
12321221 exam_inp ,
@@ -1479,9 +1468,9 @@ def speedtest(model, exam_inp, Ntest=100):
14791468 "step" : completed_steps ,
14801469 }
14811470 if args .do_predict :
1482- log [
1483- "squad_v2_predict" if args . version_2_with_negative else "squad_predict"
1484- ] = predict_metric
1471+ log ["squad_v2_predict" if args . version_2_with_negative else "squad_predict" ] = (
1472+ predict_metric
1473+ )
14851474
14861475 accelerator .log (log , step = completed_steps )
14871476
0 commit comments