|
| 1 | + |
| 2 | +import torch |
| 3 | + |
| 4 | +from torchao.quantization.granularity import PerGroup, PerAxis |
| 5 | +from torchao.quantization.quant_api import ( |
| 6 | + IntxWeightOnlyConfig, |
| 7 | + Int8DynamicActivationIntxWeightConfig, |
| 8 | + quantize_, |
| 9 | +) |
| 10 | +from torchao.utils import unwrap_tensor_subclass |
| 11 | +from torch.export import export, ExportedProgram |
| 12 | +from executorch.exir import ( |
| 13 | + EdgeProgramManager, |
| 14 | + ExecutorchBackendConfig, |
| 15 | + ExecutorchProgramManager, |
| 16 | +) |
| 17 | +from executorch.backends.xnnpack.partition.config.xnnpack_config import ( |
| 18 | + ConfigPrecisionType, |
| 19 | +) |
| 20 | +from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( |
| 21 | + XnnpackFloatingPointPartitioner, |
| 22 | + XnnpackPartitioner, |
| 23 | +) |
| 24 | +from executorch.exir import ( |
| 25 | + EdgeCompileConfig, |
| 26 | + EdgeProgramManager, |
| 27 | + to_edge_transform_and_lower, |
| 28 | +) |
| 29 | +# Quantize embeddings with 8-bits, per channel |
| 30 | +# embedding_config = IntxWeightOnlyConfig( |
| 31 | +# weight_dtype=torch.int8, |
| 32 | +# granularity=PerAxis(0), |
| 33 | +# ) |
| 34 | +# qunatize_( |
| 35 | +# eager_model, |
| 36 | +# lambda m, fqn: isinstance(m, torch.nn.Embedding), |
| 37 | +# ) |
| 38 | + |
| 39 | +torch.manual_seed(0) |
| 40 | + |
| 41 | +class ModuleLinear(torch.nn.Module): |
| 42 | + def __init__( |
| 43 | + self, |
| 44 | + in_size: int = 2, |
| 45 | + input_channels: int = 4, |
| 46 | + output_channels: int = 4, |
| 47 | + dtype: torch.dtype = torch.float, |
| 48 | + use_bias: bool = False |
| 49 | + ): |
| 50 | + super().__init__() |
| 51 | + self.linear = torch.nn.Linear( |
| 52 | + input_channels, output_channels, bias=use_bias |
| 53 | + ).to(dtype=dtype) |
| 54 | + |
| 55 | + self.ic = input_channels |
| 56 | + self.oc = output_channels |
| 57 | + assert dtype in [torch.float, torch.half], "Unsupported op dtype" |
| 58 | + self.op_dtype = dtype |
| 59 | + self.in_size = in_size |
| 60 | + |
| 61 | + def forward(self, x: torch.Tensor): |
| 62 | + return self.linear(x) |
| 63 | + |
| 64 | + def get_random_inputs(self): |
| 65 | + inp = torch.randn(self.in_size, self.ic).to(self.op_dtype) |
| 66 | + return (inp,) |
| 67 | + |
| 68 | +eager_model = ModuleLinear( |
| 69 | + in_size=1, |
| 70 | + input_channels=32, |
| 71 | + output_channels=2, |
| 72 | +) |
| 73 | + |
| 74 | +test_inputs = eager_model.get_random_inputs() |
| 75 | +eager_result = eager_model(*test_inputs) |
| 76 | +print("eager result: ", eager_result) |
| 77 | +# Quatize linear layers with 8-bit dynamic activations and 4-bit weights |
| 78 | +linear_config = Int8DynamicActivationIntxWeightConfig( |
| 79 | + weight_dtype=torch.int4, |
| 80 | + weight_granularity=PerGroup(32), |
| 81 | +) |
| 82 | +quantize_(eager_model, linear_config) |
| 83 | + |
| 84 | +quantized_result = eager_model(*test_inputs) |
| 85 | +print("quantized results: ", quantized_result) |
| 86 | +print(torch.allclose(eager_result, quantized_result, atol=1e-1)) |
| 87 | + |
| 88 | +unwrap_tensor_subclass(eager_model) |
| 89 | +unwrapped_result = eager_model(*test_inputs) |
| 90 | +print("unwrapped results: ", unwrapped_result) |
| 91 | +print(torch.allclose(quantized_result, unwrapped_result, atol=1e-3)) |
| 92 | + |
| 93 | +from executorch.exir.passes.external_constants_pass import ( |
| 94 | + delegate_external_constants_pass_unlifted, |
| 95 | +) |
| 96 | + |
| 97 | +ep1 = export(eager_model, test_inputs, dynamic_shapes=None, strict=True) |
| 98 | +exported_result = ep1.module()(*test_inputs) |
| 99 | +print("exported program: ", exported_result) |
| 100 | +print(torch.allclose(quantized_result, exported_result, atol=1e-3)) |
| 101 | +print("Graph: ") |
| 102 | +ep1.graph_module.print_readable() |
| 103 | +# Tag the unlifted ep.module(). |
| 104 | +tagged_module = ep1.module() |
| 105 | +delegate_external_constants_pass_unlifted( |
| 106 | + module=tagged_module, |
| 107 | + gen_tag_fn=lambda x: "model", # This is the filename the weights will be saved to. In this case, weights will be saved as "model.ptd" |
| 108 | +) |
| 109 | +ep = export(tagged_module, test_inputs, dynamic_shapes=None, strict=True) |
| 110 | +exported_result = ep.module()(*test_inputs) |
| 111 | +print("exported program (after tagging): ", exported_result) |
| 112 | +print(torch.allclose(quantized_result, exported_result, atol=1e-3)) |
| 113 | +# Check tagged nodes: |
| 114 | +for node in list(ep.graph.nodes): |
| 115 | + if 'custom' in node.meta: |
| 116 | + print(f"Node: {node.name}, meta: {node.meta['custom']}") |
| 117 | + |
| 118 | +DynamicallyQuantizedPartitioner = XnnpackPartitioner( |
| 119 | + config_precisions=ConfigPrecisionType.DYNAMIC_QUANT, |
| 120 | + per_op_mode=True, |
| 121 | +) |
| 122 | +edge = to_edge_transform_and_lower( |
| 123 | + ep, |
| 124 | + compile_config=EdgeCompileConfig(_check_ir_validity=False), |
| 125 | + partitioner=[XnnpackPartitioner()], |
| 126 | + generate_etrecord=False, |
| 127 | +) |
| 128 | +# ^ after this, the graph has a single node? torchao_dequantize_affine_default |
| 129 | +edge_result = edge.exported_program().module()(*test_inputs) |
| 130 | +print("edge program: ", edge_result) |
| 131 | +print(torch.allclose(quantized_result, edge_result, atol=1e-3)) |
| 132 | +edge.exported_program().graph_module.print_readable() |
| 133 | + |
| 134 | +exec = edge.to_executorch(ExecutorchBackendConfig()) |
| 135 | +exec_result = exec.exported_program().module()(*test_inputs) |
| 136 | +print("executorch program: ", exec_result) |
| 137 | +print(torch.allclose(quantized_result, exec_result, atol=1e-3)) |
0 commit comments