Skip to content

Commit 9738edc

Browse files
Merge branch 'main' into gh/gasoonjia/58/orig
2 parents 50df211 + 50f96a0 commit 9738edc

File tree

12 files changed

+345
-53
lines changed

12 files changed

+345
-53
lines changed

backends/qualcomm/quantizer/qconfig.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ def _derive_bias_qparams_fn(
5252
act_scale, weight_scale
5353
)
5454
derived_scale = (broadcast_act_scale * broadcast_weight_scale).to(torch.float32)
55-
derived_zero = torch.zeros(derived_scale.size()).to(torch.int32)
55+
derived_zero = torch.zeros(derived_scale.size(), device=weight_zp.device).to(
56+
torch.int32
57+
)
5658
if isinstance(weight_obs_or_fq, PerBlockParamObserver):
5759
# keep maximum scale of each channel for bias
5860
derived_scale = (

backends/test/harness/stages/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .export import Export
22
from .partition import Partition
3-
from .quantize import Quantize
3+
from .quantize import Quantize, Quantize_
44
from .run_passes import RunPasses
55
from .serialize import Serialize
66
from .stage import Stage, StageType
@@ -12,6 +12,7 @@
1212
"Export",
1313
"Partition",
1414
"Quantize",
15+
"Quantize_",
1516
"RunPasses",
1617
"Serialize",
1718
"Stage",

backends/test/harness/stages/quantize.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Optional, Sequence, Tuple
1+
from typing import Any, Callable, Optional, Sequence, Tuple
22

33
import torch
44

@@ -15,6 +15,8 @@
1515
prepare_qat_pt2e,
1616
)
1717
from torchao.quantization.pt2e.quantizer import Quantizer
18+
from torchao.quantization.quant_api import quantize_
19+
from torchao.utils import unwrap_tensor_subclass
1820

1921

2022
class Quantize(Stage):
@@ -79,3 +81,48 @@ def graph_module(self) -> str:
7981

8082
def run_artifact(self, inputs):
8183
return self.converted_graph.forward(*inputs)
84+
85+
86+
class Quantize_(Stage):
87+
"""
88+
TorchAO quantization stage using the quantize_ API.
89+
"""
90+
91+
def __init__(
92+
self,
93+
config: Any,
94+
filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None,
95+
):
96+
"""
97+
Args:
98+
config: TorchAO quantization config (e.g., Int4WeightOnlyConfig, Int8DynamicActivationInt8WeightConfig)
99+
filter_fn: Optional filter function to select which modules to quantize
100+
"""
101+
self.config = config
102+
self.filter_fn = filter_fn
103+
self.quantized_module = None
104+
105+
def stage_type(self) -> str:
106+
return StageType.QUANTIZE
107+
108+
def run(
109+
self, artifact: torch.nn.Module, inputs: Optional[Tuple[torch.Tensor]]
110+
) -> None:
111+
# Apply quantize_ to the model
112+
quantize_(artifact, self.config, self.filter_fn)
113+
114+
# Unwrap tensor subclasses for export compatibility
115+
unwrap_tensor_subclass(artifact)
116+
117+
self.quantized_module = artifact
118+
119+
@property
120+
def artifact(self) -> torch.nn.Module:
121+
return self.quantized_module
122+
123+
@property
124+
def graph_module(self) -> torch.nn.Module:
125+
return self.quantized_module
126+
127+
def run_artifact(self, inputs):
128+
return self.quantized_module.forward(*inputs)

backends/xnnpack/_passes/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
from executorch.backends.xnnpack._passes.fuse_activation_pass import FuseActivationPass
2424
from executorch.backends.xnnpack._passes.fuse_batch_norm import FuseBatchNormPass
2525
from executorch.backends.xnnpack._passes.prelu_reshape_pass import PReLUReshapePass
26+
from executorch.backends.xnnpack._passes.propagate_custom_meta_pass import (
27+
PropagateCustomMetaPass,
28+
)
2629
from executorch.backends.xnnpack._passes.remove_redundant_copy_pass import (
2730
RemoveRedundantCopyPass,
2831
)
@@ -59,6 +62,7 @@ def __init__(
5962
DimOrderOpsRevertPass,
6063
ConvertToUpsampleBilinear2d,
6164
ConvertToLinearPass,
65+
PropagateCustomMetaPass,
6266
ConvertToSDPAPass,
6367
ConstPropPass,
6468
FuseBatchNormPass,
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
9+
from executorch.backends.xnnpack.utils.quant_utils import is_dequant, is_quant
10+
from executorch.exir.pass_base import PassResult
11+
12+
13+
class PropagateCustomMetaPass(XNNPACKPass):
14+
"""
15+
Pass to propagate node.meta['custom'] from parent nodes to their q/dq child nodes.
16+
For all quantize/dequantize nodes in the graph, if the parent node has a
17+
node.meta['custom'] entry, this pass will copy that value to the q/dq node's meta.
18+
"""
19+
20+
def call(self, graph_module: torch.fx.GraphModule):
21+
graph = graph_module.graph
22+
23+
for node in graph.nodes:
24+
if not (is_quant(node) or is_dequant(node)):
25+
continue
26+
27+
# Get the parent node (first input argument)
28+
if len(node.all_input_nodes) == 0:
29+
continue
30+
31+
parent_node = node.args[0]
32+
if not isinstance(parent_node, torch.fx.Node):
33+
continue
34+
35+
if "custom" in parent_node.meta:
36+
node.meta["custom"] = parent_node.meta["custom"]
37+
38+
graph_module.recompile()
39+
40+
# Since we are overriding "call", we need to call the parent's "call"
41+
# to retrace the graph and regenerate metadata
42+
graph_module = super().call(graph_module).graph_module
43+
44+
return PassResult(graph_module, True)

backends/xnnpack/runtime/XNNWeightsCache.cpp

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
#include <executorch/runtime/core/memory_allocator.h>
1212
#include <sys/stat.h>
1313
#include <xnnpack.h>
14+
#include <exception>
15+
#include <memory>
16+
#include <new>
1417
#include <string>
1518
#include <vector>
1619

@@ -155,21 +158,45 @@ size_t XNNWeightsCache::look_up(
155158
return packed_weight_entry->second.offset;
156159
}
157160

161+
/**
162+
* Reserve space in the weight cache for n bytes of weight data, aligned to
163+
* context->kPackedAllocationAlignment. This function will return nullptr if
164+
* the allocation fails.
165+
*/
158166
void* XNNWeightsCache::reserve_space(XNNWeightsCache* context, size_t n) {
159167
// MemoryAllocator* allocator = context->runtime_allocator_;
160168
// void* reserved_pointer = allocator->allocate(n,
161169
// context->kPackedAllocationAlignment);
162170

163171
// return reserved_pointer;
164-
std::string data_container;
165-
data_container.resize(n + context->kPackedAllocationAlignment);
166-
void* maybe_aligned_space = data_container.data();
167-
void* aligned_space = (void*)((intptr_t)maybe_aligned_space + 64 -
168-
(intptr_t)maybe_aligned_space % 64);
169-
170-
context->packed_pointer_to_container_[aligned_space] =
171-
std::move(data_container);
172-
return aligned_space;
172+
try {
173+
std::string data_container;
174+
size_t raw_allocation_size = n + context->kPackedAllocationAlignment - 1;
175+
data_container.resize(raw_allocation_size);
176+
177+
void* maybe_aligned_space = data_container.data();
178+
void* aligned_space = std::align(
179+
context->kPackedAllocationAlignment,
180+
n,
181+
maybe_aligned_space,
182+
raw_allocation_size // Note that std::align mutates this value.
183+
);
184+
ET_CHECK_MSG(aligned_space != nullptr, "Memory alignment failed.");
185+
186+
context->packed_pointer_to_container_[aligned_space] =
187+
std::move(data_container);
188+
return aligned_space;
189+
} catch (std::bad_alloc& e) {
190+
// XNNPACK can gracefully handle allocation failures, so return nullptr.
191+
// We want to be able to recover from a failed attempt to load a large
192+
// model without a crash.
193+
ET_LOG(
194+
Error,
195+
"XNN weight cache failed to allocate %zu bytes: %s.",
196+
n,
197+
e.what());
198+
return nullptr;
199+
}
173200
}
174201

175202
size_t XNNWeightsCache::look_up_or_insert(
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
from typing import Tuple, Union
10+
11+
import executorch.backends.test.harness.stages as BaseStages
12+
13+
import torch
14+
from executorch.backends.xnnpack.partition.config.xnnpack_config import (
15+
ConfigPrecisionType,
16+
)
17+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
18+
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
19+
get_symmetric_quantization_config,
20+
)
21+
from executorch.backends.xnnpack.test.tester import Quantize as XNNPackQuantize, Tester
22+
from executorch.backends.xnnpack.test.tester.tester import ToEdgeTransformAndLower
23+
from executorch.exir.passes.external_constants_pass import (
24+
delegate_external_constants_pass_unlifted,
25+
)
26+
27+
from torchao.quantization.granularity import PerGroup
28+
from torchao.quantization.quant_api import Int8DynamicActivationIntxWeightConfig
29+
30+
try:
31+
import executorch.extension.pybindings.portable_lib # noqa[F401]
32+
import executorch.kernels.quantized # noqa[F401]
33+
34+
has_quantized_ops = True
35+
except:
36+
has_quantized_ops = False
37+
print("Missing quantized ops")
38+
39+
40+
class TestPropagateCustomMetaPass(unittest.TestCase):
41+
class ModuleLinear(torch.nn.Module):
42+
def __init__(
43+
self,
44+
in_size: int = 2,
45+
input_channels: int = 4,
46+
output_channels: int = 4,
47+
dtype: torch.dtype = torch.float,
48+
use_bias: bool = False,
49+
):
50+
super().__init__()
51+
self.linear = torch.nn.Linear(
52+
input_channels, output_channels, bias=use_bias
53+
).to(dtype=dtype)
54+
55+
self.ic = input_channels
56+
self.oc = output_channels
57+
assert dtype in [torch.float, torch.half], "Unsupported op dtype"
58+
self.op_dtype = dtype
59+
self.in_size = in_size
60+
61+
def forward(self, x: torch.Tensor):
62+
return self.linear(x)
63+
64+
def get_random_inputs(self):
65+
inp = torch.randn(self.in_size, self.ic).to(self.op_dtype)
66+
return (inp,)
67+
68+
class Export(BaseStages.Export):
69+
def run(
70+
self,
71+
artifact: torch.nn.Module,
72+
inputs: Tuple[torch.Tensor],
73+
) -> None:
74+
75+
tagged_module = torch.export.export(
76+
artifact, inputs, dynamic_shapes=self.dynamic_shapes, strict=True
77+
).module()
78+
delegate_external_constants_pass_unlifted(
79+
module=tagged_module,
80+
gen_tag_fn=lambda x: "model", # This is the filename the weights will be saved to. In this case, weights will be saved as "model.ptd"
81+
)
82+
self.exported_program = torch.export.export(
83+
tagged_module, inputs, dynamic_shapes=self.dynamic_shapes, strict=True
84+
)
85+
86+
def _test_linear(
87+
self,
88+
partitioner: XnnpackPartitioner,
89+
quantization_stage: Union[BaseStages.Quantize, BaseStages.Quantize_],
90+
):
91+
eager_model = self.ModuleLinear(
92+
in_size=1,
93+
input_channels=32,
94+
output_channels=2,
95+
)
96+
test_inputs = eager_model.get_random_inputs()
97+
98+
tester = Tester(eager_model, test_inputs)
99+
tester.quantize(quantization_stage)
100+
tester.export(self.Export())
101+
tester.to_edge_transform_and_lower(
102+
ToEdgeTransformAndLower([partitioner])
103+
).to_executorch()
104+
tester.run_method_and_compare_outputs()
105+
106+
exec = tester.get_artifact()
107+
program_buffer = exec.buffer
108+
self.assertEqual(len(exec._tensor_data), 1)
109+
data_buffer = bytes(exec._tensor_data.pop("model"))
110+
self.assertTrue(len(data_buffer) > 200)
111+
from executorch.extension.pybindings import portable_lib as runtime
112+
113+
module = runtime._load_for_executorch_from_buffer(program_buffer, data_buffer)
114+
output = module.forward(test_inputs)
115+
reference_output = exec.exported_program().module()(
116+
test_inputs[0],
117+
)
118+
self.assertTrue(torch.allclose(output[0], reference_output, 1e-2))
119+
120+
# with self.assertRaises(RuntimeError):
121+
# runtime._load_for_executorch_from_buffer(program_buffer).forward(
122+
# test_inputs
123+
# )
124+
125+
def test_quantize_(self):
126+
# Quantize with torchao quantize_ API.
127+
DynamicallyQuantizedPartitioner = XnnpackPartitioner(
128+
config_precisions=ConfigPrecisionType.DYNAMIC_QUANT,
129+
per_op_mode=False,
130+
)
131+
linear_config = Int8DynamicActivationIntxWeightConfig(
132+
weight_dtype=torch.int4,
133+
weight_granularity=PerGroup(32),
134+
)
135+
self._test_linear(
136+
DynamicallyQuantizedPartitioner, BaseStages.Quantize_(config=linear_config)
137+
)
138+
139+
def test_pt2e_quantize(self):
140+
# Quantize with pt2e quantize.
141+
quant_configs = [
142+
# per_tensor
143+
get_symmetric_quantization_config(is_per_channel=False, is_dynamic=False),
144+
# per_channel
145+
get_symmetric_quantization_config(is_per_channel=True, is_dynamic=False),
146+
# per_channel_dynamic
147+
get_symmetric_quantization_config(is_per_channel=True, is_dynamic=True),
148+
]
149+
for quant_config in quant_configs:
150+
precision = (
151+
ConfigPrecisionType.DYNAMIC_QUANT
152+
if quant_config.input_activation.is_dynamic
153+
else ConfigPrecisionType.STATIC_QUANT
154+
)
155+
for per_op_mode in [True, False]:
156+
partitioner = XnnpackPartitioner(
157+
config_precisions=precision, per_op_mode=per_op_mode
158+
)
159+
self._test_linear(
160+
partitioner, XNNPackQuantize(quantization_config=quant_config)
161+
)

docs/source/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import sys
2525
from typing import Any
2626

27-
import pytorch_sphinx_theme2 # type: ignore[import-untyped]
27+
import pytorch_sphinx_theme2 # type: ignore[import-not-found]
2828

2929
# To let us import ./custom_directives.py
3030
sys.path.insert(0, os.path.abspath("."))

0 commit comments

Comments
 (0)