Skip to content

Commit 4f4611b

Browse files
committed
Update on "make to_edge_transform_and_lower support etrecord generation"
Differential Revision: [D79336982](https://our.internmc.facebook.com/intern/diff/D79336982/) umbrella issue: #12961 [ghstack-poisoned]
2 parents 4b786bd + cc9df36 commit 4f4611b

File tree

10 files changed

+154
-70
lines changed

10 files changed

+154
-70
lines changed

backends/apple/coreml/partition/coreml_partitioner.py

Lines changed: 80 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
PartitionResult,
2121
)
2222
from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer
23+
from executorch.exir.dialects._ops import ops as exir_ops
2324
from torch.export.exported_program import ExportedProgram
2425
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
2526
from torch.fx.passes.operator_support import OperatorSupportBase
@@ -56,6 +57,80 @@ def log_once(self, msg: str) -> None:
5657
logger.info(msg)
5758
self._logged_msgs.add(msg)
5859

60+
def should_skip_op_for_delegation(self, node_target_name: str) -> bool:
61+
skipped_ops = self.skip_ops_for_coreml_delegation or []
62+
if node_target_name in skipped_ops:
63+
assert (
64+
not self.lower_full_graph
65+
), f"Cannot skip {node_target_name} because lower_full_graph is True. Please set skip_ops_for_coreml_delegation=None or lower_full_graph=False in the CoreMLPartitioner"
66+
self.log_once(
67+
"Skipping op for CoreML delegation because it is in skip_ops_for_coreml_delegation: "
68+
+ node_target_name
69+
)
70+
return True
71+
return False
72+
73+
def should_override_support(self, node) -> bool:
74+
# https://github.com/apple/coremltools/issues/2573
75+
if (
76+
node.target
77+
in [
78+
torch.ops.aten.sub.Tensor,
79+
exir_ops.edge.aten.sub.Tensor,
80+
torch.ops.aten.add.Tensor,
81+
exir_ops.edge.aten.add.Tensor,
82+
]
83+
and "alpha" in node.kwargs
84+
and node.kwargs["alpha"] != 1
85+
):
86+
self.log_once(
87+
"torch.ops.aten.{sub, add}.Tensor with alpha != 1 is not supported by CoreML. Overriding support."
88+
)
89+
return True
90+
91+
# https://github.com/apple/coremltools/issues/2565
92+
if node.target in [
93+
torch.ops.aten.diagonal.default,
94+
torch.ops.aten.diagonal_copy.default,
95+
exir_ops.edge.aten.diagonal.default,
96+
exir_ops.edge.aten.diagonal_copy.default,
97+
]:
98+
self.log_once(
99+
"torch.ops.aten.diagonal.default has a bug in CoreML. Overriding op support."
100+
)
101+
return True
102+
103+
# https://github.com/apple/coremltools/issues/2569
104+
if node.target in [
105+
torch.ops.aten.acosh.default,
106+
exir_ops.edge.aten.acosh.default,
107+
torch.ops.aten.asinh.default,
108+
exir_ops.edge.aten.asinh.default,
109+
]:
110+
self.log_once(
111+
"torch.ops.aten.{acosh, asinh}.default is not supported by CoreML. Overriding op support."
112+
)
113+
return True
114+
115+
# TODO: enable this after bugs in ExecuTorch's partitioner are fixed
116+
# # If lower_full_graph=False, do not partition nodes with symbolic args because it can result in symbolic args
117+
# # in the placeholders due to partitioning, which CoreML does not support
118+
# if not self.lower_full_graph and any(
119+
# isinstance(arg, torch.fx.Node)
120+
# and isinstance(
121+
# arg.meta.get("val", None),
122+
# (torch.SymInt, torch.SymBool, torch.SymFloat),
123+
# )
124+
# for arg in node.args
125+
# ):
126+
# self.log_once(
127+
# "Skipping op for CoreML delegation because it contains symbolic args: "
128+
# + node_target_name
129+
# )
130+
# return True
131+
132+
return False
133+
59134
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
60135
# get_attr node can always be supported on any backend
61136
if node.op == "get_attr":
@@ -64,38 +139,17 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
64139
elif node.op == "call_function":
65140
# skip ops if specified by user
66141
node_target_name = getattr(node.target, "__name__", "").lower()
67-
if node_target_name in (self.skip_ops_for_coreml_delegation or []):
68-
self.log_once(
69-
"Skipping op for CoreML delegation because it is in skip_ops_for_coreml_delegation: "
70-
+ node_target_name
71-
)
72-
assert (
73-
not self.lower_full_graph
74-
), "Cannot have skip_ops_for_coreml_delegation when lower_full_graph is True"
75-
return False
76142

77-
# TODO: enable this after bugs in ExecuTorch's partitioner are fixed
78-
# # If lower_full_graph=False, do not partition nodes with symbolic args because it can result in symbolic args
79-
# # in the placeholders due to partitioning, which CoreML does not support
80-
# if not self.lower_full_graph and any(
81-
# isinstance(arg, torch.fx.Node)
82-
# and isinstance(
83-
# arg.meta.get("val", None),
84-
# (torch.SymInt, torch.SymBool, torch.SymFloat),
85-
# )
86-
# for arg in node.args
87-
# ):
88-
# self.log_once(
89-
# "Skipping op for CoreML delegation because it contains symbolic args: "
90-
# + node_target_name
91-
# )
92-
# assert not self.lower_full_graph
93-
# return False
143+
if self.should_skip_op_for_delegation(node_target_name):
144+
return False
94145

95146
# query coremltools to see if node is supported
96147
is_supported = ct.converters.mil.frontend.torch.is_torch_fx_node_supported(
97148
node
98149
)
150+
if self.should_override_support(node):
151+
is_supported = False
152+
99153
if not is_supported:
100154
if self.lower_full_graph:
101155
raise NotImplementedError(
@@ -126,7 +180,6 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
126180

127181

128182
class CoreMLPartitioner(Partitioner):
129-
130183
def __init__(
131184
self,
132185
*,

backends/arm/test/ops/test_asinh.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99

1010
from executorch.backends.arm.test import common
1111
from executorch.backends.arm.test.tester.test_pipeline import (
12-
EthosU55PipelineBI,
13-
EthosU85PipelineBI,
14-
TosaPipelineBI,
15-
TosaPipelineMI,
12+
EthosU55PipelineINT,
13+
EthosU85PipelineINT,
14+
TosaPipelineFP,
15+
TosaPipelineINT,
1616
)
1717

1818
input_t = Tuple[torch.Tensor] # Input x
@@ -36,8 +36,8 @@ def forward(self, x):
3636

3737

3838
@common.parametrize("test_data", test_data_suite)
39-
def test_asin_tosa_MI(test_data: Tuple):
40-
pipeline = TosaPipelineMI[input_t](
39+
def test_asinh_tosa_FP(test_data: Tuple):
40+
pipeline = TosaPipelineFP[input_t](
4141
Asinh(),
4242
(test_data(),),
4343
aten_op,
@@ -47,8 +47,8 @@ def test_asin_tosa_MI(test_data: Tuple):
4747

4848

4949
@common.parametrize("test_data", test_data_suite)
50-
def test_asin_tosa_BI(test_data: Tuple):
51-
pipeline = TosaPipelineBI[input_t](
50+
def test_asinh_tosa_INT(test_data: Tuple):
51+
pipeline = TosaPipelineINT[input_t](
5252
Asinh(),
5353
(test_data(),),
5454
aten_op=[],
@@ -59,8 +59,8 @@ def test_asin_tosa_BI(test_data: Tuple):
5959

6060
@common.parametrize("test_data", test_data_suite)
6161
@common.XfailIfNoCorstone300
62-
def test_asin_u55_BI(test_data: Tuple):
63-
pipeline = EthosU55PipelineBI[input_t](
62+
def test_asinh_u55_INT(test_data: Tuple):
63+
pipeline = EthosU55PipelineINT[input_t](
6464
Asinh(),
6565
(test_data(),),
6666
aten_ops=[],
@@ -70,8 +70,8 @@ def test_asin_u55_BI(test_data: Tuple):
7070

7171
@common.parametrize("test_data", test_data_suite)
7272
@common.XfailIfNoCorstone320
73-
def test_asin_u85_BI(test_data: Tuple):
74-
pipeline = EthosU85PipelineBI[input_t](
73+
def test_asinh_u85_INT(test_data: Tuple):
74+
pipeline = EthosU85PipelineINT[input_t](
7575
Asinh(),
7676
(test_data(),),
7777
aten_ops=[],

extension/android/BUCK

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ non_fbcode_target(_kind = fb_android_library,
1313
"executorch_android/src/main/java/org/pytorch/executorch/MethodMetadata.java",
1414
"executorch_android/src/main/java/org/pytorch/executorch/Module.java",
1515
"executorch_android/src/main/java/org/pytorch/executorch/Tensor.java",
16-
"executorch_android/src/main/java/org/pytorch/executorch/TrainingModule.java",
17-
"executorch_android/src/main/java/org/pytorch/executorch/SGD.java",
1816
"executorch_android/src/main/java/org/pytorch/executorch/annotations/Experimental.java",
17+
"executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java",
18+
"executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java",
1919
],
2020
autoglob = False,
2121
language = "JAVA",
Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,24 @@
55
* This source code is licensed under the BSD-style license found in the
66
* LICENSE file in the root directory of this source tree.
77
*/
8-
package org.pytorch.executorch
8+
9+
package org.pytorch.executorch.training
910

1011
import android.Manifest
1112
import android.util.Log
1213
import androidx.test.ext.junit.runners.AndroidJUnit4
1314
import androidx.test.rule.GrantPermissionRule
14-
import java.io.File
15-
import java.io.IOException
16-
import java.net.URISyntaxException
1715
import org.apache.commons.io.FileUtils
1816
import org.junit.Assert
1917
import org.junit.Rule
2018
import org.junit.Test
2119
import org.junit.runner.RunWith
22-
import org.pytorch.executorch.TestFileUtils.getTestFilePath
20+
import org.pytorch.executorch.EValue
21+
import org.pytorch.executorch.Tensor
22+
import org.pytorch.executorch.TestFileUtils
23+
import java.io.File
24+
import java.io.IOException
25+
import java.net.URISyntaxException
2326
import kotlin.random.Random
2427
import kotlin.test.assertContains
2528

@@ -36,17 +39,20 @@ class TrainingModuleE2ETest {
3639
val pteFilePath = "/xor.pte"
3740
val ptdFilePath = "/xor.ptd"
3841

39-
val pteFile = File(getTestFilePath(pteFilePath))
42+
val pteFile = File(TestFileUtils.getTestFilePath(pteFilePath))
4043
val pteInputStream = javaClass.getResourceAsStream(pteFilePath)
4144
FileUtils.copyInputStreamToFile(pteInputStream, pteFile)
4245
pteInputStream.close()
4346

44-
val ptdFile = File(getTestFilePath(ptdFilePath))
47+
val ptdFile = File(TestFileUtils.getTestFilePath(ptdFilePath))
4548
val ptdInputStream = javaClass.getResourceAsStream(ptdFilePath)
4649
FileUtils.copyInputStreamToFile(ptdInputStream, ptdFile)
4750
ptdInputStream.close()
4851

49-
val module = TrainingModule.load(getTestFilePath(pteFilePath), getTestFilePath(ptdFilePath))
52+
val module = TrainingModule.load(
53+
TestFileUtils.getTestFilePath(pteFilePath),
54+
TestFileUtils.getTestFilePath(ptdFilePath)
55+
)
5056
val params = module.namedParameters("forward")
5157

5258
Assert.assertEquals(4, params.size)
@@ -75,7 +81,10 @@ class TrainingModuleE2ETest {
7581
val targetDex = inputDex + 1
7682
val input = dataset.get(inputDex)
7783
val target = dataset.get(targetDex)
78-
val out = module.executeForwardBackward("forward", EValue.from(input), EValue.from(target))
84+
val out = module.executeForwardBackward("forward",
85+
EValue.from(input),
86+
EValue.from(target)
87+
)
7988
val gradients = module.namedGradients("forward")
8089

8190
if (i == 0) {
@@ -96,7 +105,9 @@ class TrainingModuleE2ETest {
96105
input.getDataAsFloatArray()[0],
97106
input.getDataAsFloatArray()[1],
98107
out[1].toTensor().getDataAsLongArray()[0],
99-
target.getDataAsLongArray()[0]));
108+
target.getDataAsLongArray()[0]
109+
)
110+
);
100111
}
101112

102113
sgd.step(gradients)
@@ -113,12 +124,12 @@ class TrainingModuleE2ETest {
113124
fun testTrainXOR_PTEOnly() {
114125
val pteFilePath = "/xor_full.pte"
115126

116-
val pteFile = File(getTestFilePath(pteFilePath))
127+
val pteFile = File(TestFileUtils.getTestFilePath(pteFilePath))
117128
val pteInputStream = javaClass.getResourceAsStream(pteFilePath)
118129
FileUtils.copyInputStreamToFile(pteInputStream, pteFile)
119130
pteInputStream.close()
120131

121-
val module = TrainingModule.load(getTestFilePath(pteFilePath));
132+
val module = TrainingModule.load(TestFileUtils.getTestFilePath(pteFilePath));
122133
val params = module.namedParameters("forward")
123134

124135
Assert.assertEquals(4, params.size)
@@ -147,7 +158,10 @@ class TrainingModuleE2ETest {
147158
val targetDex = inputDex + 1
148159
val input = dataset.get(inputDex)
149160
val target = dataset.get(targetDex)
150-
val out = module.executeForwardBackward("forward", EValue.from(input), EValue.from(target))
161+
val out = module.executeForwardBackward("forward",
162+
EValue.from(input),
163+
EValue.from(target)
164+
)
151165
val gradients = module.namedGradients("forward")
152166

153167
if (i == 0) {
@@ -168,7 +182,9 @@ class TrainingModuleE2ETest {
168182
input.getDataAsFloatArray()[0],
169183
input.getDataAsFloatArray()[1],
170184
out[1].toTensor().getDataAsLongArray()[0],
171-
target.getDataAsLongArray()[0]));
185+
target.getDataAsLongArray()[0]
186+
)
187+
);
172188
}
173189

174190
sgd.step(gradients)
@@ -184,24 +200,33 @@ class TrainingModuleE2ETest {
184200
@Throws(IOException::class)
185201
fun testMissingPteFile() {
186202
val exception = Assert.assertThrows(RuntimeException::class.java) {
187-
TrainingModule.load(getTestFilePath(MISSING_PTE_NAME))
203+
TrainingModule.load(TestFileUtils.getTestFilePath(MISSING_PTE_NAME))
188204
}
189-
Assert.assertEquals(exception.message, "Cannot load model path!! " + getTestFilePath(MISSING_PTE_NAME))
205+
Assert.assertEquals(
206+
exception.message,
207+
"Cannot load model path!! " + TestFileUtils.getTestFilePath(MISSING_PTE_NAME)
208+
)
190209
}
191210

192211
@Test
193212
@Throws(IOException::class)
194213
fun testMissingPtdFile() {
195214
val exception = Assert.assertThrows(RuntimeException::class.java) {
196215
val pteFilePath = "/xor.pte"
197-
val pteFile = File(getTestFilePath(pteFilePath))
216+
val pteFile = File(TestFileUtils.getTestFilePath(pteFilePath))
198217
val pteInputStream = javaClass.getResourceAsStream(pteFilePath)
199218
FileUtils.copyInputStreamToFile(pteInputStream, pteFile)
200219
pteInputStream.close()
201220

202-
TrainingModule.load(getTestFilePath(pteFilePath), getTestFilePath(MISSING_PTD_NAME))
221+
TrainingModule.load(
222+
TestFileUtils.getTestFilePath(pteFilePath),
223+
TestFileUtils.getTestFilePath(MISSING_PTD_NAME)
224+
)
203225
}
204-
Assert.assertEquals(exception.message, "Cannot load data path!! " + getTestFilePath(MISSING_PTD_NAME))
226+
Assert.assertEquals(
227+
exception.message,
228+
"Cannot load data path!! " + TestFileUtils.getTestFilePath(MISSING_PTD_NAME)
229+
)
205230
}
206231

207232
companion object {
@@ -212,4 +237,4 @@ class TrainingModuleE2ETest {
212237
private const val MISSING_PTE_NAME = "/missing.pte"
213238
private const val MISSING_PTD_NAME = "/missing.ptd"
214239
}
215-
}
240+
}

0 commit comments

Comments
 (0)