diff --git a/backends/qualcomm/_passes/replace_inf_buffer.py b/backends/qualcomm/_passes/replace_inf_buffer.py index 776bc9beeba..123cddc7007 100644 --- a/backends/qualcomm/_passes/replace_inf_buffer.py +++ b/backends/qualcomm/_passes/replace_inf_buffer.py @@ -18,8 +18,8 @@ def __init__(self): def call(self, graph_module: torch.fx.GraphModule): for buf_name, tensor in graph_module.named_buffers(): if tensor.is_floating_point(): - tensor[tensor == float("inf")] = 255 - tensor[tensor == float("-inf")] = -255 + tensor[tensor == float("inf")] = 65535 + tensor[tensor == float("-inf")] = -65535 setattr(graph_module, buf_name, tensor) graph_module.recompile() diff --git a/backends/qualcomm/quantizer/quantizer.py b/backends/qualcomm/quantizer/quantizer.py index da7b0174c02..52eed11fb4b 100644 --- a/backends/qualcomm/quantizer/quantizer.py +++ b/backends/qualcomm/quantizer/quantizer.py @@ -19,6 +19,7 @@ ) from torch._ops import OpOverload +from torch.ao.quantization.observer import MinMaxObserver from torch.ao.quantization.quantizer import Quantizer from torch.fx import GraphModule @@ -75,7 +76,7 @@ class QuantDtype(IntEnum): ), (QuantDtype.use_16a4w, False): ( get_16a4w_qnn_ptq_config, - get_ptq_per_channel_quant_config(torch.uint16, "int4"), + get_ptq_per_channel_quant_config(torch.uint16, "int4", act_observer=MinMaxObserver), ), (QuantDtype.use_8a8w, False): ( get_8a8w_qnn_ptq_config, @@ -84,7 +85,7 @@ class QuantDtype(IntEnum): # QAT, (QuantDtype.use_16a4w, True): ( get_16a4w_qnn_qat_config, - get_qat_per_channel_quant_config(torch.uint16, "int4"), + get_qat_per_channel_quant_config(torch.uint16, "int4", act_observer=MinMaxObserver), ), (QuantDtype.use_8a8w, True): ( get_8a8w_qnn_qat_config, diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 7ebdf95418d..603b7f2f153 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -574,7 +574,7 @@ def get_quantizer_and_quant_params(args): if args.qnn and args.pt2e_quantize: assert len(quantizers) == 0, "Should not enable both xnnpack and qnn" qnn_quantizer, quant_dtype = get_qnn_quantizer( - args.pt2e_quantize, args.quantization_mode + args.pt2e_quantize, args.quantization_mode, is_qat=True ) quantizers.append(qnn_quantizer) if args.coreml and args.pt2e_quantize: diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index ebc7f02ee1a..e9181db94ba 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -31,7 +31,12 @@ from executorch.extension.export_util.utils import export_to_edge, save_pte_program from executorch.extension.llm.tokenizer.utils import get_tokenizer -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torch.ao.quantization.fake_quantize import ( + disable_fake_quant, + disable_observer, + enable_fake_quant, +) +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_qat_pt2e from torch.ao.quantization.quantizer import Quantizer from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer from torch.export import export_for_training @@ -305,7 +310,7 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage assert ( self.pre_autograd_graph_module is not None ), "Please run export() first" - m = prepare_pt2e(self.pre_autograd_graph_module, composed_quantizer) + m = prepare_qat_pt2e(self.pre_autograd_graph_module, composed_quantizer) logging.info( f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, calibration_data: {self.calibration_data}, tokenizer_path: {self.tokenizer_path}, seq_length: {self.calibration_seq_length}" ) @@ -329,11 +334,21 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage tokenizer_path=self.tokenizer_path, ) else: - logging.info( - "No calibration provided, using dummy input to calibrate..." - ) - m(*self.example_inputs) + m.apply(disable_fake_quant) + logging.info("Calibrating model") + generate(self.tokenizer_path, m) + + logging.info("Testing with fake-quant enabled") + m.apply(disable_observer) + m.apply(enable_fake_quant) + generate(self.tokenizer_path, m) + logging.info(f"Model after Calibration and Testing: {m}") + m.apply(disable_fake_quant) + m = convert_pt2e(m) + logging.info("Running model after convert step") + generate(self.tokenizer_path, m) + logging.info(f"Model after convert step: {m}") DuplicateDynamicQuantChainPass()(m) self.pre_autograd_graph_module = m return self @@ -438,3 +453,26 @@ def get_saved_pte_filename(self) -> Optional[str]: Return the filename of the most recenet saved .pte file. Return None if the model is not saved. """ return self._saved_pte_filename + +def generate(tokenizer_path, m): + tokenizer = get_tokenizer(tokenizer_path) + calib_str = "Once upon a time, there was a little girl named Alice. She" + tokens = tokenizer.encode(calib_str, bos=False, eos=False) + logging.info("Running model with the prompt: " + calib_str) + logging.info("Tokens: " + str(tokens)) + generations = [] + pred = 0 + prompt_len = len(tokens) + for i, token in enumerate(tokens): + outputs = m(torch.tensor([[token]], dtype=torch.long), torch.tensor([i], dtype=torch.long)) + pred = outputs[0].argmax().item() + generations.append(pred) + + for i in range(prompt_len, prompt_len + 30): + outputs = m(torch.tensor([[pred]], dtype=torch.long), torch.tensor([i], dtype=torch.long)) + pred = outputs[0].argmax().item() + generations.append(pred) + + logging.info("Generated tokens: " + str(generations)) + logging.info("Generated string: " + tokenizer.decode(generations)) + return generations