|
16 | 16 |
|
17 | 17 | from executorch.extension.llm.export.builder import DType |
18 | 18 |
|
19 | | -from sentencepiece import SentencePieceProcessor |
20 | | - |
21 | 19 | from torchao.dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout |
22 | 20 | from torchao.quantization.granularity import PerAxis, PerGroup |
23 | 21 | from torchao.quantization.quant_api import ( |
@@ -516,56 +514,37 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: |
516 | 514 | ############################ Source Transform Start ####################### |
517 | 515 |
|
518 | 516 |
|
519 | | -def get_quant_embedding_transform(args, dtype_override: Optional[DType] = None): |
| 517 | +def get_quant_embedding_transform( |
| 518 | + args, use_shared_embedding: bool = False, dtype_override: Optional[DType] = None |
| 519 | +): |
520 | 520 | use_torchao = args.embedding_quantize.startswith("torchao:") |
521 | 521 | if use_torchao: |
522 | 522 | quant_args = args.embedding_quantize.split(":")[1].split(",") |
523 | 523 | else: |
524 | 524 | quant_args = args.embedding_quantize.split(",") |
| 525 | + assert len(quant_args) in [ |
| 526 | + 2, |
| 527 | + 3, |
| 528 | + ], f"Expected 2 or 3 embedding quant_args, but got: {quant_args}" |
525 | 529 |
|
526 | 530 | bitwidth = int(quant_args[0]) |
527 | 531 | group_size = quant_args[0] |
528 | 532 | if group_size in ["none", "None", "0"]: |
529 | 533 | group_size = 0 |
530 | 534 | group_size = int(group_size) |
531 | | - is_symmetric = bool(quant_args[3]) if len(quant_args) > 2 else True |
| 535 | + is_symmetric = ( |
| 536 | + bool(quant_args[3].lower() == "true") if len(quant_args) > 2 else True |
| 537 | + ) |
532 | 538 |
|
533 | 539 | weight_dtype = getattr(torch, f"int{bitwidth}") |
534 | 540 | granularity = PerAxis(0) if group_size == 0 else PerGroup(group_size) |
535 | 541 | mapping_type = MappingType.SYMMETRIC if is_symmetric else MappingType.ASYMMETRIC |
536 | 542 |
|
537 | 543 | if use_torchao: |
538 | | -def get_quant_embedding_transform( |
539 | | - embedding_quantize: str, |
540 | | - use_shared_embedding: bool = False, |
541 | | - dtype_override: Optional[DType] = None, |
542 | | -): |
543 | | - if embedding_quantize.startswith("torchao:"): |
544 | 544 | from torchao.experimental.quant_api import ( |
545 | 545 | EmbeddingQuantizer, |
546 | 546 | SharedEmbeddingQuantizer, |
547 | 547 | ) |
548 | | - from torchao.quantization.granularity import PerAxis, PerGroup |
549 | | - from torchao.quantization.quant_api import MappingType |
550 | | - |
551 | | - quant_args = embedding_quantize.split(":")[1].split(",") |
552 | | - if len(quant_args) == 2: |
553 | | - bitwidth, group_size = quant_args |
554 | | - is_asymmetric = True |
555 | | - else: |
556 | | - bitwidth, group_size, is_asymmetric = quant_args |
557 | | - |
558 | | - if group_size in ["none", "None", "0"]: |
559 | | - group_size = 0 |
560 | | - |
561 | | - group_size = int(group_size) |
562 | | - bitwidth = int(bitwidth) |
563 | | - is_asymmetric = bool(is_asymmetric) |
564 | | - weight_dtype = getattr(torch, f"int{bitwidth}") |
565 | | - granularity = PerAxis(0) if group_size == 0 else PerGroup(group_size) |
566 | | - mapping_type = ( |
567 | | - MappingType.ASYMMETRIC if is_asymmetric else MappingType.SYMMETRIC |
568 | | - ) |
569 | 548 |
|
570 | 549 | def _torchao_embedding_quantizer(model): |
571 | 550 | with torch.no_grad(): |
@@ -599,6 +578,7 @@ def _quantize_embedding(model): |
599 | 578 | granularity=granularity, |
600 | 579 | mapping_type=mapping_type, |
601 | 580 | ), |
| 581 | + lambda m, fqn: isinstance(m, nn.Embedding), |
602 | 582 | ) |
603 | 583 | return model |
604 | 584 |
|
|
0 commit comments