Skip to content

Commit c82f77b

Browse files
committed
update quant
1 parent 91bab8b commit c82f77b

File tree

1 file changed

+8
-1
lines changed
  • examples/models/llama2/source_transformation

1 file changed

+8
-1
lines changed

examples/models/llama2/source_transformation/quantize.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,13 @@ def quantize( # noqa: C901
7171
# Add quantization mode options here: group size, bit width, etc.
7272
return WeightOnlyInt8QuantHandler(model).quantized_model()
7373
elif qmode.startswith("torchao:"):
74+
import os
75+
import glob
76+
libs = glob.glob(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../cmake-out/lib/libtorchao_ops_aten.*")))
77+
assert len(libs) == 1, f"Expected 1 library but got {len(libs)}"
78+
logging.info(f"Loading custom ops library: {libs[0]}")
79+
torch.ops.load_library(libs[0])
80+
7481
logging.warning(
7582
"When qmode is torchao, the groupsize is obtained from the qmode string with regex parse; blocksize is ignored."
7683
)
@@ -105,7 +112,7 @@ def quantize( # noqa: C901
105112
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
106113

107114
model = Int8DynActInt4WeightQuantizer(
108-
precision=torch_dtype, groupsize=group_size, bitwidth=4
115+
precision=torch_dtype, groupsize=group_size
109116
).quantize(model)
110117

111118
if verbose:

0 commit comments

Comments
 (0)