Skip to content

Commit ea358d5

Browse files
committed
lints
1 parent 52f7a1f commit ea358d5

File tree

2 files changed

+23
-8
lines changed

2 files changed

+23
-8
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def build_args_parser() -> argparse.ArgumentParser:
156156
def _is_valid_torchao_qmode_type(value):
157157
if not value.startswith("torchao:"):
158158
return False
159-
159+
160160
patterns = [
161161
r"emb.(\d+),(\d+)&lin8da.(\d+),(\d+)",
162162
r"emb.(\d+),(\d+)",

examples/models/llama2/source_transformation/quantize.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from executorch.extension.llm.export.builder import DType
1818

1919
from sentencepiece import SentencePieceProcessor
20-
from torch.nn.modules import linear
2120

2221
try:
2322
from fairseq2.nn.embedding import (
@@ -72,9 +71,17 @@ def quantize( # noqa: C901
7271
# Add quantization mode options here: group size, bit width, etc.
7372
return WeightOnlyInt8QuantHandler(model).quantized_model()
7473
elif qmode.startswith("torchao:"):
75-
import os
7674
import glob
77-
libs = glob.glob(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../cmake-out/third-party/ao/torchao/experimental/libtorchao_ops_aten.*")))
75+
import os
76+
77+
libs = glob.glob(
78+
os.path.abspath(
79+
os.path.join(
80+
os.path.dirname(__file__),
81+
"../../../../cmake-out/third-party/ao/torchao/experimental/libtorchao_ops_aten.*",
82+
)
83+
)
84+
)
7885
assert len(libs) == 1, f"Expected 1 library but got {len(libs)}"
7986
logging.info(f"Loading custom ops library: {libs[0]}")
8087
torch.ops.load_library(libs[0])
@@ -87,24 +94,32 @@ def quantize( # noqa: C901
8794

8895
linear_matches = re.findall(linear_pattern, qmode)
8996
if linear_matches:
90-
assert len(linear_matches) == 1, f"Expected 1 match but got {len(linear_matches)}"
97+
assert (
98+
len(linear_matches) == 1
99+
), f"Expected 1 match but got {len(linear_matches)}"
91100
bitwidth = int(linear_matches[0][0])
92101
groupsize = int(linear_matches[0][1])
93-
from torchao.experimental.quant_api import Int8DynActIntxWeightLinearQuantizer
102+
from torchao.experimental.quant_api import (
103+
Int8DynActIntxWeightLinearQuantizer,
104+
)
105+
94106
model = Int8DynActIntxWeightLinearQuantizer(
95107
device="cpu",
96108
precision=torch_dtype,
97109
groupsize=groupsize,
98110
bitwidth=bitwidth,
99111
has_weight_zeros=False,
100112
).quantize(model)
101-
113+
102114
embedding_matches = re.findall(embedding_pattern, qmode)
103115
if embedding_matches:
104-
assert len(embedding_matches) == 1, f"Expected 1 match but got {len(embedding_matches)}"
116+
assert (
117+
len(embedding_matches) == 1
118+
), f"Expected 1 match but got {len(embedding_matches)}"
105119
bitwidth = int(embedding_matches[0][0])
106120
groupsize = int(embedding_matches[0][1])
107121
from torchao.experimental.quant_api import IntxWeightEmbeddingQuantizer
122+
108123
model = IntxWeightEmbeddingQuantizer(
109124
device="cpu",
110125
precision=torch_dtype,

0 commit comments

Comments
 (0)