Skip to content

Commit 4565a0b

Browse files
committed
up
1 parent 84809fa commit 4565a0b

File tree

1 file changed

+11
-31
lines changed

1 file changed

+11
-31
lines changed

examples/models/llama/source_transformation/quantize.py

Lines changed: 11 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616

1717
from executorch.extension.llm.export.builder import DType
1818

19-
from sentencepiece import SentencePieceProcessor
20-
2119
from torchao.dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout
2220
from torchao.quantization.granularity import PerAxis, PerGroup
2321
from torchao.quantization.quant_api import (
@@ -516,56 +514,37 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
516514
############################ Source Transform Start #######################
517515

518516

519-
def get_quant_embedding_transform(args, dtype_override: Optional[DType] = None):
517+
def get_quant_embedding_transform(
518+
args, use_shared_embedding: bool = False, dtype_override: Optional[DType] = None
519+
):
520520
use_torchao = args.embedding_quantize.startswith("torchao:")
521521
if use_torchao:
522522
quant_args = args.embedding_quantize.split(":")[1].split(",")
523523
else:
524524
quant_args = args.embedding_quantize.split(",")
525+
assert len(quant_args) in [
526+
2,
527+
3,
528+
], f"Expected 2 or 3 embedding quant_args, but got: {quant_args}"
525529

526530
bitwidth = int(quant_args[0])
527531
group_size = quant_args[0]
528532
if group_size in ["none", "None", "0"]:
529533
group_size = 0
530534
group_size = int(group_size)
531-
is_symmetric = bool(quant_args[3]) if len(quant_args) > 2 else True
535+
is_symmetric = (
536+
bool(quant_args[3].lower() == "true") if len(quant_args) > 2 else True
537+
)
532538

533539
weight_dtype = getattr(torch, f"int{bitwidth}")
534540
granularity = PerAxis(0) if group_size == 0 else PerGroup(group_size)
535541
mapping_type = MappingType.SYMMETRIC if is_symmetric else MappingType.ASYMMETRIC
536542

537543
if use_torchao:
538-
def get_quant_embedding_transform(
539-
embedding_quantize: str,
540-
use_shared_embedding: bool = False,
541-
dtype_override: Optional[DType] = None,
542-
):
543-
if embedding_quantize.startswith("torchao:"):
544544
from torchao.experimental.quant_api import (
545545
EmbeddingQuantizer,
546546
SharedEmbeddingQuantizer,
547547
)
548-
from torchao.quantization.granularity import PerAxis, PerGroup
549-
from torchao.quantization.quant_api import MappingType
550-
551-
quant_args = embedding_quantize.split(":")[1].split(",")
552-
if len(quant_args) == 2:
553-
bitwidth, group_size = quant_args
554-
is_asymmetric = True
555-
else:
556-
bitwidth, group_size, is_asymmetric = quant_args
557-
558-
if group_size in ["none", "None", "0"]:
559-
group_size = 0
560-
561-
group_size = int(group_size)
562-
bitwidth = int(bitwidth)
563-
is_asymmetric = bool(is_asymmetric)
564-
weight_dtype = getattr(torch, f"int{bitwidth}")
565-
granularity = PerAxis(0) if group_size == 0 else PerGroup(group_size)
566-
mapping_type = (
567-
MappingType.ASYMMETRIC if is_asymmetric else MappingType.SYMMETRIC
568-
)
569548

570549
def _torchao_embedding_quantizer(model):
571550
with torch.no_grad():
@@ -599,6 +578,7 @@ def _quantize_embedding(model):
599578
granularity=granularity,
600579
mapping_type=mapping_type,
601580
),
581+
lambda m, fqn: isinstance(m, nn.Embedding),
602582
)
603583
return model
604584

0 commit comments

Comments
 (0)