Skip to content

Commit f777de7

Browse files
committed
updates
1 parent 4ea5ee6 commit f777de7

File tree

2 files changed

+20
-3
lines changed

2 files changed

+20
-3
lines changed

examples/models/llama2/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,12 @@ if(EXECUTORCH_BUILD_KERNELS_CUSTOM)
121121
list(APPEND link_libraries custom_ops)
122122
endif()
123123

124+
if (EXECUTORCH_BUILD_TORCHAO)
125+
set(torchao_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../../third-party/torchao)
126+
find_package(torchao CONFIG REQUIRED)
127+
list(APPEND link_libraries ${TORCHAO_LIBRARIES})
128+
endif()
129+
124130
set(XNNPACK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../backends/xnnpack)
125131
# Extra compile option and include dir for pthreadpool
126132
if(EXECUTORCH_BUILD_PTHREADPOOL)

examples/models/llama2/source_transformation/quantize.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,26 @@ def quantize(
6868
if qmode == "int8":
6969
# Add quantization mode options here: group size, bit width, etc.
7070
return WeightOnlyInt8QuantHandler(model).quantized_model()
71+
elif qmode.startswith("torchao"):
72+
# format is torchao:8daxw
73+
bitwidth = int(qmode[len("torchao:8da")])
74+
if group_size is None:
75+
raise Exception(f"For {qmode} quantization, group size must be specified.")
76+
from torchao.experimental.quant_api import Int8DynActIntxWeightQuantizer
77+
model = Int8DynActIntxWeightQuantizer(
78+
device="cpu",
79+
precision=torch_dtype, groupsize=group_size, bitwidth=bitwidth, has_weight_zeros=False).quantize(model)
80+
if verbose:
81+
print("quantized model:", model)
82+
return model
7183
elif qmode == "8da4w":
7284
# Check for required args
7385
if group_size is None:
7486
raise Exception("For 8da4w quantization, group size must be specified.")
7587
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
7688

77-
model = Int8DynActInt4WeightQuantizer(
78-
precision=torch_dtype, groupsize=group_size
79-
).quantize(model)
89+
model = Int8DynActInt4WeightQuantizer(precision=torch_dtype, groupsize=group_size, bitwidth=4).quantize(model)
90+
8091
if verbose:
8192
print("quantized model:", model)
8293
return model

0 commit comments

Comments
 (0)