Skip to content

Commit 9381798

Browse files
committed
updates
1 parent ec73228 commit 9381798

File tree

2 files changed

+20
-3
lines changed

2 files changed

+20
-3
lines changed

examples/models/llama/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,12 @@ if(EXECUTORCH_BUILD_KERNELS_CUSTOM)
121121
list(APPEND link_libraries custom_ops)
122122
endif()
123123

124+
if (EXECUTORCH_BUILD_TORCHAO)
125+
set(torchao_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../../third-party/torchao)
126+
find_package(torchao CONFIG REQUIRED)
127+
list(APPEND link_libraries ${TORCHAO_LIBRARIES})
128+
endif()
129+
124130
set(XNNPACK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../backends/xnnpack)
125131
# Extra compile option and include dir for pthreadpool
126132
if(EXECUTORCH_BUILD_PTHREADPOOL)

examples/models/llama/source_transformation/quantize.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,15 +70,26 @@ def quantize( # noqa C901
7070
if qmode == "int8":
7171
# Add quantization mode options here: group size, bit width, etc.
7272
return WeightOnlyInt8QuantHandler(model).quantized_model()
73+
elif qmode.startswith("torchao"):
74+
# format is torchao:8daxw
75+
bitwidth = int(qmode[len("torchao:8da")])
76+
if group_size is None:
77+
raise Exception(f"For {qmode} quantization, group size must be specified.")
78+
from torchao.experimental.quant_api import Int8DynActIntxWeightQuantizer
79+
model = Int8DynActIntxWeightQuantizer(
80+
device="cpu",
81+
precision=torch_dtype, groupsize=group_size, bitwidth=bitwidth, has_weight_zeros=False).quantize(model)
82+
if verbose:
83+
print("quantized model:", model)
84+
return model
7385
elif qmode == "8da4w":
7486
# Check for required args
7587
if group_size is None:
7688
raise Exception("For 8da4w quantization, group size must be specified.")
7789
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
7890

79-
model = Int8DynActInt4WeightQuantizer(
80-
precision=torch_dtype, groupsize=group_size
81-
).quantize(model)
91+
model = Int8DynActInt4WeightQuantizer(precision=torch_dtype, groupsize=group_size, bitwidth=4).quantize(model)
92+
8293
if verbose:
8394
print("quantized model:", model)
8495
return model

0 commit comments

Comments
 (0)