Skip to content

Commit a24487d

Browse files
committed
update quant
1 parent ed43cca commit a24487d

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

examples/models/llama/source_transformation/quantize.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,13 @@ def quantize( # noqa C901
7373
# Add quantization mode options here: group size, bit width, etc.
7474
return WeightOnlyInt8QuantHandler(model).quantized_model()
7575
elif qmode.startswith("torchao:"):
76+
import os
77+
import glob
78+
libs = glob.glob(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../cmake-out/lib/libtorchao_ops_aten.*")))
79+
assert len(libs) == 1, f"Expected 1 library but got {len(libs)}"
80+
logging.info(f"Loading custom ops library: {libs[0]}")
81+
torch.ops.load_library(libs[0])
82+
7683
logging.warning(
7784
"When qmode is torchao, the groupsize is obtained from the qmode string with regex parse; blocksize is ignored."
7885
)
@@ -107,7 +114,7 @@ def quantize( # noqa C901
107114
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
108115

109116
model = Int8DynActInt4WeightQuantizer(
110-
precision=torch_dtype, groupsize=group_size, bitwidth=4
117+
precision=torch_dtype, groupsize=group_size
111118
).quantize(model)
112119

113120
if verbose:

0 commit comments

Comments
 (0)