Skip to content

Commit c856189

Browse files
committed
Update base for Update on "add more export modules after ertrecod created"
we need to support etrecord recording custom export modules for further usage. This diff makes that happen by creating new function inside ETRecord Differential Revision: [D79279401](https://our.internmc.facebook.com/intern/diff/D79279401/) umbrella issue: #12961 [ghstack-poisoned]
2 parents b033a40 + f497f7f commit c856189

File tree

10 files changed

+154
-70
lines changed

10 files changed

+154
-70
lines changed

backends/apple/coreml/partition/coreml_partitioner.py

Lines changed: 80 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
PartitionResult,
2121
)
2222
from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer
23+
from executorch.exir.dialects._ops import ops as exir_ops
2324
from torch.export.exported_program import ExportedProgram
2425
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
2526
from torch.fx.passes.operator_support import OperatorSupportBase
@@ -56,6 +57,80 @@ def log_once(self, msg: str) -> None:
5657
logger.info(msg)
5758
self._logged_msgs.add(msg)
5859

60+
def should_skip_op_for_delegation(self, node_target_name: str) -> bool:
61+
skipped_ops = self.skip_ops_for_coreml_delegation or []
62+
if node_target_name in skipped_ops:
63+
assert (
64+
not self.lower_full_graph
65+
), f"Cannot skip {node_target_name} because lower_full_graph is True. Please set skip_ops_for_coreml_delegation=None or lower_full_graph=False in the CoreMLPartitioner"
66+
self.log_once(
67+
"Skipping op for CoreML delegation because it is in skip_ops_for_coreml_delegation: "
68+
+ node_target_name
69+
)
70+
return True
71+
return False
72+
73+
def should_override_support(self, node) -> bool:
74+
# https://github.com/apple/coremltools/issues/2573
75+
if (
76+
node.target
77+
in [
78+
torch.ops.aten.sub.Tensor,
79+
exir_ops.edge.aten.sub.Tensor,
80+
torch.ops.aten.add.Tensor,
81+
exir_ops.edge.aten.add.Tensor,
82+
]
83+
and "alpha" in node.kwargs
84+
and node.kwargs["alpha"] != 1
85+
):
86+
self.log_once(
87+
"torch.ops.aten.{sub, add}.Tensor with alpha != 1 is not supported by CoreML. Overriding support."
88+
)
89+
return True
90+
91+
# https://github.com/apple/coremltools/issues/2565
92+
if node.target in [
93+
torch.ops.aten.diagonal.default,
94+
torch.ops.aten.diagonal_copy.default,
95+
exir_ops.edge.aten.diagonal.default,
96+
exir_ops.edge.aten.diagonal_copy.default,
97+
]:
98+
self.log_once(
99+
"torch.ops.aten.diagonal.default has a bug in CoreML. Overriding op support."
100+
)
101+
return True
102+
103+
# https://github.com/apple/coremltools/issues/2569
104+
if node.target in [
105+
torch.ops.aten.acosh.default,
106+
exir_ops.edge.aten.acosh.default,
107+
torch.ops.aten.asinh.default,
108+
exir_ops.edge.aten.asinh.default,
109+
]:
110+
self.log_once(
111+
"torch.ops.aten.{acosh, asinh}.default is not supported by CoreML. Overriding op support."
112+
)
113+
return True
114+
115+
# TODO: enable this after bugs in ExecuTorch's partitioner are fixed
116+
# # If lower_full_graph=False, do not partition nodes with symbolic args because it can result in symbolic args
117+
# # in the placeholders due to partitioning, which CoreML does not support
118+
# if not self.lower_full_graph and any(
119+
# isinstance(arg, torch.fx.Node)
120+
# and isinstance(
121+
# arg.meta.get("val", None),
122+
# (torch.SymInt, torch.SymBool, torch.SymFloat),
123+
# )
124+
# for arg in node.args
125+
# ):
126+
# self.log_once(
127+
# "Skipping op for CoreML delegation because it contains symbolic args: "
128+
# + node_target_name
129+
# )
130+
# return True
131+
132+
return False
133+
59134
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
60135
# get_attr node can always be supported on any backend
61136
if node.op == "get_attr":
@@ -64,38 +139,17 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
64139
elif node.op == "call_function":
65140
# skip ops if specified by user
66141
node_target_name = getattr(node.target, "__name__", "").lower()
67-
if node_target_name in (self.skip_ops_for_coreml_delegation or []):
68-
self.log_once(
69-
"Skipping op for CoreML delegation because it is in skip_ops_for_coreml_delegation: "
70-
+ node_target_name
71-
)
72-
assert (
73-
not self.lower_full_graph
74-
), "Cannot have skip_ops_for_coreml_delegation when lower_full_graph is True"
75-
return False
76142

77-
# TODO: enable this after bugs in ExecuTorch's partitioner are fixed
78-
# # If lower_full_graph=False, do not partition nodes with symbolic args because it can result in symbolic args
79-
# # in the placeholders due to partitioning, which CoreML does not support
80-
# if not self.lower_full_graph and any(
81-
# isinstance(arg, torch.fx.Node)
82-
# and isinstance(
83-
# arg.meta.get("val", None),
84-
# (torch.SymInt, torch.SymBool, torch.SymFloat),
85-
# )
86-
# for arg in node.args
87-
# ):
88-
# self.log_once(
89-
# "Skipping op for CoreML delegation because it contains symbolic args: "
90-
# + node_target_name
91-
# )
92-
# assert not self.lower_full_graph
93-
# return False
143+
if self.should_skip_op_for_delegation(node_target_name):
144+
return False
94145

95146
# query coremltools to see if node is supported
96147
is_supported = ct.converters.mil.frontend.torch.is_torch_fx_node_supported(
97148
node
98149
)
150+
if self.should_override_support(node):
151+
is_supported = False
152+
99153
if not is_supported:
100154
if self.lower_full_graph:
101155
raise NotImplementedError(
@@ -126,7 +180,6 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
126180

127181

128182
class CoreMLPartitioner(Partitioner):
129-
130183
def __init__(
131184
self,
132185
*,

backends/arm/test/ops/test_asinh.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99

1010
from executorch.backends.arm.test import common
1111
from executorch.backends.arm.test.tester.test_pipeline import (
12-
EthosU55PipelineBI,
13-
EthosU85PipelineBI,
14-
TosaPipelineBI,
15-
TosaPipelineMI,
12+
EthosU55PipelineINT,
13+
EthosU85PipelineINT,
14+
TosaPipelineFP,
15+
TosaPipelineINT,
1616
)
1717

1818
input_t = Tuple[torch.Tensor] # Input x
@@ -36,8 +36,8 @@ def forward(self, x):
3636

3737

3838
@common.parametrize("test_data", test_data_suite)
39-
def test_asin_tosa_MI(test_data: Tuple):
40-
pipeline = TosaPipelineMI[input_t](
39+
def test_asinh_tosa_FP(test_data: Tuple):
40+
pipeline = TosaPipelineFP[input_t](
4141
Asinh(),
4242
(test_data(),),
4343
aten_op,
@@ -47,8 +47,8 @@ def test_asin_tosa_MI(test_data: Tuple):
4747

4848

4949
@common.parametrize("test_data", test_data_suite)
50-
def test_asin_tosa_BI(test_data: Tuple):
51-
pipeline = TosaPipelineBI[input_t](
50+
def test_asinh_tosa_INT(test_data: Tuple):
51+
pipeline = TosaPipelineINT[input_t](
5252
Asinh(),
5353
(test_data(),),
5454
aten_op=[],
@@ -59,8 +59,8 @@ def test_asin_tosa_BI(test_data: Tuple):
5959

6060
@common.parametrize("test_data", test_data_suite)
6161
@common.XfailIfNoCorstone300
62-
def test_asin_u55_BI(test_data: Tuple):
63-
pipeline = EthosU55PipelineBI[input_t](
62+
def test_asinh_u55_INT(test_data: Tuple):
63+
pipeline = EthosU55PipelineINT[input_t](
6464
Asinh(),
6565
(test_data(),),
6666
aten_ops=[],
@@ -70,8 +70,8 @@ def test_asin_u55_BI(test_data: Tuple):
7070

7171
@common.parametrize("test_data", test_data_suite)
7272
@common.XfailIfNoCorstone320
73-
def test_asin_u85_BI(test_data: Tuple):
74-
pipeline = EthosU85PipelineBI[input_t](
73+
def test_asinh_u85_INT(test_data: Tuple):
74+
pipeline = EthosU85PipelineINT[input_t](
7575
Asinh(),
7676
(test_data(),),
7777
aten_ops=[],

extension/android/BUCK

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ non_fbcode_target(_kind = fb_android_library,
1313
"executorch_android/src/main/java/org/pytorch/executorch/MethodMetadata.java",
1414
"executorch_android/src/main/java/org/pytorch/executorch/Module.java",
1515
"executorch_android/src/main/java/org/pytorch/executorch/Tensor.java",
16-
"executorch_android/src/main/java/org/pytorch/executorch/TrainingModule.java",
17-
"executorch_android/src/main/java/org/pytorch/executorch/SGD.java",
1816
"executorch_android/src/main/java/org/pytorch/executorch/annotations/Experimental.java",
17+
"executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java",
18+
"executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java",
1919
],
2020
autoglob = False,
2121
language = "JAVA",
Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,24 @@
55
* This source code is licensed under the BSD-style license found in the
66
* LICENSE file in the root directory of this source tree.
77
*/
8-
package org.pytorch.executorch
8+
9+
package org.pytorch.executorch.training
910

1011
import android.Manifest
1112
import android.util.Log
1213
import androidx.test.ext.junit.runners.AndroidJUnit4
1314
import androidx.test.rule.GrantPermissionRule
14-
import java.io.File
15-
import java.io.IOException
16-
import java.net.URISyntaxException
1715
import org.apache.commons.io.FileUtils
1816
import org.junit.Assert
1917
import org.junit.Rule
2018
import org.junit.Test
2119
import org.junit.runner.RunWith
22-
import org.pytorch.executorch.TestFileUtils.getTestFilePath
20+
import org.pytorch.executorch.EValue
21+
import org.pytorch.executorch.Tensor
22+
import org.pytorch.executorch.TestFileUtils
23+
import java.io.File
24+
import java.io.IOException
25+
import java.net.URISyntaxException
2326
import kotlin.random.Random
2427
import kotlin.test.assertContains
2528

@@ -36,17 +39,20 @@ class TrainingModuleE2ETest {
3639
val pteFilePath = "/xor.pte"
3740
val ptdFilePath = "/xor.ptd"
3841

39-
val pteFile = File(getTestFilePath(pteFilePath))
42+
val pteFile = File(TestFileUtils.getTestFilePath(pteFilePath))
4043
val pteInputStream = javaClass.getResourceAsStream(pteFilePath)
4144
FileUtils.copyInputStreamToFile(pteInputStream, pteFile)
4245
pteInputStream.close()
4346

44-
val ptdFile = File(getTestFilePath(ptdFilePath))
47+
val ptdFile = File(TestFileUtils.getTestFilePath(ptdFilePath))
4548
val ptdInputStream = javaClass.getResourceAsStream(ptdFilePath)
4649
FileUtils.copyInputStreamToFile(ptdInputStream, ptdFile)
4750
ptdInputStream.close()
4851

49-
val module = TrainingModule.load(getTestFilePath(pteFilePath), getTestFilePath(ptdFilePath))
52+
val module = TrainingModule.load(
53+
TestFileUtils.getTestFilePath(pteFilePath),
54+
TestFileUtils.getTestFilePath(ptdFilePath)
55+
)
5056
val params = module.namedParameters("forward")
5157

5258
Assert.assertEquals(4, params.size)
@@ -75,7 +81,10 @@ class TrainingModuleE2ETest {
7581
val targetDex = inputDex + 1
7682
val input = dataset.get(inputDex)
7783
val target = dataset.get(targetDex)
78-
val out = module.executeForwardBackward("forward", EValue.from(input), EValue.from(target))
84+
val out = module.executeForwardBackward("forward",
85+
EValue.from(input),
86+
EValue.from(target)
87+
)
7988
val gradients = module.namedGradients("forward")
8089

8190
if (i == 0) {
@@ -96,7 +105,9 @@ class TrainingModuleE2ETest {
96105
input.getDataAsFloatArray()[0],
97106
input.getDataAsFloatArray()[1],
98107
out[1].toTensor().getDataAsLongArray()[0],
99-
target.getDataAsLongArray()[0]));
108+
target.getDataAsLongArray()[0]
109+
)
110+
);
100111
}
101112

102113
sgd.step(gradients)
@@ -113,12 +124,12 @@ class TrainingModuleE2ETest {
113124
fun testTrainXOR_PTEOnly() {
114125
val pteFilePath = "/xor_full.pte"
115126

116-
val pteFile = File(getTestFilePath(pteFilePath))
127+
val pteFile = File(TestFileUtils.getTestFilePath(pteFilePath))
117128
val pteInputStream = javaClass.getResourceAsStream(pteFilePath)
118129
FileUtils.copyInputStreamToFile(pteInputStream, pteFile)
119130
pteInputStream.close()
120131

121-
val module = TrainingModule.load(getTestFilePath(pteFilePath));
132+
val module = TrainingModule.load(TestFileUtils.getTestFilePath(pteFilePath));
122133
val params = module.namedParameters("forward")
123134

124135
Assert.assertEquals(4, params.size)
@@ -147,7 +158,10 @@ class TrainingModuleE2ETest {
147158
val targetDex = inputDex + 1
148159
val input = dataset.get(inputDex)
149160
val target = dataset.get(targetDex)
150-
val out = module.executeForwardBackward("forward", EValue.from(input), EValue.from(target))
161+
val out = module.executeForwardBackward("forward",
162+
EValue.from(input),
163+
EValue.from(target)
164+
)
151165
val gradients = module.namedGradients("forward")
152166

153167
if (i == 0) {
@@ -168,7 +182,9 @@ class TrainingModuleE2ETest {
168182
input.getDataAsFloatArray()[0],
169183
input.getDataAsFloatArray()[1],
170184
out[1].toTensor().getDataAsLongArray()[0],
171-
target.getDataAsLongArray()[0]));
185+
target.getDataAsLongArray()[0]
186+
)
187+
);
172188
}
173189

174190
sgd.step(gradients)
@@ -184,24 +200,33 @@ class TrainingModuleE2ETest {
184200
@Throws(IOException::class)
185201
fun testMissingPteFile() {
186202
val exception = Assert.assertThrows(RuntimeException::class.java) {
187-
TrainingModule.load(getTestFilePath(MISSING_PTE_NAME))
203+
TrainingModule.load(TestFileUtils.getTestFilePath(MISSING_PTE_NAME))
188204
}
189-
Assert.assertEquals(exception.message, "Cannot load model path!! " + getTestFilePath(MISSING_PTE_NAME))
205+
Assert.assertEquals(
206+
exception.message,
207+
"Cannot load model path!! " + TestFileUtils.getTestFilePath(MISSING_PTE_NAME)
208+
)
190209
}
191210

192211
@Test
193212
@Throws(IOException::class)
194213
fun testMissingPtdFile() {
195214
val exception = Assert.assertThrows(RuntimeException::class.java) {
196215
val pteFilePath = "/xor.pte"
197-
val pteFile = File(getTestFilePath(pteFilePath))
216+
val pteFile = File(TestFileUtils.getTestFilePath(pteFilePath))
198217
val pteInputStream = javaClass.getResourceAsStream(pteFilePath)
199218
FileUtils.copyInputStreamToFile(pteInputStream, pteFile)
200219
pteInputStream.close()
201220

202-
TrainingModule.load(getTestFilePath(pteFilePath), getTestFilePath(MISSING_PTD_NAME))
221+
TrainingModule.load(
222+
TestFileUtils.getTestFilePath(pteFilePath),
223+
TestFileUtils.getTestFilePath(MISSING_PTD_NAME)
224+
)
203225
}
204-
Assert.assertEquals(exception.message, "Cannot load data path!! " + getTestFilePath(MISSING_PTD_NAME))
226+
Assert.assertEquals(
227+
exception.message,
228+
"Cannot load data path!! " + TestFileUtils.getTestFilePath(MISSING_PTD_NAME)
229+
)
205230
}
206231

207232
companion object {
@@ -212,4 +237,4 @@ class TrainingModuleE2ETest {
212237
private const val MISSING_PTE_NAME = "/missing.pte"
213238
private const val MISSING_PTD_NAME = "/missing.ptd"
214239
}
215-
}
240+
}

0 commit comments

Comments
 (0)