Skip to content

Commit 3cc1d26

Browse files
committed
up
1 parent 9ebd761 commit 3cc1d26

File tree

1 file changed

+27
-25
lines changed

1 file changed

+27
-25
lines changed

examples/models/llama/source_transformation/quantize.py

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -94,40 +94,42 @@ def quantize( # noqa C901
9494
embedding_pattern = r"emb.(\d+),(\d+)"
9595
linear_pattern = r"lin8da.(\d+),(\d+)"
9696

97-
linear_matches = re.findall(linear_pattern, qmode)
98-
if linear_matches:
97+
matches = re.findall(linear_pattern, qmode)
98+
if matches:
9999
assert (
100-
len(linear_matches) == 1
101-
), f"Expected 1 match but got {len(linear_matches)}"
102-
bitwidth = int(linear_matches[0][0])
103-
groupsize = int(linear_matches[0][1])
100+
len(matches) == 1
101+
), f"Expected 1 match for linear_pattern but got {len(matches)}"
102+
bitwidth = int(matches[0][0])
103+
groupsize = int(matches[0][1])
104104
from torchao.experimental.quant_api import (
105105
Int8DynActIntxWeightLinearQuantizer,
106106
)
107107

108-
model = Int8DynActIntxWeightLinearQuantizer(
109-
device="cpu",
110-
precision=torch_dtype,
111-
groupsize=groupsize,
112-
bitwidth=bitwidth,
113-
has_weight_zeros=False,
114-
).quantize(model)
108+
with torch.no_grad():
109+
model = Int8DynActIntxWeightLinearQuantizer(
110+
device="cpu",
111+
precision=torch_dtype,
112+
groupsize=groupsize,
113+
bitwidth=bitwidth,
114+
has_weight_zeros=False,
115+
).quantize(model)
115116

116-
embedding_matches = re.findall(embedding_pattern, qmode)
117-
if embedding_matches:
117+
matches = re.findall(embedding_pattern, qmode)
118+
if matches:
118119
assert (
119-
len(embedding_matches) == 1
120-
), f"Expected 1 match but got {len(embedding_matches)}"
121-
bitwidth = int(embedding_matches[0][0])
122-
groupsize = int(embedding_matches[0][1])
120+
len(matches) == 1
121+
), f"Expected 1 match for embedding_pattern but got {len(matches)}"
122+
bitwidth = int(matches[0][0])
123+
groupsize = int(matches[0][1])
123124
from torchao.experimental.quant_api import IntxWeightEmbeddingQuantizer
124125

125-
model = IntxWeightEmbeddingQuantizer(
126-
device="cpu",
127-
precision=torch_dtype,
128-
bitwidth=bitwidth,
129-
groupsize=groupsize,
130-
).quantize(model)
126+
with torch.no_grad():
127+
model = IntxWeightEmbeddingQuantizer(
128+
device="cpu",
129+
precision=torch_dtype,
130+
bitwidth=bitwidth,
131+
groupsize=groupsize,
132+
).quantize(model)
131133

132134
if verbose:
133135
print("quantized model:", model)

0 commit comments

Comments
 (0)