Skip to content

Commit 6fc2ad8

Browse files
authored
Fix loading checkpoint in BKC run_quantization.py (#3482)
* Fix loading checkpoint in BKC run_quantization.py * fix flake8 issue
1 parent e14b2a2 commit 6fc2ad8

File tree

1 file changed

+5
-70
lines changed

1 file changed

+5
-70
lines changed

examples/cpu/llm/inference/single_instance/run_quantization.py

Lines changed: 5 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import os
21
import argparse
32
import time
43
import json
@@ -12,6 +11,9 @@
1211
from transformers import AutoConfig
1312
from transformers import TextStreamer
1413
import intel_extension_for_pytorch as ipex
14+
from intel_extension_for_pytorch.llm.utils import (
15+
load_low_precision_checkpoint,
16+
)
1517
from ast import literal_eval
1618
import sys
1719

@@ -1183,75 +1185,8 @@ def calib_func(prepared_model):
11831185
)
11841186
if args.low_precision_checkpoint != "":
11851187
pathname = args.low_precision_checkpoint
1186-
assert os.path.exists(pathname), f"Checkpoint file does not exist: {pathname}"
1187-
if os.path.isfile(pathname):
1188-
low_precision_checkpoint = None
1189-
if pathname.endswith((".pt", ".pth", ".bin")):
1190-
low_precision_checkpoint = torch.load(pathname, weights_only=True)
1191-
elif pathname.endswith(".safetensors"):
1192-
try:
1193-
import safetensors
1194-
except ImportError:
1195-
print(
1196-
"Please install safetensors package to load safetensors checkpoint."
1197-
)
1198-
exit(1)
1199-
low_precision_checkpoint = safetensors.torch.load_file(pathname)
1200-
assert (
1201-
low_precision_checkpoint is not None
1202-
), f"Invalid checkpoint file: {pathname}. Should be a .pt, .pth, .bin or .safetensors file."
1203-
1204-
quant_method = {"quant_method": "gptq"}
1205-
1206-
elif os.path.isdir(pathname):
1207-
low_precision_checkpoint = {}
1208-
for pattern in ["*.pt", "*.pth", "*.bin"]:
1209-
files = list(pathlib.Path(pathname).glob(pattern))
1210-
if files:
1211-
for f in files:
1212-
data_f = torch.load(f, weights_only=True)
1213-
low_precision_checkpoint.update(data_f)
1214-
break
1215-
if not low_precision_checkpoint:
1216-
files = list(pathlib.Path(pathname).glob("*.safetensors"))
1217-
if files:
1218-
try:
1219-
import safetensors
1220-
except ImportError:
1221-
print(
1222-
"Please install safetensors package to load safetensors checkpoint."
1223-
)
1224-
exit(1)
1225-
for f in files:
1226-
data_f = safetensors.torch.load_file(f)
1227-
low_precision_checkpoint.update(data_f)
1228-
assert (
1229-
len(low_precision_checkpoint) > 0
1230-
), f"Cannot find checkpoint (.pt/.pth/.bin/.safetensors) files in path {pathname}."
1231-
1232-
try:
1233-
with open(pathname + "/config.json") as f:
1234-
quant_model_config = json.load(f)
1235-
quant_method = {
1236-
"quant_method": quant_model_config["quantization_config"][
1237-
"quant_method"
1238-
]
1239-
}
1240-
except Exception as e:
1241-
print(
1242-
"warning: loading HF config.json to get `quant_method` failed, due to ",
1243-
e,
1244-
)
1245-
print("warning: specifying `quant_method` = `gptq` by default.")
1246-
quant_method = {"quant_method": "gptq"}
1247-
1248-
else:
1249-
raise AssertionError(
1250-
f"Invalid low-precision-checkpoint: {pathname}."
1251-
" Should be a .pt/.pth/.safetensors file or a directory containing them."
1252-
)
1253-
1254-
low_precision_checkpoint = (low_precision_checkpoint, quant_method)
1188+
low_precision_checkpoint, quant_config = load_low_precision_checkpoint(pathname)
1189+
low_precision_checkpoint = (low_precision_checkpoint, quant_config)
12551190

12561191
if args.gptq_legacy_format:
12571192
raise AssertionError(

0 commit comments

Comments
 (0)