Skip to content

Commit 88b8531

Browse files
author
ssjia
committed
Update base for Update on "[etrecord] Implement generic fallback for GraphModuleSerializer.handle_call_function"
Title says it all! Implement the case where `node.target` is neither `torch._ops.OpOverload` or `torch._ops.HigherOrderOperator`, instead of throwing an exception. Differential Revision: [D88216198](https://our.internmc.facebook.com/intern/diff/D88216198/) [ghstack-poisoned]
2 parents 33ec615 + 3f34709 commit 88b8531

File tree

10 files changed

+311
-312
lines changed

10 files changed

+311
-312
lines changed

backends/apple/coreml/runtime/delegate/ETCoreMLStrings.mm

Lines changed: 38 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -101,39 +101,50 @@ + (NSString *)debugSymbolToHandlesKeyName {
101101
}
102102

103103
+ (nullable NSString *)assetsDirectoryPath {
104-
static dispatch_once_t onceToken;
105-
static NSString *result = nil;
106-
dispatch_once(&onceToken, ^{
107-
NSArray<NSString *> *paths = NSSearchPathForDirectoriesInDomains(NSCachesDirectory, NSUserDomainMask, YES);
108-
if (paths.count > 0) {
109-
result = [paths.lastObject stringByAppendingPathComponent:self.productName];
110-
}
111-
});
112-
113-
return result;
104+
#if defined(EXECUTORCH_COREML_ASSETS_DIRECTORY_PATH)
105+
return @(EXECUTORCH_COREML_ASSETS_DIRECTORY_PATH);
106+
#else
107+
static dispatch_once_t onceToken;
108+
static NSString *result = nil;
109+
dispatch_once(&onceToken, ^{
110+
NSArray<NSString *> *paths = NSSearchPathForDirectoriesInDomains(NSCachesDirectory, NSUserDomainMask, YES);
111+
if (paths.count > 0) {
112+
result = [paths.lastObject stringByAppendingPathComponent:self.productName];
113+
}
114+
});
115+
116+
return result;
117+
#endif
114118
}
115119

116120
+ (nullable NSString *)trashDirectoryPath {
117-
static dispatch_once_t onceToken;
118-
static NSString *result = nil;
119-
dispatch_once(&onceToken, ^{
120-
result = [NSTemporaryDirectory() stringByAppendingPathComponent:self.productName];
121-
});
122-
123-
return result;
121+
#if defined(EXECUTORCH_COREML_TRASH_DIRECTORY_PATH)
122+
return @(EXECUTORCH_COREML_TRASH_DIRECTORY_PATH);
123+
#else
124+
static dispatch_once_t onceToken;
125+
static NSString *result = nil;
126+
dispatch_once(&onceToken, ^{
127+
result = [NSTemporaryDirectory() stringByAppendingPathComponent:self.productName];
128+
});
129+
130+
return result;
131+
#endif
124132
}
125133

126134
+ (nullable NSString *)databaseDirectoryPath {
127-
static dispatch_once_t onceToken;
128-
static NSString *result = nil;
129-
dispatch_once(&onceToken, ^{
130-
NSArray<NSString *> *paths = NSSearchPathForDirectoriesInDomains(NSApplicationSupportDirectory, NSUserDomainMask, YES);
131-
if (paths.count > 0) {
132-
result = [paths.lastObject stringByAppendingPathComponent:self.productName];
133-
}
134-
});
135-
136-
return result;
135+
#if defined(EXECUTORCH_COREML_DATABASE_DIRECTORY_PATH)
136+
return @(EXECUTORCH_COREML_DATABASE_DIRECTORY_PATH);
137+
#else
138+
static dispatch_once_t onceToken;
139+
static NSString *result = nil;
140+
dispatch_once(&onceToken, ^{
141+
NSArray<NSString *> *paths = NSSearchPathForDirectoriesInDomains(NSApplicationSupportDirectory, NSUserDomainMask, YES);
142+
if (paths.count > 0) {
143+
result = [paths.lastObject stringByAppendingPathComponent:self.productName];
144+
}
145+
});
146+
return result;
147+
#endif
137148
}
138149

139150

backends/cadence/aot/fuse_ops.py

Lines changed: 50 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from executorch.backends.cadence.aot.pass_utils import (
3535
CadencePassAttribute,
3636
register_cadence_pass,
37+
RemoveOrReplacePassInterface,
3738
)
3839
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
3940
from executorch.exir.dialects._ops import ops as exir_ops
@@ -454,7 +455,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
454455

455456

456457
@register_cadence_pass(CadencePassAttribute(opt_level=1))
457-
class FuseCascadedTransposeOrPermuteOps(ExportPass):
458+
class FuseCascadedTransposeOrPermuteOps(RemoveOrReplacePassInterface):
458459
"""
459460
Fuse a cascaded chain of transpose and permute ops
460461
"""
@@ -464,63 +465,61 @@ class FuseCascadedTransposeOrPermuteOps(ExportPass):
464465
exir_ops.edge.aten.permute_copy.default,
465466
}
466467

467-
# Find a chain of transpose or permute ops, and fuse them into a single permute op.
468+
@property
469+
def targets(self) -> list[EdgeOpOverload]:
470+
return list(self.transpose_or_permute_target)
468471

469-
def fuse_cascaded_transpose_or_permute_ops(
470-
self, graph_module: torch.fx.GraphModule
471-
):
472-
graph = graph_module.graph
473-
for node in graph.nodes:
474-
# We are only interested in permute/transpose ops
475-
if node.target not in self.transpose_or_permute_target:
476-
continue
477-
# Get the cascaded chain of transpose/permute ops starting at node
478-
cascaded_transpose_or_permute_ops = get_cascaded_ops(
479-
[node], self.transpose_or_permute_target
472+
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
473+
# Get the cascaded chain of transpose/permute ops starting at node
474+
cascaded_transpose_or_permute_ops = get_cascaded_ops(
475+
[node], self.transpose_or_permute_target
476+
)
477+
# The chain must have more than 1 node
478+
if len(cascaded_transpose_or_permute_ops) == 1:
479+
return False
480+
481+
# Get shape from node metadata
482+
val = node.meta.get("val")
483+
if val is None:
484+
return False
485+
out_shape = val.shape
486+
out_dims = len(out_shape)
487+
488+
# This is the trivial dimension order
489+
dims = list(range(out_dims))
490+
# Compute the effect of the chain on dims
491+
for tp in cascaded_transpose_or_permute_ops:
492+
dims = (
493+
get_transposed_dims(tp, dims)
494+
if tp.target == exir_ops.edge.aten.transpose_copy.int
495+
else get_permuted_dims(tp, dims)
480496
)
481-
# The chain must have more than 1 node
482-
if len(cascaded_transpose_or_permute_ops) == 1:
483-
continue
484497

485-
out_shape = get_shape(graph_module, node)
486-
assert out_shape is not None
487-
out_dims = len(out_shape)
488-
# This is the trivial dimension order
489-
dims = list(range(out_dims))
490-
# Compute the effect of the chain on dims
491-
for tp in cascaded_transpose_or_permute_ops:
492-
dims = (
493-
get_transposed_dims(tp, dims)
494-
if tp.target == exir_ops.edge.aten.transpose_copy.int
495-
else get_permuted_dims(tp, dims)
496-
)
498+
graph = node.graph
497499

498-
# In case the permute chain cancelled each other, the final dims will
499-
# be the same as the initial order. In that case, the chain was nop.
500-
# Otherwise create a new permute op that encompasses the effect of the
501-
# chain.
502-
if dims == list(range(out_dims)):
503-
cascaded_transpose_or_permute_ops[-1].replace_all_uses_with(
504-
node.args[0]
500+
# In case the permute chain cancelled each other, the final dims will
501+
# be the same as the initial order. In that case, the chain was nop.
502+
# Otherwise create a new permute op that encompasses the effect of the
503+
# chain.
504+
if dims == list(range(out_dims)):
505+
cascaded_transpose_or_permute_ops[-1].replace_all_uses_with(
506+
cast(torch.fx.Node, node.args[0])
507+
)
508+
else:
509+
with graph.inserting_before(cascaded_transpose_or_permute_ops[-1]):
510+
new_permute = graph.call_function(
511+
exir_ops.edge.aten.permute_copy.default,
512+
args=(node.args[0], dims),
505513
)
506-
else:
507-
with graph.inserting_before(cascaded_transpose_or_permute_ops[-1]):
508-
new_permute = graph.call_function(
509-
exir_ops.edge.aten.permute_copy.default,
510-
args=(node.args[0], dims),
511-
)
512-
cascaded_transpose_or_permute_ops[-1].replace_all_uses_with(new_permute)
514+
new_permute.meta = cascaded_transpose_or_permute_ops[-1].meta
515+
cascaded_transpose_or_permute_ops[-1].replace_all_uses_with(new_permute)
513516

514-
# Now erase the chain
515-
for tp in reversed(cascaded_transpose_or_permute_ops):
516-
graph.erase_node(tp)
517-
518-
graph_module.recompile()
517+
# Now erase the chain (except the first node which will be handled by the interface)
518+
for tp in reversed(cascaded_transpose_or_permute_ops[1:]):
519+
graph.erase_node(tp)
519520

520-
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
521-
self.fuse_cascaded_transpose_or_permute_ops(graph_module)
522-
result = super().call(graph_module)
523-
return result
521+
# Return True to indicate the first node in the chain should be removed
522+
return True
524523

525524

526525
@register_cadence_pass(CadencePassAttribute(opt_level=1))

0 commit comments

Comments
 (0)