Skip to content

Commit 5f941b1

Browse files
GregoryComerfacebook-github-bot
authored andcommitted
Add pass to remove unused parameters in to_executorch (pytorch#10484)
Summary: Currently, ExecuTorch will serialize any parameters in the exported program, regardless of whether they are actually used. Exporting with strict=True will remove unused parameters, but strict=False will not. Export recently switched to non-strict as the default behavior. This causes PTE bloat when doing pt2e quantization (unquantized weights are left in the graph) or sometimes when exporting multiple methods (encode and decoder, for example). This PR adds a new pass (`remove_unused_parameters_pass`) to strip unused parameters from the `ExportedProgram`. It is run as part of `to_executorch`. Parameters are considered unused if there are no uses of the placeholder node. Parameters are removed by stripping them from the state_dict, input specs, and graph. As a question for reviewers, should we run this pass earlier, as part of to_edge? My rationale for running as part of to_executorch was that it could theoretically clean up anything else left by partitioning and lowering, but I'm not aware of any concrete use cases for this. cc JacobSzwejbka angelayi Reviewed By: digantdesai, JacobSzwejbka Differential Revision: D73654202 Pulled By: GregoryComer
1 parent 3064308 commit 5f941b1

File tree

6 files changed

+303
-0
lines changed

6 files changed

+303
-0
lines changed

exir/passes/TARGETS

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ python_library(
2121
":quant_fusion_pass",
2222
":quantize_io_pass",
2323
":remove_noop_pass",
24+
":remove_unused_parameters_pass",
2425
":replace_aten_with_edge_pass",
2526
":replace_broken_ops_with_function_ops_pass",
2627
":replace_edge_with_backend_pass",
@@ -390,3 +391,14 @@ python_library(
390391
"//executorch/exir/dialects:lib",
391392
],
392393
)
394+
395+
python_library(
396+
name = "remove_unused_parameters_pass",
397+
srcs = [
398+
"remove_unused_parameters_pass.py",
399+
],
400+
deps = [
401+
"//caffe2:torch",
402+
"//executorch/exir/dialects:lib",
403+
],
404+
)

exir/passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from executorch.exir.passes.prune_empty_tensors_pass import PruneEmptyTensorsPass
4646
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
4747
from executorch.exir.passes.remove_noop_pass import RemoveNoopPass, RemoveToCopyPass
48+
from executorch.exir.passes.remove_unused_parameters_pass import remove_unused_parameters_pass
4849
from executorch.exir.passes.replace_aten_with_edge_pass import OpReplacePass
4950
from executorch.exir.passes.replace_broken_ops_with_function_ops_pass import (
5051
ReplaceBrokenOpsWithFunctionalOpsPass,
@@ -71,6 +72,7 @@
7172
"MemoryPlanningPass",
7273
"HintBasedSymShapeEvalPass",
7374
"insert_write_back_for_buffers_pass",
75+
"remove_unused_parameters_pass",
7476
"weights_to_outputs_pass",
7577
]
7678

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import copy
8+
import torch
9+
10+
from torch.export.exported_program import (
11+
ExportedProgram,
12+
InputKind,
13+
)
14+
15+
def remove_unused_parameters_pass(
16+
ep: ExportedProgram,
17+
) -> ExportedProgram:
18+
"""
19+
Remove unused parameters from the exported program.
20+
"""
21+
22+
placeholder_nodes = {
23+
node.target: node for node in ep.graph_module.graph.nodes if node.op == "placeholder"
24+
}
25+
26+
unused_parameters = [
27+
s
28+
for s in ep.graph_signature.input_specs
29+
if s.kind == InputKind.PARAMETER and not _is_parameter_used(ep, s.arg.name, placeholder_nodes)
30+
]
31+
32+
# Remove params from the state dict, graph, and signature.
33+
new_signature = copy.deepcopy(ep.graph_signature)
34+
for param in unused_parameters:
35+
new_signature.input_specs.remove(param)
36+
del ep._state_dict[param.target]
37+
ep.graph_module.graph.erase_node(placeholder_nodes[param.arg.name])
38+
39+
ep._graph_signature = new_signature
40+
ep.graph_module.recompile()
41+
return ep
42+
43+
def _is_parameter_used(ep: ExportedProgram, parameter: str, placeholder_nodes: dict[str, torch.fx.Node]) -> bool:
44+
placeholder_node = placeholder_nodes.get(parameter)
45+
if placeholder_node is None:
46+
raise RuntimeError(f"Invalid graph. No placeholder for {parameter} found in graph.")
47+
48+
return len(placeholder_node.users) > 0

exir/program/_program.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
base_post_op_replace_passes,
4040
base_pre_op_replace_passes,
4141
dead_code_elimination_pass,
42+
remove_unused_parameters_pass,
4243
EdgeToBackendOpsPass,
4344
MemoryFormatOpsPass,
4445
OpReplacePass,
@@ -800,6 +801,9 @@ def _generate_edge_program(
800801
gm_res = RemoveNonCoreAtenOpGraphAssertsPass()(gm)
801802
assert gm_res is not None
802803
gm = gm_res.graph_module
804+
805+
# Remove unused parameters
806+
program = remove_unused_parameters_pass(program)
803807

804808
if config._check_ir_validity:
805809
try:

exir/tests/TARGETS

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,22 @@ python_unittest(
432432
],
433433
)
434434

435+
python_unittest(
436+
name = "test_remove_unused_parameters_pass",
437+
srcs = [
438+
"test_remove_unused_parameters_pass.py",
439+
],
440+
deps = [
441+
"//caffe2:torch",
442+
"//executorch/backends/xnnpack:xnnpack_delegate",
443+
"//executorch/exir:lib",
444+
"//executorch/exir:memory",
445+
"//executorch/exir/capture:config",
446+
"//executorch/exir/passes:lib",
447+
"//executorch/runtime:runtime",
448+
],
449+
)
450+
435451
python_unittest(
436452
name = "test_remove_view_copy",
437453
srcs = [
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
import unittest
2+
from typing import Sequence
3+
4+
import torch
5+
6+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
7+
from executorch.exir import to_edge, to_edge_transform_and_lower
8+
from executorch.exir.passes import remove_unused_parameters_pass
9+
from executorch.runtime import Runtime
10+
from torch.export import ExportedProgram
11+
12+
class TestRemoveUnusedParametersPass(unittest.TestCase):
13+
class SimpleModelWithUnusedParameters(torch.nn.Module):
14+
def __init__(self):
15+
super().__init__()
16+
self.linear1 = torch.nn.Linear(16, 16)
17+
self.unused_linear = torch.nn.Linear(1024, 1024)
18+
19+
def forward(self, x):
20+
return self.linear1(x)
21+
22+
class NestedModel(torch.nn.Module):
23+
def __init__(self):
24+
super().__init__()
25+
self.mod1 = TestRemoveUnusedParametersPass.SimpleModelWithUnusedParameters()
26+
self.mod2 = TestRemoveUnusedParametersPass.SimpleModelWithUnusedParameters()
27+
28+
def forward(self, x):
29+
y = self.mod1(x) + self.mod2(x)
30+
y += self.mod1.unused_linear(x.repeat([1,64]))[:,:16]
31+
return y
32+
33+
34+
def test_remove_unused_parameters_simple(self):
35+
model = self.SimpleModelWithUnusedParameters()
36+
model.eval()
37+
example_inputs = (torch.randn(1, 16),)
38+
eager_outputs = model(*example_inputs)
39+
ep = torch.export.export(model, example_inputs, strict=False)
40+
41+
unused_param_names_and_args = {
42+
"unused_linear.weight": "p_unused_linear_weight",
43+
"unused_linear.bias": "p_unused_linear_bias",
44+
}
45+
46+
self._test_pass(
47+
ep,
48+
unused_param_names_and_args,
49+
example_inputs,
50+
eager_outputs
51+
)
52+
53+
54+
def test_remove_unused_parameters_nested(self):
55+
model = self.NestedModel()
56+
model.eval()
57+
example_inputs = (torch.randn(1, 16),)
58+
eager_outputs = model(*example_inputs)
59+
ep = torch.export.export(model, example_inputs, strict=False)
60+
61+
unused_param_names_and_args = {
62+
"mod2.unused_linear.weight": "p_mod2_unused_linear_weight",
63+
"mod2.unused_linear.bias": "p_mod2_unused_linear_bias",
64+
}
65+
66+
self._test_pass(
67+
ep,
68+
unused_param_names_and_args,
69+
example_inputs,
70+
eager_outputs
71+
)
72+
73+
74+
def test_remove_unused_parameters_simple_e2e_to_edge(self):
75+
model = self.SimpleModelWithUnusedParameters().eval()
76+
example_inputs = (torch.randn(1, 16),)
77+
78+
# There are approximately 1M unused fp32 parameters - ~4Mb.
79+
# Without the unused params, the expected size is ~2.5Kb.
80+
size_bound = 10000
81+
82+
for strict in [False, True]:
83+
for delegate in [False, True]:
84+
self._test_pass_e2e(
85+
model,
86+
example_inputs,
87+
strict=strict,
88+
use_to_edge=True,
89+
delegate=delegate,
90+
size_bound=size_bound
91+
)
92+
93+
94+
def test_remove_unused_parameters_simple_e2e_to_edge_transform_and_lower(self):
95+
model = self.SimpleModelWithUnusedParameters().eval()
96+
example_inputs = (torch.randn(1, 16),)
97+
98+
# There are approximately 1M unused fp32 parameters - ~4Mb.
99+
# Without the unused params, the expected size is ~2.5Kb.
100+
size_bound = 10000
101+
102+
for strict in [False, True]:
103+
for delegate in [False, True]:
104+
self._test_pass_e2e(
105+
model,
106+
example_inputs,
107+
strict=strict,
108+
use_to_edge=False,
109+
delegate=delegate,
110+
size_bound=size_bound
111+
)
112+
113+
114+
def test_remove_unused_parameters_nested_e2e_to_edge(self):
115+
model = self.NestedModel().eval()
116+
example_inputs = (torch.randn(1, 16),)
117+
118+
size_bound = 20000 + 1024*1024*4
119+
120+
for strict in [False, True]:
121+
for delegate in [False, True]:
122+
self._test_pass_e2e(
123+
model,
124+
example_inputs,
125+
strict=strict,
126+
use_to_edge=True,
127+
delegate=delegate,
128+
size_bound=size_bound
129+
)
130+
131+
132+
def test_remove_unused_parameters_nested_e2e_to_edge_transform_and_lower(self):
133+
model = self.SimpleModelWithUnusedParameters().eval()
134+
example_inputs = (torch.randn(1, 16),)
135+
136+
size_bound = 20000 + 1024*1024*4
137+
138+
for strict in [False, True]:
139+
for delegate in [False, True]:
140+
self._test_pass_e2e(
141+
model,
142+
example_inputs,
143+
strict=strict,
144+
use_to_edge=False,
145+
delegate=delegate,
146+
size_bound=size_bound
147+
)
148+
149+
150+
def _test_pass(
151+
self,
152+
ep: ExportedProgram,
153+
unused_param_names_and_args: dict[str, str],
154+
example_inputs: Sequence[torch.Tensor],
155+
expected_outputs: torch.Tensor
156+
):
157+
# Verify EP state before running the pass.
158+
placeholders = set(
159+
n.target for n in ep.graph_module.graph.nodes if n.op == "placeholder"
160+
)
161+
for param_name, param_arg in unused_param_names_and_args.items():
162+
self.assertIn(param_name, ep.state_dict.keys())
163+
self.assertIn(param_name, ep.graph_signature.parameters)
164+
self.assertIn(param_arg, placeholders)
165+
166+
new_ep = remove_unused_parameters_pass(ep)
167+
168+
# Verify that the unused params are not in the state dict,
169+
# graph signature, or graph.
170+
new_placeholders = set(
171+
n.target for n in new_ep.graph_module.graph.nodes if n.op == "placeholder"
172+
)
173+
for param_name, param_arg in unused_param_names_and_args.items():
174+
self.assertNotIn(param_name, new_ep.state_dict.keys())
175+
self.assertNotIn(param_name, new_ep.graph_signature.parameters)
176+
self.assertNotIn(param_arg, new_placeholders)
177+
178+
# Verify that the outputs are unchanged.
179+
new_outputs = new_ep.module()(*example_inputs)
180+
self.assertTrue(torch.allclose(new_outputs, expected_outputs))
181+
182+
def _test_pass_e2e(
183+
self,
184+
model: torch.nn.Module,
185+
example_inputs: Sequence[torch.Tensor],
186+
strict: bool,
187+
use_to_edge: bool,
188+
delegate: bool,
189+
size_bound: int,
190+
):
191+
eager_outputs = model(*example_inputs)
192+
ep = torch.export.export(model, example_inputs, strict=strict)
193+
194+
if use_to_edge:
195+
lowered = to_edge(ep)
196+
if delegate:
197+
lowered = lowered.to_backend(XnnpackPartitioner())
198+
else: # use to_edge_transform_and_lower
199+
lowered = to_edge_transform_and_lower(
200+
ep,
201+
partitioner=[XnnpackPartitioner()] if delegate else [],
202+
)
203+
204+
lowered = lowered.to_executorch()
205+
self.assertLess(len(lowered.buffer), size_bound)
206+
207+
# Make sure we can load and run the serialized .pte.
208+
runtime = Runtime.get()
209+
program = runtime.load_program(lowered.buffer)
210+
method = program.load_method("forward")
211+
runtime_outputs = method.execute([*example_inputs])
212+
213+
self.assertEqual(1, len(runtime_outputs))
214+
self.assertTrue(
215+
torch.allclose(runtime_outputs[0], eager_outputs, atol=2e-6),
216+
f"Values out of tolerance.\n" +
217+
f" Strict: {strict}, ToEdge: {use_to_edge}, Delegate: {delegate}.\n" +
218+
f" Eager: {eager_outputs}.\n" +
219+
f" Pybind: {runtime_outputs[0]}.\n" +
220+
f" Error: {eager_outputs - runtime_outputs[0]}"
221+
)

0 commit comments

Comments
 (0)