Skip to content

Commit 5b8fdb9

Browse files
author
ssjia
committed
Update on "[ET-VK][AOT] Enable exporting Q8 Quantized Linear + Convolution"
As title. Introduce fusion patterns to enable fusing quantized convolution and linear graph patterns into a custom op. ## Changes Introduce the concept of using custom pattern detection functions to detect graph patterns rather than solely relying on SubgraphMatcher. The issue with SubgraphMatcher is that a large number of graph patterns may need to be exported to obtain variants for different combinations of decompositions/quantization workflows. Having a custom detection function improves maintainability. Implement detection + replacement functions for quantized linear and quantized conv2d. Differential Revision: [D81323425](https://our.internmc.facebook.com/intern/diff/D81323425/) [ghstack-poisoned]
2 parents 58f5ebd + 066f34b commit 5b8fdb9

File tree

6 files changed

+125
-22
lines changed

6 files changed

+125
-22
lines changed

backends/vulkan/_passes/fold_qdq.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,26 +23,19 @@ def __init__(self, edge_program: torch.export.ExportedProgram):
2323

2424
def call(self, graph_module: torch.fx.GraphModule):
2525
for node in graph_module.graph.nodes:
26-
# Criteria for a foldable Q/DQ node:
27-
# - only one user (dequantize)
2826
if utils.is_quant_node(node):
29-
if len(node.users) > 1:
30-
continue
31-
32-
dq_node = None
27+
original_node = node.args[0]
28+
assert isinstance(original_node, torch.fx.Node)
29+
# For each direct user that is a dequant node, connect the original
30+
# node to the users of the dequant node.
3331
for user in node.users:
3432
if utils.is_dequant_node(user):
3533
dq_node = user
36-
37-
if dq_node is None:
38-
continue
39-
40-
original_node = node.args[0]
41-
assert isinstance(original_node, torch.fx.Node)
42-
dq_node.replace_all_uses_with(original_node)
34+
dq_node.replace_all_uses_with(original_node)
4335

4436
graph_module.recompile()
4537
dead_code_elimination_pass(graph_module)
4638
# Re-trace to validate everything is ok
4739
graph_module = super().call(graph_module).graph_module
40+
4841
return PassResult(graph_module, True)

backends/vulkan/custom_ops_lib.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,7 @@ def conv2d_q8ta_q8csw(
391391
padding: list,
392392
dilation: list,
393393
groups: int,
394+
out_channels: int,
394395
):
395396
weight_zeros = torch.zeros_like(weight_scales, dtype=torch.int32)
396397

@@ -409,6 +410,10 @@ def conv2d_q8ta_q8csw(
409410
# Reshape to original 4D format (OC, IC, H, W)
410411
qweights_4d = qweights_transposed.view(OC, IC, H, W)
411412

413+
# Remove any padding added to output channels dim
414+
if out_channels != OC:
415+
qweights_4d = qweights_4d[:out_channels, :, :, :]
416+
412417
# Dequantize weights
413418
weights = torch.ops.quantized_decomposed.dequantize_per_channel(
414419
qweights_4d,
@@ -443,11 +448,13 @@ def conv2d_q8ta_q8csw(
443448
SymInt[] stride,
444449
SymInt[] padding,
445450
SymInt[] dilation,
446-
SymInt groups) -> Tensor
451+
SymInt groups,
452+
SymInt out_channels) -> Tensor
447453
"""
448454
)
449455
lib.impl(name, conv2d_q8ta_q8csw, "CompositeExplicitAutograd")
450456
conv2d_q8ta_q8csw_op = getattr(getattr(torch.ops, namespace), name)
457+
451458
######################
452459
## apply_rotary_emb ##
453460
######################

backends/vulkan/op_registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,7 @@ def register_quantized_conv_op():
490490
utils.NO_STORAGE, # padding (non tensor)
491491
utils.NO_STORAGE, # dilation (non tensor)
492492
utils.NO_STORAGE, # groups (non tensor)
493+
utils.NO_STORAGE, # original OC count (non tensor)
493494
],
494495
supports_resize=False,
495496
supports_prepacking=True,

backends/vulkan/patterns/quantized_convolution.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -151,19 +151,38 @@ def make_conv2d_q8ta_q8csw_custom_op(
151151

152152
# Reshape weight tensor from (OC, IC, H, W) to (IC * H * W, OC) for matrix multiplication
153153
# This prepares the weights for Im2Col-based convolution computation
154-
OC, IC, H, W = weight_tensor.shape
154+
orig_OC, IC, H, W = weight_tensor.shape
155+
OC = orig_OC
156+
fake_weight = match.weight_node.meta["val"]
157+
158+
# The implementation requires that for grouped convolutions, a group does not cross
159+
# any texel boundary.
160+
if match.groups > 1:
161+
assert (OC / match.groups) % 4 == 0
162+
163+
# The implementation requires that OC is a multiple of 4 so that data load/stores
164+
# are well aligned with texel boundaries. If the original output channel count is
165+
# not a multiple of 4, then add padding.
166+
if OC % 4 != 0:
167+
num_padding = 4 - (OC % 4)
168+
# Pad the OC (output channel) dimension at the end with zeros
169+
weight_tensor = torch.nn.functional.pad(
170+
weight_tensor, (0, 0, 0, 0, 0, 0, 0, num_padding)
171+
)
172+
fake_weight = torch.nn.functional.pad(
173+
fake_weight, (0, 0, 0, 0, 0, 0, 0, num_padding)
174+
)
175+
OC, IC, H, W = weight_tensor.shape
155176

156177
weight_tensor_reshaped = (
157178
weight_tensor.permute(2, 3, 1, 0).contiguous().view(IC * H * W, OC)
158179
)
180+
fake_weight_reshaped = (
181+
fake_weight.permute(2, 3, 1, 0).contiguous().view(IC * H * W, OC)
182+
)
159183
utils.update_program_state_dict(ep, match.weight_node.name, weight_tensor_reshaped)
160184
# Need to make sure the fake tensor matches the updated tensor's properties
161-
match.weight_node.meta["val"] = (
162-
match.weight_node.meta["val"]
163-
.permute(1, 2, 3, 0)
164-
.contiguous()
165-
.view(IC * H * W, OC)
166-
)
185+
match.weight_node.meta["val"] = fake_weight_reshaped
167186

168187
first_graph_node = list(graph_module.graph.nodes)[0]
169188
with graph_module.graph.inserting_before(first_graph_node):
@@ -200,6 +219,7 @@ def make_conv2d_q8ta_q8csw_custom_op(
200219
match.padding,
201220
match.dilation,
202221
match.groups,
222+
orig_OC,
203223
),
204224
)
205225

backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,8 @@ void conv2d_q8csw_linear_tiled_impl(
478478
const ValueRef padding = args.at(idx++);
479479
const ValueRef dilation = args.at(idx++);
480480
const ValueRef groups = args.at(idx++);
481+
const ValueRef orig_OC = args.at(idx++);
482+
(void)orig_OC;
481483
const ValueRef output = args.at(idx++);
482484

483485
const ValueRef packed_weight = prepack_q8_linear_weight(graph, weight);
@@ -552,6 +554,8 @@ void conv2d_q8ta_q8csw_linear_tiled_impl(
552554
const ValueRef padding = args.at(idx++);
553555
const ValueRef dilation = args.at(idx++);
554556
const ValueRef groups = args.at(idx++);
557+
const ValueRef orig_OC = args.at(idx++);
558+
(void)orig_OC;
555559
const ValueRef output = args.at(idx++);
556560

557561
const ValueRef packed_weight = prepack_q8_linear_weight(graph, weight);

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2544,7 +2544,85 @@ def forward(self, x):
25442544
0.75,
25452545
)
25462546

2547-
input_tensor = torch.ones((1, 3, 32, 32), dtype=torch.float32)
2547+
# Create sample inputs
2548+
sample_inputs = (input_tensor,)
2549+
2550+
# Create XNNPACK quantizer with symmetric quantization config
2551+
quantizer = XNNPACKQuantizer()
2552+
operator_config = get_symmetric_quantization_config(
2553+
is_per_channel=True,
2554+
is_dynamic=False,
2555+
)
2556+
quantizer.set_global(operator_config)
2557+
2558+
# Test the quantized module using the existing quantize_and_lower_module function
2559+
# Use higher tolerance since quantization introduces some error
2560+
edge_program = quantize_and_lower_module(
2561+
conv_sequence_module, sample_inputs, quantizer
2562+
)
2563+
2564+
et_program = edge_program.to_executorch()
2565+
self.check_vk_delegation(et_program)
2566+
2567+
self.run_delegated_model_and_check_output(
2568+
et_program,
2569+
conv_sequence_module,
2570+
sample_inputs,
2571+
atol=1e-2,
2572+
rtol=1e-1,
2573+
)
2574+
2575+
def test_vulkan_backend_xnnpack_pt2e_quantized_conv_sequence_all_reduced(self):
2576+
"""
2577+
Test a sequence of convolution layers quantized with PT2E quantization.
2578+
This test creates a module with multiple Conv2d layers in sequence and applies
2579+
XNNPACK symmetric quantization to test the quantized model execution.
2580+
Similar to the linear sequence test but using convolution layers.
2581+
"""
2582+
2583+
import executorch.backends.vulkan.test.utils as test_utils
2584+
2585+
class ConvSequenceModule(torch.nn.Module):
2586+
def __init__(self):
2587+
super().__init__()
2588+
self.conv1 = torch.nn.Conv2d(
2589+
in_channels=3,
2590+
out_channels=32,
2591+
kernel_size=3,
2592+
padding=1,
2593+
bias=False,
2594+
)
2595+
self.conv2 = torch.nn.Conv2d(
2596+
in_channels=32,
2597+
out_channels=1,
2598+
kernel_size=3,
2599+
padding=1,
2600+
bias=False,
2601+
)
2602+
2603+
MAX = 0.75
2604+
MIN = -0.25
2605+
self.conv1.weight.data = test_utils.random_uniform_tensor(
2606+
self.conv1.weight.shape, MIN, MAX
2607+
)
2608+
self.conv2.weight.data = test_utils.random_uniform_tensor(
2609+
self.conv2.weight.shape, MIN, MAX
2610+
)
2611+
2612+
def forward(self, x):
2613+
x = self.conv1(x)
2614+
x = self.conv2(x)
2615+
return x
2616+
2617+
# Create the module
2618+
conv_sequence_module = ConvSequenceModule()
2619+
2620+
input_tensor = test_utils.random_uniform_tensor(
2621+
(1, 3, 32, 32),
2622+
-0.25,
2623+
0.75,
2624+
)
2625+
25482626
# Create sample inputs
25492627
sample_inputs = (input_tensor,)
25502628

0 commit comments

Comments
 (0)