|
1 | | -import os |
2 | 1 | import argparse |
3 | 2 | import time |
4 | 3 | import json |
|
12 | 11 | from transformers import AutoConfig |
13 | 12 | from transformers import TextStreamer |
14 | 13 | import intel_extension_for_pytorch as ipex |
| 14 | +from intel_extension_for_pytorch.llm.utils import ( |
| 15 | + load_low_precision_checkpoint, |
| 16 | +) |
15 | 17 | from ast import literal_eval |
16 | 18 | import sys |
17 | 19 |
|
@@ -1183,75 +1185,8 @@ def calib_func(prepared_model): |
1183 | 1185 | ) |
1184 | 1186 | if args.low_precision_checkpoint != "": |
1185 | 1187 | pathname = args.low_precision_checkpoint |
1186 | | - assert os.path.exists(pathname), f"Checkpoint file does not exist: {pathname}" |
1187 | | - if os.path.isfile(pathname): |
1188 | | - low_precision_checkpoint = None |
1189 | | - if pathname.endswith((".pt", ".pth", ".bin")): |
1190 | | - low_precision_checkpoint = torch.load(pathname, weights_only=True) |
1191 | | - elif pathname.endswith(".safetensors"): |
1192 | | - try: |
1193 | | - import safetensors |
1194 | | - except ImportError: |
1195 | | - print( |
1196 | | - "Please install safetensors package to load safetensors checkpoint." |
1197 | | - ) |
1198 | | - exit(1) |
1199 | | - low_precision_checkpoint = safetensors.torch.load_file(pathname) |
1200 | | - assert ( |
1201 | | - low_precision_checkpoint is not None |
1202 | | - ), f"Invalid checkpoint file: {pathname}. Should be a .pt, .pth, .bin or .safetensors file." |
1203 | | - |
1204 | | - quant_method = {"quant_method": "gptq"} |
1205 | | - |
1206 | | - elif os.path.isdir(pathname): |
1207 | | - low_precision_checkpoint = {} |
1208 | | - for pattern in ["*.pt", "*.pth", "*.bin"]: |
1209 | | - files = list(pathlib.Path(pathname).glob(pattern)) |
1210 | | - if files: |
1211 | | - for f in files: |
1212 | | - data_f = torch.load(f, weights_only=True) |
1213 | | - low_precision_checkpoint.update(data_f) |
1214 | | - break |
1215 | | - if not low_precision_checkpoint: |
1216 | | - files = list(pathlib.Path(pathname).glob("*.safetensors")) |
1217 | | - if files: |
1218 | | - try: |
1219 | | - import safetensors |
1220 | | - except ImportError: |
1221 | | - print( |
1222 | | - "Please install safetensors package to load safetensors checkpoint." |
1223 | | - ) |
1224 | | - exit(1) |
1225 | | - for f in files: |
1226 | | - data_f = safetensors.torch.load_file(f) |
1227 | | - low_precision_checkpoint.update(data_f) |
1228 | | - assert ( |
1229 | | - len(low_precision_checkpoint) > 0 |
1230 | | - ), f"Cannot find checkpoint (.pt/.pth/.bin/.safetensors) files in path {pathname}." |
1231 | | - |
1232 | | - try: |
1233 | | - with open(pathname + "/config.json") as f: |
1234 | | - quant_model_config = json.load(f) |
1235 | | - quant_method = { |
1236 | | - "quant_method": quant_model_config["quantization_config"][ |
1237 | | - "quant_method" |
1238 | | - ] |
1239 | | - } |
1240 | | - except Exception as e: |
1241 | | - print( |
1242 | | - "warning: loading HF config.json to get `quant_method` failed, due to ", |
1243 | | - e, |
1244 | | - ) |
1245 | | - print("warning: specifying `quant_method` = `gptq` by default.") |
1246 | | - quant_method = {"quant_method": "gptq"} |
1247 | | - |
1248 | | - else: |
1249 | | - raise AssertionError( |
1250 | | - f"Invalid low-precision-checkpoint: {pathname}." |
1251 | | - " Should be a .pt/.pth/.safetensors file or a directory containing them." |
1252 | | - ) |
1253 | | - |
1254 | | - low_precision_checkpoint = (low_precision_checkpoint, quant_method) |
| 1188 | + low_precision_checkpoint, quant_config = load_low_precision_checkpoint(pathname) |
| 1189 | + low_precision_checkpoint = (low_precision_checkpoint, quant_config) |
1255 | 1190 |
|
1256 | 1191 | if args.gptq_legacy_format: |
1257 | 1192 | raise AssertionError( |
|
0 commit comments