Skip to content

Commit 8155889

Browse files
authored
Merge branch 'main' into jgibson/grid_sample_2d_nchw_portable_kernels
2 parents 3512a18 + c00d726 commit 8155889

35 files changed

+818
-432
lines changed

backends/apple/coreml/runtime/delegate/ETCoreMLStrings.mm

Lines changed: 38 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -101,39 +101,50 @@ + (NSString *)debugSymbolToHandlesKeyName {
101101
}
102102

103103
+ (nullable NSString *)assetsDirectoryPath {
104-
static dispatch_once_t onceToken;
105-
static NSString *result = nil;
106-
dispatch_once(&onceToken, ^{
107-
NSArray<NSString *> *paths = NSSearchPathForDirectoriesInDomains(NSCachesDirectory, NSUserDomainMask, YES);
108-
if (paths.count > 0) {
109-
result = [paths.lastObject stringByAppendingPathComponent:self.productName];
110-
}
111-
});
112-
113-
return result;
104+
#if defined(EXECUTORCH_COREML_ASSETS_DIRECTORY_PATH)
105+
return @(EXECUTORCH_COREML_ASSETS_DIRECTORY_PATH);
106+
#else
107+
static dispatch_once_t onceToken;
108+
static NSString *result = nil;
109+
dispatch_once(&onceToken, ^{
110+
NSArray<NSString *> *paths = NSSearchPathForDirectoriesInDomains(NSCachesDirectory, NSUserDomainMask, YES);
111+
if (paths.count > 0) {
112+
result = [paths.lastObject stringByAppendingPathComponent:self.productName];
113+
}
114+
});
115+
116+
return result;
117+
#endif
114118
}
115119

116120
+ (nullable NSString *)trashDirectoryPath {
117-
static dispatch_once_t onceToken;
118-
static NSString *result = nil;
119-
dispatch_once(&onceToken, ^{
120-
result = [NSTemporaryDirectory() stringByAppendingPathComponent:self.productName];
121-
});
122-
123-
return result;
121+
#if defined(EXECUTORCH_COREML_TRASH_DIRECTORY_PATH)
122+
return @(EXECUTORCH_COREML_TRASH_DIRECTORY_PATH);
123+
#else
124+
static dispatch_once_t onceToken;
125+
static NSString *result = nil;
126+
dispatch_once(&onceToken, ^{
127+
result = [NSTemporaryDirectory() stringByAppendingPathComponent:self.productName];
128+
});
129+
130+
return result;
131+
#endif
124132
}
125133

126134
+ (nullable NSString *)databaseDirectoryPath {
127-
static dispatch_once_t onceToken;
128-
static NSString *result = nil;
129-
dispatch_once(&onceToken, ^{
130-
NSArray<NSString *> *paths = NSSearchPathForDirectoriesInDomains(NSApplicationSupportDirectory, NSUserDomainMask, YES);
131-
if (paths.count > 0) {
132-
result = [paths.lastObject stringByAppendingPathComponent:self.productName];
133-
}
134-
});
135-
136-
return result;
135+
#if defined(EXECUTORCH_COREML_DATABASE_DIRECTORY_PATH)
136+
return @(EXECUTORCH_COREML_DATABASE_DIRECTORY_PATH);
137+
#else
138+
static dispatch_once_t onceToken;
139+
static NSString *result = nil;
140+
dispatch_once(&onceToken, ^{
141+
NSArray<NSString *> *paths = NSSearchPathForDirectoriesInDomains(NSApplicationSupportDirectory, NSUserDomainMask, YES);
142+
if (paths.count > 0) {
143+
result = [paths.lastObject stringByAppendingPathComponent:self.productName];
144+
}
145+
});
146+
return result;
147+
#endif
137148
}
138149

139150

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from .annotate_decomposed_matmul import AnnotateDecomposedMatmulPass # noqa
1010
from .annotate_output_dim_order_pass import AnnotateOutputDimOrderPass # noqa
1111
from .broadcast_args_pass import BroadcastArgsPass # noqa
12-
from .cast_bool_to_int8_pass import CastBoolToInt8Pass # noqa
1312
from .cast_int64_pass import CastInt64BuffersToInt32Pass # noqa
1413
from .cast_to_int32_pass import CastToInt32Pass # noqa
1514
from .conv1d_unsqueeze_pass import Conv1dUnsqueezePass # noqa
@@ -101,6 +100,7 @@
101100
from .match_arg_dtype_pass import MatchArgDtypePass # noqa
102101
from .match_arg_ranks_pass import MatchArgRanksPass # noqa
103102
from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa
103+
from .promote_bool_operands_pass import PromoteBoolOperandsPass # noqa
104104
from .remove_getitem_pass import RemoveGetItemPass # noqa
105105
from .remove_graph_asserts_pass import RemoveGraphAssertsPass # noqa
106106
from .remove_noop_pass import RemoveNoopPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
AnnotateDecomposedMatmulPass,
1515
AnnotateOutputDimOrderPass,
1616
BroadcastArgsPass,
17-
CastBoolToInt8Pass,
1817
CastInt64BuffersToInt32Pass,
1918
CastToInt32Pass,
2019
ComputeConstantOpsAOTPass,
@@ -93,6 +92,7 @@
9392
InsertTableOpsPass,
9493
MatchArgDtypePass,
9594
MatchArgRanksPass,
95+
PromoteBoolOperandsPass,
9696
QuantizeClampArgumentsPass,
9797
RemoveGetItemPass,
9898
RemoveGraphAssertsPass,
@@ -218,7 +218,7 @@ def _tosa_pipeline(
218218
DecomposeEluPass(),
219219
DecomposeExpm1Pass(),
220220
DecomposeIntPowPass(),
221-
CastBoolToInt8Pass(),
221+
PromoteBoolOperandsPass(),
222222
DecomposeSinhPass(),
223223
DecomposeSignPass(),
224224
DecomposeFloorDividePass(),
@@ -330,7 +330,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
330330
DecomposeScaledDotProductAttentionPass(),
331331
DecomposeRoundPass(),
332332
DecomposeLogitPass(),
333-
CastBoolToInt8Pass(),
333+
PromoteBoolOperandsPass(),
334334
DecomposeSignPass(),
335335
DecomposeAddmmPass(),
336336
DecomposeRemainderPass(),

backends/arm/_passes/cast_bool_to_int8_pass.py

Lines changed: 0 additions & 63 deletions
This file was deleted.

backends/arm/_passes/fuse_equal_placeholders_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
6060
is_int48,
6161
str(t_cpu.dtype),
6262
tuple(t_cpu.shape),
63-
hashlib.sha1(data_bytes).hexdigest(),
63+
hashlib.sha1(data_bytes, usedforsecurity=False).hexdigest(),
6464
)
6565
hash_buckets[key].append((node, t_cpu))
6666

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# The TOSA BITWISE_AND, BITWISE_OR, and BITWISE_XOR don't handle bool inputs.
7+
# When a targeted op receives boolean tensors, we promote them to an integer type before
8+
# invocation and cast the result back to the expected dtype afterwards.
9+
10+
from typing import Set, Type
11+
12+
import torch
13+
14+
from executorch.backends.arm._passes.arm_pass import ArmPass
15+
from executorch.exir.dialects._ops import ops as exir_ops
16+
from executorch.exir.pass_base import ExportPass
17+
18+
19+
class PromoteBoolOperandsPass(ArmPass):
20+
"""Promote boolean operands to the appropriate integer dtype for unsupported ops."""
21+
22+
_passes_required_after: Set[Type[ExportPass]] = set()
23+
24+
targeted_ops = {
25+
exir_ops.edge.aten.bitwise_and.Tensor,
26+
exir_ops.edge.aten.bitwise_or.Tensor,
27+
exir_ops.edge.aten.bitwise_xor.Tensor,
28+
exir_ops.edge.aten.mul.Tensor,
29+
}
30+
31+
def call_operator(self, op, args, kwargs, meta):
32+
if op not in self.targeted_ops:
33+
return super().call_operator(op, args, kwargs, meta)
34+
35+
original_dtypes = [arg.data.dtype for arg in args]
36+
if torch.bool not in original_dtypes:
37+
return super().call_operator(op, args, kwargs, meta)
38+
39+
# select the first non-bool dtype, or None if all bool
40+
promoted_dtype = next((dt for dt in original_dtypes if dt != torch.bool), None)
41+
42+
# if we don't have a dtype specified by the op, promote to default choice for the op
43+
if promoted_dtype is None:
44+
if op == exir_ops.edge.aten.mul.Tensor:
45+
# mul as int32
46+
promoted_dtype = torch.int32
47+
else:
48+
# bitwise ops can be int8
49+
promoted_dtype = torch.int8
50+
51+
target_dtypes = []
52+
for dt in original_dtypes:
53+
if dt == torch.bool:
54+
target_dtypes.append(promoted_dtype)
55+
else:
56+
target_dtypes.append(dt)
57+
58+
new_args = []
59+
for arg, original_dtype, target_dtype in zip(
60+
args, original_dtypes, target_dtypes
61+
):
62+
if original_dtype == target_dtype:
63+
new_args.append(arg)
64+
else:
65+
new_args.append(
66+
super().call_operator(
67+
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
68+
(arg,),
69+
{"dtype": target_dtype},
70+
meta,
71+
)
72+
)
73+
74+
output = super().call_operator(
75+
op,
76+
tuple(new_args),
77+
kwargs,
78+
meta,
79+
)
80+
81+
if all(dtype == torch.bool for dtype in original_dtypes):
82+
output = super().call_operator(
83+
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
84+
(output,),
85+
{"dtype": torch.bool},
86+
meta,
87+
)
88+
return output

backends/arm/_passes/unsqueeze_scalar_placeholders_pass.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from executorch.exir import ExportedProgram
1212
from executorch.exir.pass_base import ExportPass, PassResult
1313
from torch._export.utils import is_buffer, is_param
14+
from torch.export.graph_signature import InputKind
1415

1516

1617
class UnsqueezeScalarPlaceholdersPass(ArmPass):
@@ -42,17 +43,30 @@ def call(self, graph_module: torch.fx.GraphModule):
4243
else:
4344
continue
4445

45-
tensor = self.exported_program.state_dict[name]
46+
tensor = self.exported_program.state_dict.get(name)
4647

48+
# If we have a persistent=False buffer with no entry in state_dict
49+
spec = next(
50+
s
51+
for s in self.exported_program.graph_signature.input_specs
52+
if getattr(s.arg, "name", None) == node.name
53+
)
54+
is_non_persistent_buffer = (
55+
spec.kind is InputKind.BUFFER and spec.persistent is False
56+
)
57+
if tensor is None and is_non_persistent_buffer:
58+
fake = node.meta["val"]
59+
tensor = torch.ones_like(fake)
60+
61+
# If we have a scalar, unsqueeze it
4762
if tensor.dim() == 0:
48-
self.exported_program.state_dict[name] = tensor.unsqueeze(0)
49-
node.meta["val"] = node.meta["val"].fake_mode.from_tensor(
50-
tensor.unsqueeze(0), static_shapes=True
51-
)
52-
else:
53-
node.meta["val"] = node.meta["val"].fake_mode.from_tensor(
54-
tensor, static_shapes=True
55-
)
63+
tensor = tensor.unsqueeze(0)
64+
65+
# update or create entry in state_dict, recreate fake
66+
self.exported_program.state_dict[name] = tensor
67+
node.meta["val"] = node.meta["val"].fake_mode.from_tensor(
68+
tensor, static_shapes=True
69+
)
5670

5771
graph_module.recompile()
5872
graph_module = super().call(graph_module).graph_module

backends/arm/test/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def pytest_addoption(parser):
4040
def try_addoption(*args, **kwargs):
4141
try:
4242
parser.addoption(*args, **kwargs)
43-
except Exception:
43+
except Exception: # nosec B110 - pytest redefines options, safe to ignore
4444
pass
4545

4646
try_addoption("--arm_quantize_io", action="store_true", help="Deprecated.")
@@ -85,7 +85,7 @@ def set_random_seed():
8585

8686
if os.environ.get("ARM_TEST_SEED", "RANDOM") == "RANDOM":
8787
random.seed() # reset seed, in case any other test has fiddled with it
88-
seed = random.randint(0, 2**32 - 1)
88+
seed = random.randint(0, 2**32 - 1) # nosec B311 - non-crypto seed for tests
8989
torch.manual_seed(seed)
9090
else:
9191
seed_str = os.environ.get("ARM_TEST_SEED", "0")

backends/arm/test/misc/test_debug_hook.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def _get_action_str() -> str:
3131
name="convolution",
3232
target="aten.convolution.default",
3333
graph_id=6052414368,
34-
pass_name="ExportedProgram.module()",
34+
pass_name="ExportedProgram.module()", # nosec B106 - static test string, not a secret
3535
action="create",
3636
from_node=[],
3737
_get_action_string=_get_action_str,
@@ -41,7 +41,7 @@ def _get_action_str() -> str:
4141
name="convolution",
4242
target="aten.convolution.default",
4343
graph_id=5705954832,
44-
pass_name="Interpreter_PropagateUnbackedSymInts",
44+
pass_name="Interpreter_PropagateUnbackedSymInts", # nosec B106 - static test string, not a secret
4545
action="create",
4646
from_node=[from_node_2],
4747
_get_action_string=_get_action_str,
@@ -69,7 +69,7 @@ def _get_action_str() -> str:
6969
name="convolution",
7070
target="aten.convolution.default",
7171
graph_id=5705954832,
72-
pass_name="Interpreter_PropagateUnbackedSymInts",
72+
pass_name="Interpreter_PropagateUnbackedSymInts", # nosec B106 - static test string, not a secret
7373
action="create",
7474
from_node=[],
7575
_get_action_string=_get_action_str,

backends/arm/test/misc/test_save_exported_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ def test_save_load_exported_int_model():
4848
torch.export.save(quantized_exported_module, file_path)
4949

5050
# Verify that we can load the model back
51-
loaded_model = torch.export.load(file_path)
51+
loaded_model = torch.export.load(
52+
file_path
53+
) # nosec B614 - loads trusted test artifact
5254
for original_node, loaded_node in zip(
5355
quantized_exported_module.graph.nodes, loaded_model.graph.nodes
5456
):

0 commit comments

Comments
 (0)