Skip to content

Commit 9ebd761

Browse files
committed
lints
1 parent d4f6cd2 commit 9ebd761

File tree

2 files changed

+23
-8
lines changed

2 files changed

+23
-8
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def build_args_parser() -> argparse.ArgumentParser:
157157
def _is_valid_torchao_qmode_type(value):
158158
if not value.startswith("torchao:"):
159159
return False
160-
160+
161161
patterns = [
162162
r"emb.(\d+),(\d+)&lin8da.(\d+),(\d+)",
163163
r"emb.(\d+),(\d+)",

examples/models/llama/source_transformation/quantize.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from executorch.extension.llm.export.builder import DType
2020

2121
from sentencepiece import SentencePieceProcessor
22-
from torch.nn.modules import linear
2322

2423
try:
2524
from fairseq2.nn.embedding import (
@@ -74,9 +73,17 @@ def quantize( # noqa C901
7473
# Add quantization mode options here: group size, bit width, etc.
7574
return WeightOnlyInt8QuantHandler(model).quantized_model()
7675
elif qmode.startswith("torchao:"):
77-
import os
7876
import glob
79-
libs = glob.glob(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../cmake-out/third-party/ao/torchao/experimental/libtorchao_ops_aten.*")))
77+
import os
78+
79+
libs = glob.glob(
80+
os.path.abspath(
81+
os.path.join(
82+
os.path.dirname(__file__),
83+
"../../../../cmake-out/third-party/ao/torchao/experimental/libtorchao_ops_aten.*",
84+
)
85+
)
86+
)
8087
assert len(libs) == 1, f"Expected 1 library but got {len(libs)}"
8188
logging.info(f"Loading custom ops library: {libs[0]}")
8289
torch.ops.load_library(libs[0])
@@ -89,24 +96,32 @@ def quantize( # noqa C901
8996

9097
linear_matches = re.findall(linear_pattern, qmode)
9198
if linear_matches:
92-
assert len(linear_matches) == 1, f"Expected 1 match but got {len(linear_matches)}"
99+
assert (
100+
len(linear_matches) == 1
101+
), f"Expected 1 match but got {len(linear_matches)}"
93102
bitwidth = int(linear_matches[0][0])
94103
groupsize = int(linear_matches[0][1])
95-
from torchao.experimental.quant_api import Int8DynActIntxWeightLinearQuantizer
104+
from torchao.experimental.quant_api import (
105+
Int8DynActIntxWeightLinearQuantizer,
106+
)
107+
96108
model = Int8DynActIntxWeightLinearQuantizer(
97109
device="cpu",
98110
precision=torch_dtype,
99111
groupsize=groupsize,
100112
bitwidth=bitwidth,
101113
has_weight_zeros=False,
102114
).quantize(model)
103-
115+
104116
embedding_matches = re.findall(embedding_pattern, qmode)
105117
if embedding_matches:
106-
assert len(embedding_matches) == 1, f"Expected 1 match but got {len(embedding_matches)}"
118+
assert (
119+
len(embedding_matches) == 1
120+
), f"Expected 1 match but got {len(embedding_matches)}"
107121
bitwidth = int(embedding_matches[0][0])
108122
groupsize = int(embedding_matches[0][1])
109123
from torchao.experimental.quant_api import IntxWeightEmbeddingQuantizer
124+
110125
model = IntxWeightEmbeddingQuantizer(
111126
device="cpu",
112127
precision=torch_dtype,

0 commit comments

Comments
 (0)