From b48a57ab6624bf2aee3b6522a7b77243b9249be6 Mon Sep 17 00:00:00 2001 From: lucylq Date: Fri, 21 Nov 2025 16:46:57 -0800 Subject: [PATCH] decompose after export in export_llama --- examples/models/llama/source_transformation/quantize.py | 3 --- extension/llm/export/builder.py | 8 ++------ 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/examples/models/llama/source_transformation/quantize.py b/examples/models/llama/source_transformation/quantize.py index a9412d513c7..24fbd8c8111 100644 --- a/examples/models/llama/source_transformation/quantize.py +++ b/examples/models/llama/source_transformation/quantize.py @@ -194,9 +194,6 @@ def filter_fn(m, fqn): ), filter_fn=filter_fn, ) - - model = unwrap_tensor_subclass(model) - # TODO: deal with checkpoint / computation dtype decoupling. if verbose: diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index f8c556f351c..2bb70bff263 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -38,7 +38,6 @@ from torch.nn.attention import SDPBackend from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e from torchao.quantization.pt2e.quantizer import ComposableQuantizer, Quantizer -from torchao.utils import unwrap_tensor_subclass FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) @@ -203,11 +202,6 @@ def _get_edge_config(self) -> EdgeCompileConfig: return edge_config def _export(self, module: Optional[torch.nn.Module] = None) -> ExportedProgram: - if module is not None: - unwrap_tensor_subclass(module) - else: - unwrap_tensor_subclass(self.model) - dynamic_shape = self._get_dynamic_shape() # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up) @@ -226,6 +220,8 @@ def _export(self, module: Optional[torch.nn.Module] = None) -> ExportedProgram: dynamic_shapes=dynamic_shape, strict=True, ) + # Functionalize the graph, and decompose subclasses from torchao quantize. + exported_module = exported_module.run_decompositions({}) return exported_module def export(self) -> "LLMEdgeManager":