Skip to content

Commit 404aacf

Browse files
pytorchbotmetascroyGregoryComer
authored
Embedding quant unification (#14706)
Differential Revision: D83318725 --------- Co-authored-by: Scott Roy <[email protected]> Co-authored-by: Gregory Comer <[email protected]>
1 parent f06ad29 commit 404aacf

File tree

6 files changed

+277
-126
lines changed

6 files changed

+277
-126
lines changed

.ci/scripts/test_llama.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ if [[ "${CUSTOM}" == "ON" ]]; then
236236
EXPORT_ARGS="${EXPORT_ARGS} model.use_sdpa_with_kv_cache=true"
237237
fi
238238
if [[ "${QE}" == "ON" ]]; then
239-
EXPORT_ARGS="${EXPORT_ARGS} quantization.embedding_quantize=\"8,1024\""
239+
EXPORT_ARGS="${EXPORT_ARGS} quantization.embedding_quantize=\"8,768\""
240240
fi
241241
if [[ "${MPS}" == "ON" ]]; then
242242
EXPORT_ARGS="${EXPORT_ARGS} backend.mps.enabled=true model.enable_dynamic_shape=false debug.verbose=true"

examples/apple/coreml/llama/export.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from executorch.exir.backend.utils import format_delegated_graph
2424
from executorch.exir.capture._config import ExecutorchBackendConfig
2525
from executorch.exir.passes import MemoryPlanningPass
26-
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
2726
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
2827
from executorch.extension.export_util.utils import save_pte_program
2928

@@ -211,9 +210,7 @@ def main() -> None:
211210
executorch_program = edge_manager.to_executorch(
212211
ExecutorchBackendConfig(
213212
extract_delegate_segments=True,
214-
passes=[
215-
QuantFusionPass(),
216-
],
213+
do_quant_fusion_and_const_prop=True,
217214
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
218215
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
219216
)

examples/models/llama/source_transformation/quantize.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -595,19 +595,16 @@ def __init__(
595595

596596
@torch.no_grad()
597597
def create_quantized_state_dict(self, packed=False) -> Dict:
598+
from torchao.quantization.granularity import PerAxis, PerGroup
599+
from torchao.quantization.quant_api import (
600+
IntxWeightOnlyConfig,
601+
MappingType,
602+
quantize_,
603+
)
604+
598605
cur_state_dict = self.mod.state_dict()
599606

600-
if self.bitwidth == 2:
601-
range_min = -2
602-
range_max = 1
603-
elif self.bitwidth == 4:
604-
range_min = -8
605-
range_max = 7
606-
elif self.bitwidth == 8:
607-
range_min = -128
608-
range_max = 127
609-
else:
610-
raise ValueError(f"Unsupported bitwidth {self.bitwidth}")
607+
assert self.bitwidth in [2, 4, 8], f"Unsupported bitwidth {self.bitwidth}"
611608

612609
for fqn, mod in self.mod.named_modules():
613610
if isinstance(mod, nn.Embedding):
@@ -619,18 +616,22 @@ def create_quantized_state_dict(self, packed=False) -> Dict:
619616
print(
620617
f"quantize {fqn, mod} with group_size {self.group_size}, bitwidth {self.bitwidth}"
621618
)
622-
weight, scales, _ = dynamically_quantize_per_channel(
623-
(
624-
mod.weight.to(dtype=self.precision)
625-
if self.precision
626-
else mod.weight
619+
tmp_model = nn.Embedding(mod.weight.shape[0], mod.weight.shape[1])
620+
if self.precision:
621+
tmp_model = tmp_model.to(dtype=self.precision)
622+
tmp_model.weight = nn.Parameter(mod.weight)
623+
config = IntxWeightOnlyConfig(
624+
weight_dtype=getattr(torch, f"int{self.bitwidth}"),
625+
granularity=(
626+
PerAxis(0)
627+
if (self.group_size is None or self.group_size == 0)
628+
else PerGroup(self.group_size)
627629
),
628-
range_min,
629-
range_max,
630-
torch.int8,
631-
self.group_size,
632-
scales_dtype=mod.weight.dtype,
630+
mapping_type=MappingType.SYMMETRIC,
633631
)
632+
quantize_(tmp_model, config, lambda m, fqn: isinstance(m, nn.Embedding))
633+
weight = tmp_model.weight.qdata # pyre-ignore[16]
634+
scales = tmp_model.weight.scale # pyre-ignore[16]
634635

635636
if packed:
636637
if self.bitwidth == 2:

0 commit comments

Comments
 (0)