Skip to content

Commit ed43cca

Browse files
committed
updates
1 parent 11e2e5c commit ed43cca

File tree

4 files changed

+64
-32
lines changed

4 files changed

+64
-32
lines changed

CMakeLists.txt

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -836,13 +836,5 @@ if(EXECUTORCH_BUILD_VULKAN)
836836
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/vulkan)
837837
endif()
838838

839-
# if(EXECUTORCH_BUILD_TORCHAO)
840-
# add_compile_options("-frtti")
841-
# set(EXECUTORCH_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/..)
842-
# set(EXECUTORCH_LIBRARIES executorch extension_threadpool) # cpuinfo pthreadpool)
843-
# set(TORCHAO_BUILD_EXECUTORCH_OPS ON)
844-
# add_subdirectory(third-party/ao/torchao/experimental)
845-
# endif()
846-
847839
# Print all summary
848840
executorch_print_configuration_summary()

examples/models/llama/CMakeLists.txt

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ cmake_dependent_option(
3737
"NOT EXECUTORCH_BUILD_ARM_BAREMETAL" OFF
3838
)
3939

40+
option(EXECUTORCH_BUILD_TORCHAO "Build the torchao kernels" OFF)
41+
4042
if(NOT PYTHON_EXECUTABLE)
4143
set(PYTHON_EXECUTABLE python3)
4244
endif()
@@ -122,17 +124,7 @@ if(EXECUTORCH_BUILD_KERNELS_CUSTOM)
122124
endif()
123125

124126
if(EXECUTORCH_BUILD_TORCHAO)
125-
# Method1: torchao has a config
126-
# set(torchao_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../../lib/cmake/torchao)
127-
# find_package(torchao REQUIRED)
128-
# target_link_options_shared_lib(torchao::torchao_ops_executorch)
129-
# list(APPEND link_libraries torchao::torchao_ops_executorch)
130-
131-
# Method2: torchao is built at top-level CMakeLists.txt
132-
# list(APPEND link_libraries "$<LINK_LIBRARY:WHOLE_ARCHIVE,${CMAKE_CURRENT_BINARY_DIR}/../../../lib/libtorchao_ops_executorch.a>")
133-
# list(APPEND link_libraries "${CMAKE_CURRENT_BINARY_DIR}/../../../lib/libtorchao_kernels_aarch64.a")
134-
135-
# Method3: submodule
127+
set(TORCHAO_BUILD_EXECUTORCH_OPS ON)
136128
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../../third-party/ao/torchao/experimental ${CMAKE_CURRENT_BINARY_DIR}/../../../third-party/ao/torchao/experimental)
137129
target_link_options_shared_lib(torchao_ops_executorch)
138130
list(APPEND link_libraries torchao_ops_executorch)

examples/models/llama/export_llama_lib.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@
1212
import copy
1313
import json
1414
import logging
15+
import re
1516
import shlex
1617
from enum import Enum
1718
from json import JSONDecodeError
1819
from pathlib import Path
1920
from typing import Callable, List, Optional, Union
2021

2122
import pkg_resources
22-
2323
import torch
2424

2525
from executorch.devtools.etrecord import generate_etrecord
@@ -153,12 +153,40 @@ def build_args_parser() -> argparse.ArgumentParser:
153153
],
154154
help="Use PT2E quantization. Comma separated options. e.g. xnnpack_dynamic (for per channel 8 bit weight), xnnpack_dynamic_qc4 (for per channel 4 bit weight), embedding.",
155155
)
156+
157+
def _is_valid_torchao_qmode_type(value):
158+
if not value.startswith("torchao:"):
159+
return False
160+
161+
linear_pattern = r"lin.8da(\d+)b(\d+)gw"
162+
linear_matches = re.findall(linear_pattern, value)
163+
print("LINEAR MATCHES", linear_matches)
164+
165+
if len(linear_matches) > 1:
166+
return False
167+
168+
embedding_pattern = r"emb.(\d+)b(\d+)gw"
169+
embedding_matches = re.findall(embedding_pattern, value)
170+
print("EMBEDDING MATCHES", embedding_matches)
171+
if len(embedding_matches) > 1:
172+
return False
173+
if len(linear_matches) + len(embedding_matches) == 0:
174+
return False
175+
return True
176+
177+
def _qmode_type(value):
178+
choices = ["int8", "8da4w", "8da4w-gptq", "vulkan_4w"]
179+
if not (value in choices or _is_valid_torchao_qmode_type(value)):
180+
raise argparse.ArgumentTypeError(
181+
f"Value must be one of: {choices} or a valid torchao regex"
182+
)
183+
return value
184+
156185
parser.add_argument(
157186
"-qmode",
158187
"--quantization_mode",
159-
type=str,
188+
type=_qmode_type,
160189
default=None,
161-
choices=["int8", "8da4w", "8da4w-gptq", "vulkan_4w"],
162190
help="type of quantization",
163191
)
164192

examples/models/llama/source_transformation/quantize.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import logging
8+
import re
79
from functools import partial
810
from pathlib import Path
911
from typing import Any, Dict, Optional
@@ -70,25 +72,43 @@ def quantize( # noqa C901
7072
if qmode == "int8":
7173
# Add quantization mode options here: group size, bit width, etc.
7274
return WeightOnlyInt8QuantHandler(model).quantized_model()
73-
elif qmode.startswith("torchao"):
74-
# format is torchao:8daxw
75-
bitwidth = int(qmode[len("torchao:8da")])
76-
if group_size is None:
77-
raise Exception(f"For {qmode} quantization, group size must be specified.")
78-
from torchao.experimental.quant_api import Int8DynActIntxWeightQuantizer
79-
model = Int8DynActIntxWeightQuantizer(
80-
device="cpu",
81-
precision=torch_dtype, groupsize=group_size, bitwidth=bitwidth, has_weight_zeros=False).quantize(model)
75+
elif qmode.startswith("torchao:"):
76+
logging.warning(
77+
"When qmode is torchao, the groupsize is obtained from the qmode string with regex parse; blocksize is ignored."
78+
)
79+
linear_pattern = r"lin.8da(\d+)b(\d+)gw"
80+
linear_matches = re.findall(linear_pattern, qmode)
81+
if linear_matches:
82+
bitwidth = int(linear_matches[0][0])
83+
group_size = int(linear_matches[0][1])
84+
from torchao.experimental.quant_api import Int8DynActIntxWeightQuantizer
85+
86+
model = Int8DynActIntxWeightQuantizer(
87+
device="cpu",
88+
precision=torch_dtype,
89+
groupsize=group_size,
90+
bitwidth=bitwidth,
91+
has_weight_zeros=False,
92+
).quantize(model)
93+
94+
embedding_pattern = r"emb.(\d+)b(\d+)gw"
95+
embedding_matches = re.findall(embedding_pattern, qmode)
96+
if embedding_matches:
97+
pass # TODO: add when embedding PR lands in torchao
98+
8299
if verbose:
83100
print("quantized model:", model)
101+
84102
return model
85103
elif qmode == "8da4w":
86104
# Check for required args
87105
if group_size is None:
88106
raise Exception("For 8da4w quantization, group size must be specified.")
89107
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
90108

91-
model = Int8DynActInt4WeightQuantizer(precision=torch_dtype, groupsize=group_size, bitwidth=4).quantize(model)
109+
model = Int8DynActInt4WeightQuantizer(
110+
precision=torch_dtype, groupsize=group_size, bitwidth=4
111+
).quantize(model)
92112

93113
if verbose:
94114
print("quantized model:", model)

0 commit comments

Comments
 (0)