Skip to content

Commit 939fe29

Browse files
committed
up
1 parent 65d71ba commit 939fe29

File tree

6 files changed

+63
-70
lines changed

6 files changed

+63
-70
lines changed

.ci/docker/ci_commit_pins/torchao.txt

Lines changed: 0 additions & 1 deletion
This file was deleted.

examples/models/llama/export_llama_lib.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -570,7 +570,7 @@ def get_quantizer_and_quant_params(args):
570570

571571
def _qmode_type(value):
572572
choices = ["int8", "8da4w", "8da4w-gptq", "vulkan_4w"]
573-
patterns = [r"torchao:8da{\d+}w"]
573+
patterns = [r"torchao:8da(\d+)w"]
574574

575575
if value in choices:
576576
return value
@@ -579,10 +579,12 @@ def _qmode_type(value):
579579
matches = re.findall(pattern, value)
580580
if len(matches) == 1:
581581
return value
582+
582583
raise argparse.ArgumentTypeError(
583-
f"Got qmode {value}, but expected one of {choices}, or one of the regex patterns {patterns}."
584+
f"Got qmode {value}, but expected one of {choices}, or one of the regex patterns {patterns}."
584585
)
585586

587+
586588
def _validate_args(args):
587589
"""
588590
TODO: Combine all the backends under --backend args
@@ -596,7 +598,9 @@ def _validate_args(args):
596598
if args.num_sharding > 0 and not args.qnn:
597599
raise ValueError("Model shard is only supported with qnn backend now.")
598600

599-
if args.quantization_mode.startswith("torchao:") or args.embedding_quantize.startswith("torchao:"):
601+
if args.quantization_mode.startswith(
602+
"torchao:"
603+
) or args.embedding_quantize.startswith("torchao:"):
600604
if args.enable_dynamic_shape:
601605
raise ValueError(
602606
"Dynamic shape is not currently supported with torchao ops. Please use --disable_dynamic_shape."

examples/models/llama/install_requirements.sh

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
pip install snakeviz sentencepiece
1111

1212
# Install torchao.
13-
TORCHAO_VERSION=$(cat "$(dirname "$0")"/../../../.ci/docker/ci_commit_pins/torchao.txt)
14-
pip install --no-use-pep517 "git+https://github.com/pytorch/ao.git@${TORCHAO_VERSION}"
13+
pip install "$(dirname "$0")/../../../third-party/ao"
1514

1615
# Install lm-eval for Model Evaluation with lm-evalution-harness
1716
# Install tiktoken for tokenizer

examples/models/llama/source_transformation/quantize.py

Lines changed: 53 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -73,69 +73,24 @@ def quantize( # noqa C901
7373
# Add quantization mode options here: group size, bit width, etc.
7474
return WeightOnlyInt8QuantHandler(model).quantized_model()
7575
elif qmode.startswith("torchao:"):
76-
import glob
77-
import os
78-
79-
libs = glob.glob(
80-
os.path.abspath(
81-
os.path.join(
82-
os.environ.get("CMAKE_INSTALL_PREFIX", ""),
83-
"lib/libtorchao_ops_aten.*",
84-
)
85-
)
86-
)
87-
assert (
88-
len(libs) == 1
89-
), f"Expected 1 library but got {len(libs)}. If you installed the torchao ops in a non-standard location, please set CMAKE_INSTALL_PREFIX correctly."
90-
logging.info(f"Loading custom ops library: {libs[0]}")
91-
torch.ops.load_library(libs[0])
92-
93-
logging.warning(
94-
"When qmode is torchao, the groupsize is obtained from the qmode string with regex parse; blocksize is ignored."
95-
)
96-
embedding_pattern = r"emb.(\d+),(\d+)"
97-
linear_pattern = r"lin8da.(\d+),(\d+)"
98-
99-
matches = re.findall(linear_pattern, qmode)
100-
if matches:
101-
assert (
102-
len(matches) == 1
103-
), f"Expected 1 match for linear_pattern but got {len(matches)}"
104-
bitwidth = int(matches[0][0])
105-
groupsize = int(matches[0][1])
106-
from torchao.experimental.quant_api import (
107-
Int8DynActIntxWeightLinearQuantizer,
108-
)
109-
110-
with torch.no_grad():
111-
model = Int8DynActIntxWeightLinearQuantizer(
112-
device="cpu",
113-
precision=torch_dtype,
114-
groupsize=groupsize,
115-
bitwidth=bitwidth,
116-
has_weight_zeros=False,
117-
).quantize(model)
118-
119-
matches = re.findall(embedding_pattern, qmode)
120-
if matches:
121-
assert (
122-
len(matches) == 1
123-
), f"Expected 1 match for embedding_pattern but got {len(matches)}"
124-
bitwidth = int(matches[0][0])
125-
groupsize = int(matches[0][1])
126-
from torchao.experimental.quant_api import IntxWeightEmbeddingQuantizer
127-
128-
with torch.no_grad():
129-
model = IntxWeightEmbeddingQuantizer(
130-
device="cpu",
131-
precision=torch_dtype,
132-
bitwidth=bitwidth,
133-
groupsize=groupsize,
134-
).quantize(model)
76+
pattern = r"torchao:8da(\d+)w"
77+
matches = re.findall(pattern, qmode)
78+
assert len(matches) == 1, f"Expected 1 match for pattern but got {len(matches)}"
79+
bitwidth = int(matches[0][0])
80+
_load_torchao_ops_aten()
81+
from torchao.experimental.quant_api import Int8DynActIntxWeightLinearQuantizer
82+
83+
with torch.no_grad():
84+
model = Int8DynActIntxWeightLinearQuantizer(
85+
device="cpu",
86+
precision=torch.float32,
87+
groupsize=group_size,
88+
bitwidth=bitwidth,
89+
has_weight_zeros=False,
90+
).quantize(model)
13591

13692
if verbose:
13793
print("quantized model:", model)
138-
13994
return model
14095
elif qmode == "8da4w":
14196
# Check for required args
@@ -760,6 +715,25 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor:
760715

761716

762717
def get_quant_embedding_transform(args):
718+
if args.embedding_quantize.startswith("torchao:"):
719+
bitwidth, group_size = args.embedding_quantize.split(":")[1].split(",")
720+
group_size = int(group_size)
721+
bitwidth = int(bitwidth)
722+
_load_torchao_ops_aten()
723+
from torchao.experimental.quant_api import IntxWeightEmbeddingQuantizer
724+
725+
def _torchao_embedding_quantizer(model):
726+
with torch.no_grad():
727+
model = IntxWeightEmbeddingQuantizer(
728+
device="cpu",
729+
precision=torch.float32,
730+
bitwidth=bitwidth,
731+
groupsize=group_size,
732+
).quantize(model)
733+
return model
734+
735+
return _torchao_embedding_quantizer
736+
763737
bitwidth, group_size = args.embedding_quantize.split(",")
764738
if group_size == "none" or group_size == "None" or group_size == "0":
765739
group_size = None
@@ -801,4 +775,23 @@ def get_quant_weight_transform(args, dtype_override, verbose):
801775
)
802776

803777

778+
def _load_torchao_ops_aten():
779+
import glob
780+
import os
781+
782+
libs = glob.glob(
783+
os.path.abspath(
784+
os.path.join(
785+
os.environ.get("CMAKE_INSTALL_PREFIX", ""),
786+
"lib/libtorchao_ops_aten.*",
787+
)
788+
)
789+
)
790+
assert (
791+
len(libs) == 1
792+
), f"Expected 1 library but got {len(libs)}. If you installed the torchao ops in a non-standard location, please set CMAKE_INSTALL_PREFIX correctly."
793+
logging.info(f"Loading custom ops library: {libs[0]}")
794+
torch.ops.load_library(libs[0])
795+
796+
804797
############################ Source Transform End #######################

examples/models/llama3_2_vision/install_requirements.sh

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,4 @@
99
pip install --pre torchtune --extra-index-url https://download.pytorch.org/whl/nightly/cpu --no-cache-dir
1010

1111
# Install torchao.
12-
TORCHAO_VERSION=$(cat "$(dirname "$0")"/../../../.ci/docker/ci_commit_pins/torchao.txt)
13-
pip install --no-use-pep517 "git+https://github.com/pytorch/ao.git@${TORCHAO_VERSION}"
12+
pip install "$(dirname "$0")/../../../third-party/ao"

examples/models/phi-3-mini-lora/install_requirements.sh

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,4 @@ pip install torchtune
1010
pip install tiktoken
1111

1212
# Install torchao.
13-
TORCHAO_VERSION=$(cat "$(dirname "$0")"/../../../.ci/docker/ci_commit_pins/torchao.txt)
14-
pip install --no-use-pep517 "git+https://github.com/pytorch/ao.git@${TORCHAO_VERSION}"
13+
pip install "$(dirname "$0")/../../../third-party/ao"

0 commit comments

Comments
 (0)