Skip to content

Commit 7336a78

Browse files
committed
test xnnpack quant with program-data separation
1 parent d95143e commit 7336a78

File tree

1 file changed

+137
-0
lines changed

1 file changed

+137
-0
lines changed

export_quant_xnnpack.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
2+
import torch
3+
4+
from torchao.quantization.granularity import PerGroup, PerAxis
5+
from torchao.quantization.quant_api import (
6+
IntxWeightOnlyConfig,
7+
Int8DynamicActivationIntxWeightConfig,
8+
quantize_,
9+
)
10+
from torchao.utils import unwrap_tensor_subclass
11+
from torch.export import export, ExportedProgram
12+
from executorch.exir import (
13+
EdgeProgramManager,
14+
ExecutorchBackendConfig,
15+
ExecutorchProgramManager,
16+
)
17+
from executorch.backends.xnnpack.partition.config.xnnpack_config import (
18+
ConfigPrecisionType,
19+
)
20+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
21+
XnnpackFloatingPointPartitioner,
22+
XnnpackPartitioner,
23+
)
24+
from executorch.exir import (
25+
EdgeCompileConfig,
26+
EdgeProgramManager,
27+
to_edge_transform_and_lower,
28+
)
29+
# Quantize embeddings with 8-bits, per channel
30+
# embedding_config = IntxWeightOnlyConfig(
31+
# weight_dtype=torch.int8,
32+
# granularity=PerAxis(0),
33+
# )
34+
# qunatize_(
35+
# eager_model,
36+
# lambda m, fqn: isinstance(m, torch.nn.Embedding),
37+
# )
38+
39+
torch.manual_seed(0)
40+
41+
class ModuleLinear(torch.nn.Module):
42+
def __init__(
43+
self,
44+
in_size: int = 2,
45+
input_channels: int = 4,
46+
output_channels: int = 4,
47+
dtype: torch.dtype = torch.float,
48+
use_bias: bool = False
49+
):
50+
super().__init__()
51+
self.linear = torch.nn.Linear(
52+
input_channels, output_channels, bias=use_bias
53+
).to(dtype=dtype)
54+
55+
self.ic = input_channels
56+
self.oc = output_channels
57+
assert dtype in [torch.float, torch.half], "Unsupported op dtype"
58+
self.op_dtype = dtype
59+
self.in_size = in_size
60+
61+
def forward(self, x: torch.Tensor):
62+
return self.linear(x)
63+
64+
def get_random_inputs(self):
65+
inp = torch.randn(self.in_size, self.ic).to(self.op_dtype)
66+
return (inp,)
67+
68+
eager_model = ModuleLinear(
69+
in_size=1,
70+
input_channels=32,
71+
output_channels=2,
72+
)
73+
74+
test_inputs = eager_model.get_random_inputs()
75+
eager_result = eager_model(*test_inputs)
76+
print("eager result: ", eager_result)
77+
# Quatize linear layers with 8-bit dynamic activations and 4-bit weights
78+
linear_config = Int8DynamicActivationIntxWeightConfig(
79+
weight_dtype=torch.int4,
80+
weight_granularity=PerGroup(32),
81+
)
82+
quantize_(eager_model, linear_config)
83+
84+
quantized_result = eager_model(*test_inputs)
85+
print("quantized results: ", quantized_result)
86+
print(torch.allclose(eager_result, quantized_result, atol=1e-1))
87+
88+
unwrap_tensor_subclass(eager_model)
89+
unwrapped_result = eager_model(*test_inputs)
90+
print("unwrapped results: ", unwrapped_result)
91+
print(torch.allclose(quantized_result, unwrapped_result, atol=1e-3))
92+
93+
from executorch.exir.passes.external_constants_pass import (
94+
delegate_external_constants_pass_unlifted,
95+
)
96+
97+
ep1 = export(eager_model, test_inputs, dynamic_shapes=None, strict=True)
98+
exported_result = ep1.module()(*test_inputs)
99+
print("exported program: ", exported_result)
100+
print(torch.allclose(quantized_result, exported_result, atol=1e-3))
101+
print("Graph: ")
102+
ep1.graph_module.print_readable()
103+
# Tag the unlifted ep.module().
104+
tagged_module = ep1.module()
105+
delegate_external_constants_pass_unlifted(
106+
module=tagged_module,
107+
gen_tag_fn=lambda x: "model", # This is the filename the weights will be saved to. In this case, weights will be saved as "model.ptd"
108+
)
109+
ep = export(tagged_module, test_inputs, dynamic_shapes=None, strict=True)
110+
exported_result = ep.module()(*test_inputs)
111+
print("exported program (after tagging): ", exported_result)
112+
print(torch.allclose(quantized_result, exported_result, atol=1e-3))
113+
# Check tagged nodes:
114+
for node in list(ep.graph.nodes):
115+
if 'custom' in node.meta:
116+
print(f"Node: {node.name}, meta: {node.meta['custom']}")
117+
118+
DynamicallyQuantizedPartitioner = XnnpackPartitioner(
119+
config_precisions=ConfigPrecisionType.DYNAMIC_QUANT,
120+
per_op_mode=True,
121+
)
122+
edge = to_edge_transform_and_lower(
123+
ep,
124+
compile_config=EdgeCompileConfig(_check_ir_validity=False),
125+
partitioner=[XnnpackPartitioner()],
126+
generate_etrecord=False,
127+
)
128+
# ^ after this, the graph has a single node? torchao_dequantize_affine_default
129+
edge_result = edge.exported_program().module()(*test_inputs)
130+
print("edge program: ", edge_result)
131+
print(torch.allclose(quantized_result, edge_result, atol=1e-3))
132+
edge.exported_program().graph_module.print_readable()
133+
134+
exec = edge.to_executorch(ExecutorchBackendConfig())
135+
exec_result = exec.exported_program().module()(*test_inputs)
136+
print("executorch program: ", exec_result)
137+
print(torch.allclose(quantized_result, exec_result, atol=1e-3))

0 commit comments

Comments
 (0)