Skip to content

Commit 0381a4e

Browse files
Add torchao mps lowbit ops to llama runner
1 parent f477fd5 commit 0381a4e

File tree

3 files changed

+33
-6
lines changed

3 files changed

+33
-6
lines changed

examples/models/llama/CMakeLists.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ cmake_dependent_option(
3838
)
3939

4040
option(EXECUTORCH_BUILD_TORCHAO "Build the torchao kernels" OFF)
41+
option(EXECUTORCH_BUILD_TORCHAO_MPS "Build the torchao mps kernels" OFF)
4142

4243
if(NOT PYTHON_EXECUTABLE)
4344
set(PYTHON_EXECUTABLE python3)
@@ -130,6 +131,13 @@ if(EXECUTORCH_BUILD_TORCHAO)
130131
list(APPEND link_libraries torchao_ops_executorch)
131132
endif()
132133

134+
if(EXECUTORCH_BUILD_TORCHAO_MPS)
135+
set(TORCHAO_BUILD_EXECUTORCH_OPS ON)
136+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../../third-party/ao/torchao/experimental/ops/mps ${CMAKE_CURRENT_BINARY_DIR}/../../../third-party/ao/torchao/experimental/ops/mps)
137+
target_link_options_shared_lib(torchao_ops_mps_executorch)
138+
list(APPEND link_libraries torchao_ops_mps_executorch)
139+
endif()
140+
133141
set(XNNPACK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../backends/xnnpack)
134142
# Extra compile option and include dir for pthreadpool
135143
if(EXECUTORCH_BUILD_PTHREADPOOL)

examples/models/llama/export_llama_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,7 @@ def get_quantizer_and_quant_params(args):
593593

594594
def _qmode_type(value):
595595
choices = ["int8", "8da4w", "8da4w-gptq", "vulkan_4w"]
596-
patterns = [r"torchao:8da(\d+)w"]
596+
patterns = [r"torchao:8da(\d+)w", r"torchao:fpa(\d+)w"]
597597

598598
if value in choices:
599599
return value

examples/models/llama/source_transformation/quantize.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,31 @@ def quantize( # noqa C901
7272
if qmode == "int8":
7373
# Add quantization mode options here: group size, bit width, etc.
7474
return WeightOnlyInt8QuantHandler(model).quantized_model()
75-
elif qmode.startswith("torchao:"):
75+
elif qmode.startswith("torchao:fpa"):
76+
pattern = r"torchao:fpa(\d+)w"
77+
matches = re.findall(pattern, qmode)
78+
assert len(matches) == 1, f"Expected 1 match for pattern but got {len(matches)}"
79+
bitwidth = int(matches[0][0])
80+
_load_torchao_aten_lib(libname="libtorchao_ops_mps_linear_fp_act_xbit_weight_aten")
81+
from torchao.experimental.quant_api import UIntxWeightOnlyLinearQuantizer
82+
83+
with torch.no_grad():
84+
model = UIntxWeightOnlyLinearQuantizer(
85+
device="mps",
86+
precision=torch.float32,
87+
groupsize=group_size,
88+
bitwidth=bitwidth
89+
).quantize(model)
90+
91+
if verbose:
92+
print("quantized model:", model)
93+
return model
94+
elif qmode.startswith("torchao:8da"):
7695
pattern = r"torchao:8da(\d+)w"
7796
matches = re.findall(pattern, qmode)
7897
assert len(matches) == 1, f"Expected 1 match for pattern but got {len(matches)}"
7998
bitwidth = int(matches[0][0])
80-
_load_torchao_ops_aten()
99+
_load_torchao_aten_lib(libname="libtorchao_ops_aten")
81100
from torchao.experimental.quant_api import Int8DynActIntxWeightLinearQuantizer
82101

83102
with torch.no_grad():
@@ -729,7 +748,7 @@ def get_quant_embedding_transform(args):
729748
bitwidth, group_size = args.embedding_quantize.split(":")[1].split(",")
730749
group_size = int(group_size)
731750
bitwidth = int(bitwidth)
732-
_load_torchao_ops_aten()
751+
_load_torchao_aten_lib(libname="libtorchao_ops_aten")
733752
from torchao.experimental.quant_api import IntxWeightEmbeddingQuantizer
734753

735754
def _torchao_embedding_quantizer(model):
@@ -785,15 +804,15 @@ def get_quant_weight_transform(args, dtype_override, verbose):
785804
)
786805

787806

788-
def _load_torchao_ops_aten():
807+
def _load_torchao_aten_lib(libname):
789808
import glob
790809
import os
791810

792811
libs = glob.glob(
793812
os.path.abspath(
794813
os.path.join(
795814
os.environ.get("CMAKE_INSTALL_PREFIX", ""),
796-
"lib/libtorchao_ops_aten.*",
815+
f"lib/{libname}.*",
797816
)
798817
)
799818
)

0 commit comments

Comments
 (0)