Skip to content

Commit ca3b217

Browse files
author
jorgep31415
committed
Update on "[ET-VK] Move save_cache from Runtime dtor to model destroy"
## Issue `Runtime` is a local static variable. Hence we'd expect the Runtime dtor to be called on program exit. But on Android devices it's not being invoked. This behavior is different than that seen 6 months ago (D57085281). It's unclear what changed. This means the cache is not saved due to the following chain never being invoked. `~Runtime()` > `~Adapter()` > `~ComputePipelineCache()` > `save_cache()`.\ ## Solution Move cache saving to `VulkanBackend.cpp`'s model destroy. This makes sense since the cache is tied to the model and not the runtime. ## Resources https://medium.com/martin00001313/mastering-static-objects-in-c-initialization-destruction-and-best-practices-760b17734195 Differential Revision: [D66179917](https://our.internmc.facebook.com/intern/diff/D66179917/) [ghstack-poisoned]
2 parents a0909c2 + 75cf8a4 commit ca3b217

File tree

10 files changed

+248
-13
lines changed

10 files changed

+248
-13
lines changed

.github/scripts/check_labels.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,15 @@ def main() -> None:
4545

4646
try:
4747
if not has_required_labels(pr):
48-
print(LABEL_ERR_MSG)
48+
print(LABEL_ERR_MSG, flush=True)
4949
add_label_err_comment(pr)
5050
if args.exit_non_zero:
51-
sys.exit(1)
51+
raise RuntimeError("PR does not have required labels")
5252
else:
5353
delete_all_label_err_comments(pr)
5454
except Exception as e:
5555
if args.exit_non_zero:
56-
sys.exit(1)
56+
raise RuntimeError(f"Error checking labels: {e}") from e
5757

5858
sys.exit(0)
5959

.github/scripts/github_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,10 @@ def gh_fetch_url(
7272
headers: Optional[Dict[str, str]] = None,
7373
data: Union[Optional[Dict[str, Any]], str] = None,
7474
method: Optional[str] = None,
75-
reader: Callable[[Any], Any] = lambda x: x.read(),
75+
reader: Callable[[Any], Any] = json.load,
7676
) -> Any:
7777
return gh_fetch_url_and_headers(
78-
url, headers=headers, data=data, reader=json.load, method=method
78+
url, headers=headers, data=data, reader=reader, method=method
7979
)[1]
8080

8181

@@ -169,7 +169,7 @@ def gh_post_commit_comment(
169169

170170
def gh_delete_comment(org: str, repo: str, comment_id: int) -> None:
171171
url = f"{GITHUB_API_URL}/repos/{org}/{repo}/issues/comments/{comment_id}"
172-
gh_fetch_url(url, method="DELETE")
172+
gh_fetch_url(url, method="DELETE", reader=lambda x: x.read())
173173

174174

175175
def gh_fetch_merge_base(org: str, repo: str, base: str, head: str) -> str:

backends/arm/operators/op_add.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def define_node(
8282

8383
if needs_rescale:
8484
# Scale output back to 8 bit
85+
# pyre-ignore
8586
tqutils.rescale_node_back_to_int8(node, add_output, scale, tosa_graph)
8687

8788

backends/cadence/aot/TARGETS

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,18 @@ python_library(
132132
],
133133
)
134134

135+
python_library(
136+
name = "graph_builder",
137+
srcs = [
138+
"graph_builder.py",
139+
],
140+
typing = True,
141+
deps = [
142+
"fbcode//caffe2:torch",
143+
"fbcode//executorch/exir:pass_base",
144+
],
145+
)
146+
135147
python_library(
136148
name = "fuse_ops",
137149
srcs = [
@@ -150,3 +162,20 @@ python_library(
150162
"//executorch/exir/passes:spec_prop_pass",
151163
],
152164
)
165+
166+
python_unittest(
167+
name = "test_graph_builder",
168+
srcs = [
169+
"tests/test_graph_builder.py",
170+
],
171+
typing = True,
172+
deps = [
173+
"//caffe2:torch",
174+
"//executorch/backends/cadence/aot:graph_builder",
175+
"//executorch/backends/cadence/aot:pass_utils",
176+
"//executorch/exir:pass_base",
177+
"//executorch/exir/dialects:lib",
178+
"//later:lib",
179+
":ops_registrations"
180+
],
181+
)

backends/cadence/aot/compiler.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,26 @@ def export_to_edge(
196196
# Export the model and lower it to an EdgeProgramManager (in edge IR), and
197197
# apply passes specific to Cadence DSP execution. Return both to print the
198198
# differences.
199-
def export_to_cadence_edge_executorch(
199+
def export_to_cadence(
200+
model: torch.nn.Module,
201+
inputs: tuple[object, ...],
202+
dump_graphs: bool = False,
203+
output_dir: Optional[str] = None,
204+
opt_level: int = 1,
205+
) -> EdgeProgramManager:
206+
edge_prog_manager = export_to_edge(model, inputs)
207+
cadence_passes = get_cadence_passes(opt_level)
208+
209+
# Run a couple required passes for quant/dequant ops
210+
cadence_prog_manager = edge_prog_manager.transform(
211+
cast(
212+
list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes
213+
)
214+
)
215+
return cadence_prog_manager
216+
217+
218+
def export_to_executorch_gen_etrecord(
200219
model: torch.nn.Module,
201220
inputs: tuple[object, ...],
202221
dump_graphs: bool = False,

backends/cadence/aot/export_example.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from executorch.backends.cadence.aot.compiler import (
1818
convert_pt2,
19-
export_to_cadence_edge_executorch,
19+
export_to_executorch_gen_etrecord,
2020
fuse_pt2,
2121
)
2222

@@ -86,8 +86,8 @@ def export_model(
8686
quantized_model = fuse_pt2(converted_model, quantizer)
8787

8888
# Get edge program after Cadence specific passes
89-
exec_prog: ExecutorchProgramManager = export_to_cadence_edge_executorch(
90-
quantized_model, example_inputs, working_dir
89+
exec_prog: ExecutorchProgramManager = export_to_executorch_gen_etrecord(
90+
quantized_model, example_inputs, output_dir=working_dir
9191
)
9292

9393
logging.info("Final exported graph:\n")
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
# pyre-strict
4+
5+
import logging
6+
from typing import Optional, Sequence, Union
7+
8+
import torch
9+
from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue
10+
from torch._subclasses import FakeTensor, FakeTensorMode
11+
from torch.fx.node import Argument, Target
12+
from torch.utils import _pytree as pytree
13+
14+
15+
class GraphBuilder(ExportPass):
16+
"""Utility class for creating a graph module with user-specified ops.
17+
18+
This class allows us to create test graph modules with any ops we want
19+
directly, rather than relying on decomposition or passes.
20+
21+
Usage:
22+
builder = GraphBuilder()
23+
# To insert placeholders, use builder.placeholder.
24+
x = builder.placeholder("x", torch.randn(1, 3, 224, 224))
25+
# To insert an op, use builder.call_operator.
26+
op = builder.call_operator(
27+
some_op
28+
(x, other_args, ...),
29+
)
30+
# Insert outputs as a list of ProxyValues using builder.output.
31+
builder.output([op])
32+
# Get GraphModule from builder.
33+
gm = builder.get_graph_module()
34+
"""
35+
36+
def __init__(self) -> None:
37+
self.exporter = ExportPass()
38+
self.tracer: ExportPass.ExportTracer = self.ExportTracer(
39+
self, torch.fx.graph.CodeGen()
40+
)
41+
self.fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=False)
42+
self.tracer.fake_tensor_mode = self.fake_tensor_mode
43+
44+
# This will be called to create nodes in tracer.
45+
self.interpreter = torch.fx.Interpreter(
46+
torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
47+
)
48+
49+
# pyre-ignore[14]: Inconsistent override.
50+
def placeholder(
51+
self, target: str, fake_tensor: Union[FakeTensor, torch.Tensor]
52+
) -> ProxyValue:
53+
if not isinstance(fake_tensor, FakeTensor):
54+
fake_tensor = self.fake_tensor_mode.from_tensor(fake_tensor)
55+
logging.info(f"Creating placeholder {target} => {fake_tensor.shape}")
56+
placeholder = super().placeholder(target, fake_tensor, NodeMetadata({}))
57+
return placeholder
58+
59+
# pyre-ignore[14]: Inconsistent override.
60+
def output(self, results: list[ProxyValue]) -> ProxyValue:
61+
logging.info(f"Creating outputs {results}")
62+
return super().output(results, NodeMetadata({}))
63+
64+
def get_graph_module(self) -> torch.fx.GraphModule:
65+
return torch.fx.GraphModule(self.tracer.root, self.tracer.graph)
66+
67+
def call_operator(
68+
self,
69+
op, # pyre-ignore
70+
args: tuple[Argument, ...],
71+
kwargs: Optional[dict[str, Argument]] = None,
72+
meta: Optional[NodeMetadata] = None,
73+
) -> ProxyValue:
74+
if meta is None:
75+
meta = NodeMetadata({})
76+
if kwargs is None:
77+
kwargs = {}
78+
return super().call_operator(op, args, kwargs, meta)
79+
80+
81+
def single_op_builder(
82+
placeholders: Sequence[Union[torch.Tensor, FakeTensor]],
83+
op: Target,
84+
args: Sequence[Argument],
85+
kwargs: Optional[dict[str, Argument]] = None,
86+
) -> torch.fx.GraphModule:
87+
"""Create a graph module with a single op.
88+
89+
Args:
90+
placeholders: Placeholders to be used as inputs to the GraphModule.
91+
op: The op to be inserted.
92+
args: The args to be passed to the op.
93+
kwargs: The kwargs to be passed to the op.
94+
95+
Returns:
96+
A graph module with a single op
97+
"""
98+
builder = GraphBuilder()
99+
op_to_placeholder_dict = {
100+
p: builder.placeholder(f"p_{i}", p) for i, p in enumerate(placeholders)
101+
}
102+
proxy_args, proxy_kwargs = pytree.tree_map_only(
103+
(torch.Tensor, FakeTensor), lambda x: op_to_placeholder_dict[x], (args, kwargs)
104+
)
105+
node = builder.call_operator(op, proxy_args, proxy_kwargs)
106+
builder.output([node])
107+
return builder.get_graph_module()

backends/cadence/aot/pass_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,12 @@ def get_node_names_list_from_gm(
8989
continue
9090
graph_nodes.append(node.name)
9191
return graph_nodes
92+
93+
94+
def count_node(graph_module: torch.fx.GraphModule, target: torch.fx.node.Target) -> int:
95+
"""Count the number of nodes with target `target` in the graph."""
96+
total = 0
97+
for node in graph_module.graph.nodes:
98+
if node.op == "call_function" and node.target == target:
99+
total += 1
100+
return total
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
4+
import executorch.backends.cadence.aot.ops_registrations # noqa
5+
import torch
6+
from executorch.backends.cadence.aot.graph_builder import (
7+
GraphBuilder,
8+
single_op_builder,
9+
)
10+
from executorch.backends.cadence.aot.pass_utils import count_node
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
from executorch.exir.pass_base import ExportPass
13+
from later.unittest import TestCase
14+
15+
16+
class TestGraphBuilder(TestCase):
17+
def test_graph_with_single_im2row(self) -> None:
18+
# Create a graph with a single im2row node.
19+
builder = GraphBuilder()
20+
x = builder.placeholder("x", torch.randn(1, 3, 224, 224))
21+
pad_value = builder.placeholder("pad", torch.randn(1))
22+
channels_last = False
23+
im2row = builder.call_operator(
24+
exir_ops.edge.cadence.im2row.default,
25+
# pyre-ignore
26+
(
27+
x,
28+
(2, 2),
29+
(1, 1),
30+
(0, 0),
31+
(1, 1),
32+
pad_value,
33+
channels_last,
34+
),
35+
)
36+
builder.output([im2row])
37+
gm = builder.get_graph_module()
38+
# Check if graph module is valid by running exportpass on it.
39+
gm = ExportPass().call(gm).graph_module
40+
41+
# Check graph has a single im2row node.
42+
self.assertEqual(len([gm.graph.nodes]), 1)
43+
self.assertEqual(count_node(gm, exir_ops.edge.cadence.im2row.default), 1)
44+
45+
46+
class TestSingleOpBuilderUtility(TestCase):
47+
def test_graph_with_single_im2row(self) -> None:
48+
# Create a graph with a single im2row node.
49+
x = torch.randn(1, 3, 224, 224)
50+
pad_value = torch.randn(1)
51+
channels_last = False
52+
gm = single_op_builder(
53+
(x, pad_value),
54+
exir_ops.edge.cadence.im2row.default,
55+
(
56+
x,
57+
(2, 2),
58+
(1, 1),
59+
(0, 0),
60+
(1, 1),
61+
pad_value,
62+
channels_last,
63+
),
64+
)
65+
# Check if graph module is valid by running exportpass on it.
66+
gm = ExportPass().call(gm).graph_module
67+
68+
# Check graph has a single im2row node.
69+
self.assertEqual(len([gm.graph.nodes]), 1)
70+
self.assertEqual(count_node(gm, exir_ops.edge.cadence.im2row.default), 1)

examples/demo-apps/apple_ios/ExecuTorchDemo/ExecuTorchDemo/Sources/MobileNet/Test/MobileNetClassifierTest.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@ final class MobileNetClassifierTest: XCTestCase {
4848
modelFilePath: modelFilePath,
4949
labelsFilePath: labelsFilePath)
5050
for expectedClassification in [
51-
Classification(label: "Arctic fox", confidence: 0.92),
52-
Classification(label: "Samoyed", confidence: 0.74),
53-
Classification(label: "hot pot", confidence: 0.82),
51+
Classification(label: "Arctic fox", confidence: 0.9),
52+
Classification(label: "Samoyed", confidence: 0.7),
53+
Classification(label: "hot pot", confidence: 0.8),
5454
] {
5555
guard
5656
let imagePath = Bundle(for: type(of: self))

0 commit comments

Comments
 (0)