Skip to content

Commit 1067754

Browse files
committed
Add pass to remove unused parameters in to_executorch (pytorch#10484)
Summary: Currently, ExecuTorch will serialize any parameters in the exported program, regardless of whether they are actually used. Exporting with strict=True will remove unused parameters, but strict=False will not. Export recently switched to non-strict as the default behavior. This causes PTE bloat when doing pt2e quantization (unquantized weights are left in the graph) or sometimes when exporting multiple methods (encode and decoder, for example). This PR adds a new pass (`remove_unused_parameters_pass`) to strip unused parameters from the `ExportedProgram`. It is run as part of `to_executorch`. Parameters are considered unused if there are no uses of the placeholder node. Parameters are removed by stripping them from the state_dict, input specs, and graph. As a question for reviewers, should we run this pass earlier, as part of to_edge? My rationale for running as part of to_executorch was that it could theoretically clean up anything else left by partitioning and lowering, but I'm not aware of any concrete use cases for this. cc JacobSzwejbka angelayi Pull Request resolved: pytorch#10484 Reviewed By: digantdesai, JacobSzwejbka Differential Revision: D73654202 Pulled By: GregoryComer
1 parent 3064308 commit 1067754

File tree

6 files changed

+296
-0
lines changed

6 files changed

+296
-0
lines changed

exir/passes/TARGETS

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ python_library(
2121
":quant_fusion_pass",
2222
":quantize_io_pass",
2323
":remove_noop_pass",
24+
":remove_unused_parameters_pass",
2425
":replace_aten_with_edge_pass",
2526
":replace_broken_ops_with_function_ops_pass",
2627
":replace_edge_with_backend_pass",
@@ -390,3 +391,14 @@ python_library(
390391
"//executorch/exir/dialects:lib",
391392
],
392393
)
394+
395+
python_library(
396+
name = "remove_unused_parameters_pass",
397+
srcs = [
398+
"remove_unused_parameters_pass.py",
399+
],
400+
deps = [
401+
"//caffe2:torch",
402+
"//executorch/exir/dialects:lib",
403+
],
404+
)

exir/passes/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@
4545
from executorch.exir.passes.prune_empty_tensors_pass import PruneEmptyTensorsPass
4646
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
4747
from executorch.exir.passes.remove_noop_pass import RemoveNoopPass, RemoveToCopyPass
48+
from executorch.exir.passes.remove_unused_parameters_pass import (
49+
remove_unused_parameters_pass,
50+
)
4851
from executorch.exir.passes.replace_aten_with_edge_pass import OpReplacePass
4952
from executorch.exir.passes.replace_broken_ops_with_function_ops_pass import (
5053
ReplaceBrokenOpsWithFunctionalOpsPass,
@@ -71,6 +74,7 @@
7174
"MemoryPlanningPass",
7275
"HintBasedSymShapeEvalPass",
7376
"insert_write_back_for_buffers_pass",
77+
"remove_unused_parameters_pass",
7478
"weights_to_outputs_pass",
7579
]
7680

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import copy
8+
9+
import torch
10+
11+
from torch.export.exported_program import ExportedProgram, InputKind
12+
13+
14+
def remove_unused_parameters_pass(
15+
ep: ExportedProgram,
16+
) -> ExportedProgram:
17+
"""
18+
Remove unused parameters from the exported program.
19+
"""
20+
21+
placeholder_nodes = {
22+
node.target: node
23+
for node in ep.graph_module.graph.nodes
24+
if node.op == "placeholder"
25+
}
26+
27+
unused_parameters = [
28+
s
29+
for s in ep.graph_signature.input_specs
30+
if s.kind == InputKind.PARAMETER
31+
and not _is_parameter_used(ep, s.arg.name, placeholder_nodes)
32+
]
33+
34+
# Remove params from the state dict, graph, and signature.
35+
new_signature = copy.deepcopy(ep.graph_signature)
36+
for param in unused_parameters:
37+
new_signature.input_specs.remove(param)
38+
del ep._state_dict[param.target]
39+
ep.graph_module.graph.erase_node(placeholder_nodes[param.arg.name])
40+
41+
ep._graph_signature = new_signature
42+
ep.graph_module.recompile()
43+
return ep
44+
45+
46+
def _is_parameter_used(
47+
ep: ExportedProgram, parameter: str, placeholder_nodes: dict[str, torch.fx.Node]
48+
) -> bool:
49+
placeholder_node = placeholder_nodes.get(parameter)
50+
if placeholder_node is None:
51+
raise RuntimeError(
52+
f"Invalid graph. No placeholder for {parameter} found in graph."
53+
)
54+
55+
return len(placeholder_node.users) > 0

exir/program/_program.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
EdgeToBackendOpsPass,
4343
MemoryFormatOpsPass,
4444
OpReplacePass,
45+
remove_unused_parameters_pass,
4546
)
4647
from executorch.exir.passes.external_constants_pass import (
4748
external_constants_pass,
@@ -801,6 +802,9 @@ def _generate_edge_program(
801802
assert gm_res is not None
802803
gm = gm_res.graph_module
803804

805+
# Remove unused parameters
806+
program = remove_unused_parameters_pass(program)
807+
804808
if config._check_ir_validity:
805809
try:
806810
EXIRATenDialectVerifier(

exir/tests/TARGETS

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,22 @@ python_unittest(
432432
],
433433
)
434434

435+
python_unittest(
436+
name = "test_remove_unused_parameters_pass",
437+
srcs = [
438+
"test_remove_unused_parameters_pass.py",
439+
],
440+
deps = [
441+
"//caffe2:torch",
442+
"//executorch/backends/xnnpack:xnnpack_delegate",
443+
"//executorch/exir:lib",
444+
"//executorch/exir:memory",
445+
"//executorch/exir/capture:config",
446+
"//executorch/exir/passes:lib",
447+
"//executorch/runtime:runtime",
448+
],
449+
)
450+
435451
python_unittest(
436452
name = "test_remove_view_copy",
437453
srcs = [
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
import unittest
2+
from typing import Sequence
3+
4+
import torch
5+
6+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
7+
from executorch.exir import to_edge, to_edge_transform_and_lower
8+
from executorch.exir.passes import remove_unused_parameters_pass
9+
from executorch.runtime import Runtime
10+
from torch.export import ExportedProgram
11+
12+
13+
class TestRemoveUnusedParametersPass(unittest.TestCase):
14+
class SimpleModelWithUnusedParameters(torch.nn.Module):
15+
def __init__(self):
16+
super().__init__()
17+
self.linear1 = torch.nn.Linear(16, 16)
18+
self.unused_linear = torch.nn.Linear(1024, 1024)
19+
20+
def forward(self, x):
21+
return self.linear1(x)
22+
23+
class NestedModel(torch.nn.Module):
24+
def __init__(self):
25+
super().__init__()
26+
self.mod1 = TestRemoveUnusedParametersPass.SimpleModelWithUnusedParameters()
27+
self.mod2 = TestRemoveUnusedParametersPass.SimpleModelWithUnusedParameters()
28+
29+
def forward(self, x):
30+
y = self.mod1(x) + self.mod2(x)
31+
y += self.mod1.unused_linear(x.repeat([1, 64]))[:, :16]
32+
return y
33+
34+
def test_remove_unused_parameters_simple(self):
35+
model = self.SimpleModelWithUnusedParameters()
36+
model.eval()
37+
example_inputs = (torch.randn(1, 16),)
38+
eager_outputs = model(*example_inputs)
39+
ep = torch.export.export(model, example_inputs, strict=False)
40+
41+
unused_param_names_and_args = {
42+
"unused_linear.weight": "p_unused_linear_weight",
43+
"unused_linear.bias": "p_unused_linear_bias",
44+
}
45+
46+
self._test_pass(ep, unused_param_names_and_args, example_inputs, eager_outputs)
47+
48+
def test_remove_unused_parameters_nested(self):
49+
model = self.NestedModel()
50+
model.eval()
51+
example_inputs = (torch.randn(1, 16),)
52+
eager_outputs = model(*example_inputs)
53+
ep = torch.export.export(model, example_inputs, strict=False)
54+
55+
unused_param_names_and_args = {
56+
"mod2.unused_linear.weight": "p_mod2_unused_linear_weight",
57+
"mod2.unused_linear.bias": "p_mod2_unused_linear_bias",
58+
}
59+
60+
self._test_pass(ep, unused_param_names_and_args, example_inputs, eager_outputs)
61+
62+
def test_remove_unused_parameters_simple_e2e_to_edge(self):
63+
model = self.SimpleModelWithUnusedParameters().eval()
64+
example_inputs = (torch.randn(1, 16),)
65+
66+
# There are approximately 1M unused fp32 parameters - ~4Mb.
67+
# Without the unused params, the expected size is ~2.5Kb.
68+
size_bound = 10000
69+
70+
for strict in [False, True]:
71+
for delegate in [False, True]:
72+
self._test_pass_e2e(
73+
model,
74+
example_inputs,
75+
strict=strict,
76+
use_to_edge=True,
77+
delegate=delegate,
78+
size_bound=size_bound,
79+
)
80+
81+
def test_remove_unused_parameters_simple_e2e_to_edge_transform_and_lower(self):
82+
model = self.SimpleModelWithUnusedParameters().eval()
83+
example_inputs = (torch.randn(1, 16),)
84+
85+
# There are approximately 1M unused fp32 parameters - ~4Mb.
86+
# Without the unused params, the expected size is ~2.5Kb.
87+
size_bound = 10000
88+
89+
for strict in [False, True]:
90+
for delegate in [False, True]:
91+
self._test_pass_e2e(
92+
model,
93+
example_inputs,
94+
strict=strict,
95+
use_to_edge=False,
96+
delegate=delegate,
97+
size_bound=size_bound,
98+
)
99+
100+
def test_remove_unused_parameters_nested_e2e_to_edge(self):
101+
model = self.NestedModel().eval()
102+
example_inputs = (torch.randn(1, 16),)
103+
104+
size_bound = 20000 + 1024 * 1024 * 4
105+
106+
for strict in [False, True]:
107+
for delegate in [False, True]:
108+
self._test_pass_e2e(
109+
model,
110+
example_inputs,
111+
strict=strict,
112+
use_to_edge=True,
113+
delegate=delegate,
114+
size_bound=size_bound,
115+
)
116+
117+
def test_remove_unused_parameters_nested_e2e_to_edge_transform_and_lower(self):
118+
model = self.SimpleModelWithUnusedParameters().eval()
119+
example_inputs = (torch.randn(1, 16),)
120+
121+
size_bound = 20000 + 1024 * 1024 * 4
122+
123+
for strict in [False, True]:
124+
for delegate in [False, True]:
125+
self._test_pass_e2e(
126+
model,
127+
example_inputs,
128+
strict=strict,
129+
use_to_edge=False,
130+
delegate=delegate,
131+
size_bound=size_bound,
132+
)
133+
134+
def _test_pass(
135+
self,
136+
ep: ExportedProgram,
137+
unused_param_names_and_args: dict[str, str],
138+
example_inputs: Sequence[torch.Tensor],
139+
expected_outputs: torch.Tensor,
140+
):
141+
# Verify EP state before running the pass.
142+
placeholders = {
143+
n.target for n in ep.graph_module.graph.nodes if n.op == "placeholder"
144+
}
145+
for param_name, param_arg in unused_param_names_and_args.items():
146+
self.assertIn(param_name, ep.state_dict.keys())
147+
self.assertIn(param_name, ep.graph_signature.parameters)
148+
self.assertIn(param_arg, placeholders)
149+
150+
new_ep = remove_unused_parameters_pass(ep)
151+
152+
# Verify that the unused params are not in the state dict,
153+
# graph signature, or graph.
154+
new_placeholders = set{
155+
n.target for n in new_ep.graph_module.graph.nodes if n.op == "placeholder"
156+
}
157+
for param_name, param_arg in unused_param_names_and_args.items():
158+
self.assertNotIn(param_name, new_ep.state_dict.keys())
159+
self.assertNotIn(param_name, new_ep.graph_signature.parameters)
160+
self.assertNotIn(param_arg, new_placeholders)
161+
162+
# Verify that the outputs are unchanged.
163+
new_outputs = new_ep.module()(*example_inputs)
164+
self.assertTrue(torch.allclose(new_outputs, expected_outputs))
165+
166+
def _test_pass_e2e(
167+
self,
168+
model: torch.nn.Module,
169+
example_inputs: Sequence[torch.Tensor],
170+
strict: bool,
171+
use_to_edge: bool,
172+
delegate: bool,
173+
size_bound: int,
174+
):
175+
eager_outputs = model(*example_inputs)
176+
ep = torch.export.export(model, example_inputs, strict=strict)
177+
178+
if use_to_edge:
179+
lowered = to_edge(ep)
180+
if delegate:
181+
lowered = lowered.to_backend(XnnpackPartitioner())
182+
else: # use to_edge_transform_and_lower
183+
lowered = to_edge_transform_and_lower(
184+
ep,
185+
partitioner=[XnnpackPartitioner()] if delegate else [],
186+
)
187+
188+
lowered = lowered.to_executorch()
189+
self.assertLess(len(lowered.buffer), size_bound)
190+
191+
# Make sure we can load and run the serialized .pte.
192+
runtime = Runtime.get()
193+
program = runtime.load_program(lowered.buffer)
194+
method = program.load_method("forward")
195+
runtime_outputs = method.execute([*example_inputs])
196+
197+
self.assertEqual(1, len(runtime_outputs))
198+
self.assertTrue(
199+
torch.allclose(runtime_outputs[0], eager_outputs, atol=2e-6),
200+
"Values out of tolerance.\n"
201+
+ f" Strict: {strict}, ToEdge: {use_to_edge}, Delegate: {delegate}.\n"
202+
+ f" Eager: {eager_outputs}.\n"
203+
+ f" Pybind: {runtime_outputs[0]}.\n"
204+
+ f" Error: {eager_outputs - runtime_outputs[0]}",
205+
)

0 commit comments

Comments
 (0)