Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions backends/qualcomm/_passes/replace_inf_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ def __init__(self):
def call(self, graph_module: torch.fx.GraphModule):
for buf_name, tensor in graph_module.named_buffers():
if tensor.is_floating_point():
tensor[tensor == float("inf")] = 255
tensor[tensor == float("-inf")] = -255
tensor[tensor == float("inf")] = 65535
tensor[tensor == float("-inf")] = -65535
setattr(graph_module, buf_name, tensor)

graph_module.recompile()
Expand Down
5 changes: 3 additions & 2 deletions backends/qualcomm/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)

from torch._ops import OpOverload
from torch.ao.quantization.observer import MinMaxObserver
from torch.ao.quantization.quantizer import Quantizer
from torch.fx import GraphModule

Expand Down Expand Up @@ -75,7 +76,7 @@ class QuantDtype(IntEnum):
),
(QuantDtype.use_16a4w, False): (
get_16a4w_qnn_ptq_config,
get_ptq_per_channel_quant_config(torch.uint16, "int4"),
get_ptq_per_channel_quant_config(torch.uint16, "int4", act_observer=MinMaxObserver),
),
(QuantDtype.use_8a8w, False): (
get_8a8w_qnn_ptq_config,
Expand All @@ -84,7 +85,7 @@ class QuantDtype(IntEnum):
# QAT,
(QuantDtype.use_16a4w, True): (
get_16a4w_qnn_qat_config,
get_qat_per_channel_quant_config(torch.uint16, "int4"),
get_qat_per_channel_quant_config(torch.uint16, "int4", act_observer=MinMaxObserver),
),
(QuantDtype.use_8a8w, True): (
get_8a8w_qnn_qat_config,
Expand Down
2 changes: 1 addition & 1 deletion examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ def get_quantizer_and_quant_params(args):
if args.qnn and args.pt2e_quantize:
assert len(quantizers) == 0, "Should not enable both xnnpack and qnn"
qnn_quantizer, quant_dtype = get_qnn_quantizer(
args.pt2e_quantize, args.quantization_mode
args.pt2e_quantize, args.quantization_mode, is_qat=True
)
quantizers.append(qnn_quantizer)
if args.coreml and args.pt2e_quantize:
Expand Down
50 changes: 44 additions & 6 deletions extension/llm/export/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,12 @@

from executorch.extension.export_util.utils import export_to_edge, save_pte_program
from executorch.extension.llm.tokenizer.utils import get_tokenizer
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.fake_quantize import (
disable_fake_quant,
disable_observer,
enable_fake_quant,
)
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_qat_pt2e
from torch.ao.quantization.quantizer import Quantizer
from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer
from torch.export import export_for_training
Expand Down Expand Up @@ -305,7 +310,7 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage
assert (
self.pre_autograd_graph_module is not None
), "Please run export() first"
m = prepare_pt2e(self.pre_autograd_graph_module, composed_quantizer)
m = prepare_qat_pt2e(self.pre_autograd_graph_module, composed_quantizer)
logging.info(
f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, calibration_data: {self.calibration_data}, tokenizer_path: {self.tokenizer_path}, seq_length: {self.calibration_seq_length}"
)
Expand All @@ -329,11 +334,21 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage
tokenizer_path=self.tokenizer_path,
)
else:
logging.info(
"No calibration provided, using dummy input to calibrate..."
)
m(*self.example_inputs)
m.apply(disable_fake_quant)
logging.info("Calibrating model")
generate(self.tokenizer_path, m)

logging.info("Testing with fake-quant enabled")
m.apply(disable_observer)
m.apply(enable_fake_quant)
generate(self.tokenizer_path, m)
logging.info(f"Model after Calibration and Testing: {m}")
m.apply(disable_fake_quant)

m = convert_pt2e(m)
logging.info("Running model after convert step")
generate(self.tokenizer_path, m)
logging.info(f"Model after convert step: {m}")
DuplicateDynamicQuantChainPass()(m)
self.pre_autograd_graph_module = m
return self
Expand Down Expand Up @@ -438,3 +453,26 @@ def get_saved_pte_filename(self) -> Optional[str]:
Return the filename of the most recenet saved .pte file. Return None if the model is not saved.
"""
return self._saved_pte_filename

def generate(tokenizer_path, m):
tokenizer = get_tokenizer(tokenizer_path)
calib_str = "Once upon a time, there was a little girl named Alice. She"
tokens = tokenizer.encode(calib_str, bos=False, eos=False)
logging.info("Running model with the prompt: " + calib_str)
logging.info("Tokens: " + str(tokens))
generations = []
pred = 0
prompt_len = len(tokens)
for i, token in enumerate(tokens):
outputs = m(torch.tensor([[token]], dtype=torch.long), torch.tensor([i], dtype=torch.long))
pred = outputs[0].argmax().item()
generations.append(pred)

for i in range(prompt_len, prompt_len + 30):
outputs = m(torch.tensor([[pred]], dtype=torch.long), torch.tensor([i], dtype=torch.long))
pred = outputs[0].argmax().item()
generations.append(pred)

logging.info("Generated tokens: " + str(generations))
logging.info("Generated string: " + tokenizer.decode(generations))
return generations
Loading