|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import unittest |
| 8 | + |
| 9 | +import torch |
| 10 | + |
| 11 | +from executorch.backends.xnnpack.partition.config.xnnpack_config import ( |
| 12 | + ConfigPrecisionType, |
| 13 | +) |
| 14 | +from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( |
| 15 | + XnnpackFloatingPointPartitioner, |
| 16 | + XnnpackPartitioner, |
| 17 | +) |
| 18 | +from executorch.exir import ( |
| 19 | + EdgeCompileConfig, |
| 20 | + ExecutorchBackendConfig, |
| 21 | + to_edge_transform_and_lower, |
| 22 | +) |
| 23 | + |
| 24 | +from executorch.exir.passes.external_constants_pass import ( |
| 25 | + delegate_external_constants_pass_unlifted, |
| 26 | +) |
| 27 | +from torch.export import export, ExportedProgram |
| 28 | + |
| 29 | +from torchao.quantization.granularity import PerAxis, PerGroup |
| 30 | +from torchao.quantization.quant_api import ( |
| 31 | + Int8DynamicActivationIntxWeightConfig, |
| 32 | + IntxWeightOnlyConfig, |
| 33 | + quantize_, |
| 34 | +) |
| 35 | +from torchao.utils import unwrap_tensor_subclass |
| 36 | + |
| 37 | + |
| 38 | +class TestPropagateCustomMetaPass(unittest.TestCase): |
| 39 | + class ModuleLinear(torch.nn.Module): |
| 40 | + def __init__( |
| 41 | + self, |
| 42 | + in_size: int = 2, |
| 43 | + input_channels: int = 4, |
| 44 | + output_channels: int = 4, |
| 45 | + dtype: torch.dtype = torch.float, |
| 46 | + use_bias: bool = False, |
| 47 | + ): |
| 48 | + super().__init__() |
| 49 | + self.linear = torch.nn.Linear( |
| 50 | + input_channels, output_channels, bias=use_bias |
| 51 | + ).to(dtype=dtype) |
| 52 | + |
| 53 | + self.ic = input_channels |
| 54 | + self.oc = output_channels |
| 55 | + assert dtype in [torch.float, torch.half], "Unsupported op dtype" |
| 56 | + self.op_dtype = dtype |
| 57 | + self.in_size = in_size |
| 58 | + |
| 59 | + def forward(self, x: torch.Tensor): |
| 60 | + return self.linear(x) |
| 61 | + |
| 62 | + def get_random_inputs(self): |
| 63 | + inp = torch.randn(self.in_size, self.ic).to(self.op_dtype) |
| 64 | + return (inp,) |
| 65 | + |
| 66 | + def test_propagate_custom_meta_pass(self): |
| 67 | + eager_model = self.ModuleLinear( |
| 68 | + in_size=1, |
| 69 | + input_channels=32, |
| 70 | + output_channels=2, |
| 71 | + ) |
| 72 | + test_inputs = eager_model.get_random_inputs() |
| 73 | + eager_result = eager_model(*test_inputs) |
| 74 | + |
| 75 | + # Quantize with torchao quantize_ API. |
| 76 | + linear_config = Int8DynamicActivationIntxWeightConfig( |
| 77 | + weight_dtype=torch.int4, |
| 78 | + weight_granularity=PerGroup(32), |
| 79 | + ) |
| 80 | + quantize_(eager_model, linear_config) |
| 81 | + quantized_result = eager_model(*test_inputs) |
| 82 | + unwrap_tensor_subclass(eager_model) |
| 83 | + |
| 84 | + # Tag the unlifted ep.module(). |
| 85 | + tagged_module = export( |
| 86 | + eager_model, test_inputs, dynamic_shapes=None, strict=True |
| 87 | + ).module() |
| 88 | + delegate_external_constants_pass_unlifted( |
| 89 | + module=tagged_module, |
| 90 | + gen_tag_fn=lambda x: "model", # This is the filename the weights will be saved to. In this case, weights will be saved as "model.ptd" |
| 91 | + ) |
| 92 | + |
| 93 | + ep = export(tagged_module, test_inputs, dynamic_shapes=None, strict=True) |
| 94 | + DynamicallyQuantizedPartitioner = XnnpackPartitioner( |
| 95 | + config_precisions=ConfigPrecisionType.DYNAMIC_QUANT, |
| 96 | + per_op_mode=True, |
| 97 | + ) |
| 98 | + edge = to_edge_transform_and_lower( |
| 99 | + ep, |
| 100 | + compile_config=EdgeCompileConfig(_check_ir_validity=False), |
| 101 | + partitioner=[XnnpackPartitioner()], |
| 102 | + generate_etrecord=False, |
| 103 | + ) |
| 104 | + exec = edge.to_executorch(ExecutorchBackendConfig()) |
| 105 | + |
| 106 | + program_buffer = exec.buffer |
| 107 | + data_buffer = bytes(exec._tensor_data.pop("model")) |
| 108 | + |
| 109 | + from executorch.extension.pybindings import portable_lib as runtime |
| 110 | + |
| 111 | + module = runtime._load_for_executorch_from_buffer(program_buffer, data_buffer) |
| 112 | + output = module.forward(test_inputs) |
0 commit comments