diff --git a/examples/models/llama/CMakeLists.txt b/examples/models/llama/CMakeLists.txt index 6a4aee11d22..5f49581ea25 100644 --- a/examples/models/llama/CMakeLists.txt +++ b/examples/models/llama/CMakeLists.txt @@ -128,6 +128,13 @@ if(EXECUTORCH_BUILD_TORCHAO) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../../third-party/ao/torchao/experimental ${CMAKE_CURRENT_BINARY_DIR}/../../../third-party/ao/torchao/experimental) target_link_options_shared_lib(torchao_ops_executorch) list(APPEND link_libraries torchao_ops_executorch) + if(CMAKE_SYSTEM_NAME STREQUAL "Darwin" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") + add_subdirectory( + ${CMAKE_CURRENT_SOURCE_DIR}/../../../third-party/ao/torchao/experimental/ops/mps + ${CMAKE_CURRENT_BINARY_DIR}/../../../third-party/ao/torchao/experimental/ops/mps) + target_link_options_shared_lib(torchao_ops_mps_executorch) + list(APPEND link_libraries torchao_ops_mps_executorch) + endif() endif() set(XNNPACK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../backends/xnnpack) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index ea4296cc52c..3e1ccb13121 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -600,7 +600,7 @@ def get_quantizer_and_quant_params(args): def _qmode_type(value): choices = ["int8", "8da4w", "8da4w-gptq", "vulkan_4w"] - patterns = [r"torchao:8da(\d+)w"] + patterns = [r"torchao:8da(\d+)w", r"torchao:fpa(\d+)w"] if value in choices: return value diff --git a/examples/models/llama/source_transformation/quantize.py b/examples/models/llama/source_transformation/quantize.py index f8952ad0e53..4828a56d07c 100644 --- a/examples/models/llama/source_transformation/quantize.py +++ b/examples/models/llama/source_transformation/quantize.py @@ -72,12 +72,35 @@ def quantize( # noqa C901 if qmode == "int8": # Add quantization mode options here: group size, bit width, etc. return WeightOnlyInt8QuantHandler(model).quantized_model() - elif qmode.startswith("torchao:"): + elif qmode.startswith("torchao:fpa"): + pattern = r"torchao:fpa(\d+)w" + matches = re.findall(pattern, qmode) + assert len(matches) == 1, f"Expected 1 match for pattern but got {len(matches)}" + bitwidth = int(matches[0][0]) + _load_torchao_aten_lib(libname="libtorchao_ops_mps_aten") + from torchao.experimental.quant_api import UIntxWeightOnlyLinearQuantizer + + with torch.no_grad(): + model = ( + UIntxWeightOnlyLinearQuantizer( + device="mps", + precision=torch.float32, + groupsize=group_size, + bitwidth=bitwidth, + ) + .quantize(model) + .to("cpu") + ) + + if verbose: + print("quantized model:", model) + return model + elif qmode.startswith("torchao:8da"): pattern = r"torchao:8da(\d+)w" matches = re.findall(pattern, qmode) assert len(matches) == 1, f"Expected 1 match for pattern but got {len(matches)}" bitwidth = int(matches[0][0]) - _load_torchao_ops_aten() + _load_torchao_aten_lib(libname="libtorchao_ops_aten") from torchao.experimental.quant_api import Int8DynActIntxWeightLinearQuantizer with torch.no_grad(): @@ -729,7 +752,7 @@ def get_quant_embedding_transform(args): bitwidth, group_size = args.embedding_quantize.split(":")[1].split(",") group_size = int(group_size) bitwidth = int(bitwidth) - _load_torchao_ops_aten() + _load_torchao_aten_lib(libname="libtorchao_ops_aten") from torchao.experimental.quant_api import IntxWeightEmbeddingQuantizer def _torchao_embedding_quantizer(model): @@ -785,7 +808,7 @@ def get_quant_weight_transform(args, dtype_override, verbose): ) -def _load_torchao_ops_aten(): +def _load_torchao_aten_lib(libname): import glob import os @@ -793,7 +816,7 @@ def _load_torchao_ops_aten(): os.path.abspath( os.path.join( os.environ.get("CMAKE_INSTALL_PREFIX", ""), - "lib/libtorchao_ops_aten.*", + f"lib/{libname}.*", ) ) ) diff --git a/third-party/ao b/third-party/ao index 75d06933aac..ebc43034e66 160000 --- a/third-party/ao +++ b/third-party/ao @@ -1 +1 @@ -Subproject commit 75d06933aace9d1ce803158e52910e4c9fc60981 +Subproject commit ebc43034e665bcda759cf9ef9c2c207057c5eeb1