Skip to content

Commit 22a9e78

Browse files
lucylqfacebook-github-bot
authored andcommitted
Decompose after export in export_llama (#15951)
Summary: `unwrap_tensor_subclass` was not unwrapping nested lora linears. This meant qdata/scale/zero were bundled together in the subclass, and separated at run decompositions inside to_edge_transform_and_lower. This is after nodes are tagged, meaning that the scales were not tagged, and remained in the PTE file after the rest of the weights were moved to a PTD file. It's recommended to move away from `unwrap_tensor_subclass` and rely on export + decomps. This PR adds a decomp after exporting in export_llama, and removes cases of `unwrap_tensor_subclass`. TODO: remove all cases of `unwrap_tensor_subclass` in ET. Test Plan: Add check that quantized weights are in PTD file (not PTE file) after quantization. This is a simple check, nested linears seem to be the real issue that decomposing resolves. TODO to add a test for that (probably e2e test with stories in subsequent PR) ``` python -m unittest executorch.backends.xnnpack.test.passes.test_propagate_custom_meta_pass ``` Reviewed By: metascroy Differential Revision: D87826410 Pulled By: lucylq
1 parent 7fa93a7 commit 22a9e78

File tree

3 files changed

+34
-13
lines changed

3 files changed

+34
-13
lines changed

backends/xnnpack/test/passes/test_propagate_custom_meta_pass.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,15 @@
2020
)
2121
from executorch.backends.xnnpack.test.tester import Quantize as XNNPackQuantize, Tester
2222
from executorch.backends.xnnpack.test.tester.tester import ToEdgeTransformAndLower
23+
24+
from executorch.exir import ExecutorchProgramManager
25+
from executorch.exir._serialize import _deserialize_pte_binary
2326
from executorch.exir.passes.external_constants_pass import (
2427
delegate_external_constants_pass_unlifted,
2528
)
29+
from executorch.extension.flat_tensor.serialize.serialize import (
30+
_deserialize_to_flat_tensor,
31+
)
2632

2733
from torchao.quantization.granularity import PerGroup
2834
from torchao.quantization.quant_api import Int8DynamicActivationIntxWeightConfig
@@ -87,7 +93,7 @@ def _test_linear(
8793
self,
8894
partitioner: XnnpackPartitioner,
8995
quantization_stage: Union[BaseStages.Quantize, BaseStages.Quantize_],
90-
):
96+
) -> ExecutorchProgramManager:
9197
eager_model = self.ModuleLinear(
9298
in_size=1,
9399
input_channels=32,
@@ -106,7 +112,7 @@ def _test_linear(
106112
exec = tester.get_artifact()
107113
program_buffer = exec.buffer
108114
self.assertEqual(len(exec._tensor_data), 1)
109-
data_buffer = bytes(exec._tensor_data.pop("model"))
115+
data_buffer = bytes(exec._tensor_data["model"])
110116
self.assertTrue(len(data_buffer) > 200)
111117
from executorch.extension.pybindings import portable_lib as runtime
112118

@@ -122,6 +128,8 @@ def _test_linear(
122128
# test_inputs
123129
# )
124130

131+
return exec
132+
125133
def test_quantize_(self):
126134
# Quantize with torchao quantize_ API.
127135
DynamicallyQuantizedPartitioner = XnnpackPartitioner(
@@ -132,9 +140,16 @@ def test_quantize_(self):
132140
weight_dtype=torch.int4,
133141
weight_granularity=PerGroup(32),
134142
)
135-
self._test_linear(
143+
exec = self._test_linear(
136144
DynamicallyQuantizedPartitioner, BaseStages.Quantize_(config=linear_config)
137145
)
146+
# PTE file has no named data.
147+
pte_file = _deserialize_pte_binary(exec.buffer)
148+
self.assertEqual(pte_file.named_data, None)
149+
150+
# PTD file contains quantized weight and scale.
151+
ptd_file = _deserialize_to_flat_tensor(bytes(exec._tensor_data["model"]))
152+
self.assertEqual(len(ptd_file.named_data), 2)
138153

139154
def test_pt2e_quantize(self):
140155
# Quantize with pt2e quantize.
@@ -156,6 +171,15 @@ def test_pt2e_quantize(self):
156171
partitioner = XnnpackPartitioner(
157172
config_precisions=precision, per_op_mode=per_op_mode
158173
)
159-
self._test_linear(
174+
exec = self._test_linear(
160175
partitioner, XNNPackQuantize(quantization_config=quant_config)
161176
)
177+
# PTE file has no named data.
178+
pte_file = _deserialize_pte_binary(exec.buffer)
179+
self.assertEqual(pte_file.named_data, None)
180+
181+
# PTD file contains quantized weight, and potentially scale.
182+
ptd_file = _deserialize_to_flat_tensor(
183+
bytes(exec._tensor_data["model"])
184+
)
185+
self.assertTrue(len(ptd_file.named_data) >= 1)

examples/models/llama/source_transformation/quantize.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,9 +194,6 @@ def filter_fn(m, fqn):
194194
),
195195
filter_fn=filter_fn,
196196
)
197-
198-
model = unwrap_tensor_subclass(model)
199-
200197
# TODO: deal with checkpoint / computation dtype decoupling.
201198

202199
if verbose:

extension/llm/export/builder.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
from torch.nn.attention import SDPBackend
3939
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
4040
from torchao.quantization.pt2e.quantizer import ComposableQuantizer, Quantizer
41-
from torchao.utils import unwrap_tensor_subclass
4241

4342
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
4443
logging.basicConfig(level=logging.INFO, format=FORMAT)
@@ -203,11 +202,6 @@ def _get_edge_config(self) -> EdgeCompileConfig:
203202
return edge_config
204203

205204
def _export(self, module: Optional[torch.nn.Module] = None) -> ExportedProgram:
206-
if module is not None:
207-
unwrap_tensor_subclass(module)
208-
else:
209-
unwrap_tensor_subclass(self.model)
210-
211205
dynamic_shape = self._get_dynamic_shape()
212206
# 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
213207
# 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
@@ -226,6 +220,12 @@ def _export(self, module: Optional[torch.nn.Module] = None) -> ExportedProgram:
226220
dynamic_shapes=dynamic_shape,
227221
strict=True,
228222
)
223+
# Functionalize the graph, and decompose subclasses from torchao quantize.
224+
from executorch.exir.tracer import _default_decomposition_table
225+
226+
exported_module = exported_module.run_decompositions(
227+
_default_decomposition_table()
228+
)
229229
return exported_module
230230

231231
def export(self) -> "LLMEdgeManager":

0 commit comments

Comments
 (0)