Skip to content

Commit 71602a0

Browse files
mcr229facebook-github-bot
authored andcommitted
Allow Partitioner to Force Dynamic Linear Computation (#5338)
Summary: Pull Request resolved: #5338 # Motivation A current drawback to XNNPACK is that weights are duplicated across delegate instances if they do not soley belong to one partition. For ops like LSTM, they use the same few weights and bias's in multiple linear nodes. This can explode out LSTM as we have to duplicate the LSTM Weight/Bias for every instance of linear. XNNPACK has dynamic linear in which weights are given at runtime, rather than packed AoT. This allows us to force the partitioner to not partition weights so XNNPACK delegate does not own the weights, and thus does not duplicate them. This is only supported for FP32 weights atm, but we can leverage this to balance between slower perf with smaller file sizes. Reviewed By: GregoryComer Differential Revision: D62621998 fbshipit-source-id: 646f25af5f532718e88695173b9c17b6b03ff293
1 parent 034e098 commit 71602a0

File tree

5 files changed

+92
-13
lines changed

5 files changed

+92
-13
lines changed

backends/xnnpack/partition/config/gemm_configs.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ class GEMMConfig(XNNPartitionerConfig):
5252
different ops
5353
"""
5454

55-
def __init__(self, weight_idx, bias_idx, act_idx, fused_acts):
56-
super().__init__()
55+
def __init__(self, weight_idx, bias_idx, act_idx, fused_acts, **kwargs):
56+
super().__init__(**kwargs)
5757
self.weight_idx = weight_idx
5858
self.bias_idx = bias_idx
5959
self.act_idx = act_idx
@@ -250,17 +250,28 @@ def _get_act_deps(
250250
class LinearConfig(GEMMConfig):
251251
target_name = "linear.default"
252252

253-
def __init__(self):
253+
def __init__(self, **kwargs):
254254
super().__init__(
255255
weight_idx=1,
256256
bias_idx=2,
257257
act_idx=0,
258258
fused_acts=["relu.default", "hardtanh.default"],
259+
**kwargs,
259260
)
260261

261262
def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
262263
return torch.ops.aten.linear.default
263264

265+
def _get_weight_deps(
266+
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
267+
) -> Tuple[bool, List[torch.fx.Node]]:
268+
if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear:
269+
# if force fp32_dynamic_linear is on and we detected this as fp32, then we
270+
# do not partition the weight node
271+
return (True, [])
272+
273+
return super()._get_weight_deps(node, ep, precision)
274+
264275
def supported_precision_types(self):
265276
return [
266277
ConfigPrecisionType.DYNAMIC_QUANT,
@@ -272,12 +283,13 @@ def supported_precision_types(self):
272283
class ConvolutionConfig(GEMMConfig):
273284
target_name = "convolution.default"
274285

275-
def __init__(self):
286+
def __init__(self, **kwargs):
276287
super().__init__(
277288
weight_idx=1,
278289
bias_idx=2,
279290
act_idx=0,
280291
fused_acts=["relu.default", "hardtanh.default"],
292+
**kwargs,
281293
)
282294

283295
def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
@@ -314,12 +326,13 @@ class AddmmConfig(GEMMConfig):
314326

315327
target_name = "addmm.default"
316328

317-
def __init__(self):
329+
def __init__(self, **kwargs):
318330
super().__init__(
319331
weight_idx=2,
320332
bias_idx=0,
321333
act_idx=1,
322334
fused_acts=["relu.default", "hardtanh.default"],
335+
**kwargs,
323336
)
324337
self.src_partitions = None
325338
self.linear_modules = [torch.nn.functional.linear, torch.nn.Linear]
@@ -417,8 +430,8 @@ def supported_precision_types(self):
417430
class MMConfig(AddmmConfig):
418431
target_name = "mm.default"
419432

420-
def __init__(self):
421-
super().__init__()
433+
def __init__(self, **kwargs):
434+
super().__init__(**kwargs)
422435
self.bias_idx = None
423436
self.weight_idx = 1
424437
self.act_idx = 0

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,13 @@
2525

2626

2727
class GenericNodePartitionerConfig(XNNPartitionerConfig):
28-
def __init__(self, fused_act: Optional[List[str]] = None):
28+
def __init__(self, fused_act: Optional[List[str]] = None, **kwargs):
2929
"""
3030
fused_act is a list of node target names that can be fused with this
3131
node under quantization
3232
"""
3333
self.fused_acts = fused_act or []
34-
super().__init__()
34+
super().__init__(**kwargs)
3535

3636
def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
3737
return self.check_common_constraints(node, ep)
@@ -98,8 +98,8 @@ def supported_precision_types(self) -> List[ConfigPrecisionType]:
9898
class AddConfig(GenericNodePartitionerConfig):
9999
target_name = "add.Tensor"
100100

101-
def __init__(self):
102-
super().__init__(fused_act=["relu.default"])
101+
def __init__(self, **kwargs):
102+
super().__init__(fused_act=["relu.default"], **kwargs)
103103

104104
def supported_precision_types(self) -> List[ConfigPrecisionType]:
105105
return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT]

backends/xnnpack/partition/config/xnnpack_config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,11 @@ class XNNPartitionerConfig(PartitionerConfig):
3737
types they want to enable
3838
"""
3939

40-
def __init__(self):
40+
def __init__(self, **kwargs):
4141
super().__init__()
4242
self.enabled_precision_types = self.supported_precision_types()
43+
# Flag used in GEMMConfig()
44+
self.force_fp32_dynamic_linear = kwargs.get("force_fp32_dynamic_linear", False)
4345

4446
def get_partition(
4547
self, node: torch.fx.Node, ep: ExportedProgram

backends/xnnpack/partition/xnnpack_partitioner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def __init__(
3636
] = None,
3737
per_op_mode=False,
3838
verbose: bool = False,
39+
**kwargs,
3940
):
4041
"""
4142
@verbose: if True, print out more information about the partitioner.
@@ -55,7 +56,7 @@ def __init__(
5556

5657
for config in configs_to_use:
5758
# Config Classes given to XnnpackPartitioner should no longer be abstract
58-
initialized = config() # pyre-ignore
59+
initialized = config(**kwargs) # pyre-ignore
5960
initialized.set_enabled_precision_types(config_precisions)
6061
initialized_configs.append(initialized)
6162

backends/xnnpack/test/ops/lstm.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
11+
12+
from executorch.backends.xnnpack.test.tester import Tester
13+
from executorch.backends.xnnpack.test.tester.tester import ToEdgeTransformAndLower
14+
15+
16+
class TestLSTM(unittest.TestCase):
17+
class LSTMLinear(torch.nn.Module):
18+
def __init__(self, input_size, hidden_size, out_size):
19+
super().__init__()
20+
self.lstm = torch.nn.LSTM(
21+
input_size=input_size, hidden_size=hidden_size, batch_first=True
22+
)
23+
self.linear = torch.nn.Linear(hidden_size, hidden_size)
24+
self.linear2 = torch.nn.Linear(hidden_size, out_size)
25+
26+
def forward(self, x):
27+
x, hs = self.lstm(x)
28+
x = self.linear(x[:, -1, :])
29+
x = self.linear2(x)
30+
return torch.nn.functional.log_softmax(x, dim=1)
31+
32+
def test_fp32_lstm(self):
33+
(
34+
Tester(self.LSTMLinear(32, 32, 10), (torch.rand(1, 32, 32),))
35+
.export()
36+
.to_edge_transform_and_lower()
37+
.check_not(["executorch_exir_dialects_edge__ops_aten_addmm_default"])
38+
.check_not(
39+
["p_lstm_weight", "p_lstm_bias"]
40+
) # These Should be Consumed by Delegate
41+
.to_executorch()
42+
.serialize()
43+
.run_method_and_compare_outputs()
44+
)
45+
46+
def test_fp32_lstm_force_dynamic_linear(self):
47+
(
48+
Tester(self.LSTMLinear(32, 32, 10), (torch.rand(1, 32, 32),))
49+
.export()
50+
.to_edge_transform_and_lower(
51+
ToEdgeTransformAndLower(
52+
partitioners=[XnnpackPartitioner(force_fp32_dynamic_linear=True)]
53+
)
54+
)
55+
.check_not(["executorch_exir_dialects_edge__ops_aten_addmm_default"])
56+
# Weights are supplied as input to linears
57+
.check(["p_lstm_weight_hh_l0", "p_lstm_weight_ih_l0"])
58+
# Biases are owned by delegates
59+
.check_not(["p_lstm_bias"])
60+
.to_executorch()
61+
.serialize()
62+
.run_method_and_compare_outputs()
63+
)

0 commit comments

Comments
 (0)