Skip to content

Commit 853be03

Browse files
committed
updates
1 parent fbf5b6e commit 853be03

File tree

4 files changed

+66
-33
lines changed

4 files changed

+66
-33
lines changed

CMakeLists.txt

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -833,13 +833,5 @@ if(EXECUTORCH_BUILD_VULKAN)
833833
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/vulkan)
834834
endif()
835835

836-
# if(EXECUTORCH_BUILD_TORCHAO)
837-
# add_compile_options("-frtti")
838-
# set(EXECUTORCH_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/..)
839-
# set(EXECUTORCH_LIBRARIES executorch extension_threadpool) # cpuinfo pthreadpool)
840-
# set(TORCHAO_BUILD_EXECUTORCH_OPS ON)
841-
# add_subdirectory(third-party/ao/torchao/experimental)
842-
# endif()
843-
844836
# Print all summary
845837
executorch_print_configuration_summary()

examples/models/llama2/CMakeLists.txt

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ cmake_dependent_option(
3737
"NOT EXECUTORCH_BUILD_ARM_BAREMETAL" OFF
3838
)
3939

40+
option(EXECUTORCH_BUILD_TORCHAO "Build the torchao kernels" OFF)
41+
4042
if(NOT PYTHON_EXECUTABLE)
4143
set(PYTHON_EXECUTABLE python3)
4244
endif()
@@ -122,17 +124,7 @@ if(EXECUTORCH_BUILD_KERNELS_CUSTOM)
122124
endif()
123125

124126
if(EXECUTORCH_BUILD_TORCHAO)
125-
# Method1: torchao has a config
126-
# set(torchao_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../../lib/cmake/torchao)
127-
# find_package(torchao REQUIRED)
128-
# target_link_options_shared_lib(torchao::torchao_ops_executorch)
129-
# list(APPEND link_libraries torchao::torchao_ops_executorch)
130-
131-
# Method2: torchao is built at top-level CMakeLists.txt
132-
# list(APPEND link_libraries "$<LINK_LIBRARY:WHOLE_ARCHIVE,${CMAKE_CURRENT_BINARY_DIR}/../../../lib/libtorchao_ops_executorch.a>")
133-
# list(APPEND link_libraries "${CMAKE_CURRENT_BINARY_DIR}/../../../lib/libtorchao_kernels_aarch64.a")
134-
135-
# Method3: submodule
127+
set(TORCHAO_BUILD_EXECUTORCH_OPS ON)
136128
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../../third-party/ao/torchao/experimental ${CMAKE_CURRENT_BINARY_DIR}/../../../third-party/ao/torchao/experimental)
137129
target_link_options_shared_lib(torchao_ops_executorch)
138130
list(APPEND link_libraries torchao_ops_executorch)

examples/models/llama2/export_llama_lib.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@
1212
import copy
1313
import json
1414
import logging
15+
import re
1516
import shlex
1617
from enum import Enum
1718
from json import JSONDecodeError
1819
from pathlib import Path
1920
from typing import Callable, List, Optional, Union
2021

2122
import pkg_resources
22-
2323
import torch
2424

2525
from executorch.devtools.etrecord import generate_etrecord
@@ -152,12 +152,41 @@ def build_args_parser() -> argparse.ArgumentParser:
152152
],
153153
help="Use PT2E quantization. Comma separated options. e.g. xnnpack_dynamic (for per channel 8 bit weight), xnnpack_dynamic_qc4 (for per channel 4 bit weight), embedding.",
154154
)
155+
156+
def _is_valid_torchao_qmode_type(value):
157+
if not value.startswith("torchao:"):
158+
return False
159+
160+
linear_pattern = r"lin.8da(\d+)b(\d+)gw"
161+
linear_matches = re.findall(linear_pattern, value)
162+
print("LINEAR MATCHES", linear_matches)
163+
164+
if len(linear_matches) > 1:
165+
return False
166+
167+
embedding_pattern = r"emb.(\d+)b(\d+)gw"
168+
embedding_matches = re.findall(embedding_pattern, value)
169+
print("EMBEDDING MATCHES", embedding_matches)
170+
if len(embedding_matches) > 1:
171+
return False
172+
if len(linear_matches) + len(embedding_matches) == 0:
173+
return False
174+
return True
175+
176+
def _qmode_type(value):
177+
choices = ["int8", "8da4w", "8da4w-gptq"]
178+
if not (value in choices or _is_valid_torchao_qmode_type(value)):
179+
raise argparse.ArgumentTypeError(
180+
f"Value must be one of: {choices} or a valid torchao regex"
181+
)
182+
return value
183+
155184
parser.add_argument(
156185
"-qmode",
157186
"--quantization_mode",
158-
type=str,
187+
type=_qmode_type,
159188
default=None,
160-
choices=["int8", "8da4w", "8da4w-gptq"],
189+
# choices=["int8", "8da4w", "8da4w-gptq"] + [f"torchao:8da{x}w" for x in range(1, 9)],
161190
help="type of quantization",
162191
)
163192

examples/models/llama2/source_transformation/quantize.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import logging
8+
import re
79
from functools import partial
810
from pathlib import Path
911
from typing import Any, Dict, Optional
@@ -31,7 +33,7 @@
3133
fsLinear = nn.Linear
3234

3335

34-
def quantize(
36+
def quantize( # noqa: C901
3537
model: torch.nn.Module,
3638
qmode: str,
3739
activation_dtype: Optional[DType],
@@ -68,25 +70,43 @@ def quantize(
6870
if qmode == "int8":
6971
# Add quantization mode options here: group size, bit width, etc.
7072
return WeightOnlyInt8QuantHandler(model).quantized_model()
71-
elif qmode.startswith("torchao"):
72-
# format is torchao:8daxw
73-
bitwidth = int(qmode[len("torchao:8da")])
74-
if group_size is None:
75-
raise Exception(f"For {qmode} quantization, group size must be specified.")
76-
from torchao.experimental.quant_api import Int8DynActIntxWeightQuantizer
77-
model = Int8DynActIntxWeightQuantizer(
78-
device="cpu",
79-
precision=torch_dtype, groupsize=group_size, bitwidth=bitwidth, has_weight_zeros=False).quantize(model)
73+
elif qmode.startswith("torchao:"):
74+
logging.warning(
75+
"When qmode is torchao, the groupsize is obtained from the qmode string with regex parse; blocksize is ignored."
76+
)
77+
linear_pattern = r"lin.8da(\d+)b(\d+)gw"
78+
linear_matches = re.findall(linear_pattern, qmode)
79+
if linear_matches:
80+
bitwidth = int(linear_matches[0][0])
81+
group_size = int(linear_matches[0][1])
82+
from torchao.experimental.quant_api import Int8DynActIntxWeightQuantizer
83+
84+
model = Int8DynActIntxWeightQuantizer(
85+
device="cpu",
86+
precision=torch_dtype,
87+
groupsize=group_size,
88+
bitwidth=bitwidth,
89+
has_weight_zeros=False,
90+
).quantize(model)
91+
92+
embedding_pattern = r"emb.(\d+)b(\d+)gw"
93+
embedding_matches = re.findall(embedding_pattern, qmode)
94+
if embedding_matches:
95+
pass # TODO: add when embedding PR lands in torchao
96+
8097
if verbose:
8198
print("quantized model:", model)
99+
82100
return model
83101
elif qmode == "8da4w":
84102
# Check for required args
85103
if group_size is None:
86104
raise Exception("For 8da4w quantization, group size must be specified.")
87105
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
88106

89-
model = Int8DynActInt4WeightQuantizer(precision=torch_dtype, groupsize=group_size, bitwidth=4).quantize(model)
107+
model = Int8DynActInt4WeightQuantizer(
108+
precision=torch_dtype, groupsize=group_size, bitwidth=4
109+
).quantize(model)
90110

91111
if verbose:
92112
print("quantized model:", model)

0 commit comments

Comments
 (0)