Skip to content

Commit 4b9d9eb

Browse files
committed
Add CoreML tester implementation
ghstack-source-id: 557f6f5 ghstack-comment-id: 3003538299 Pull-Request: #11959
1 parent 2bb37f2 commit 4b9d9eb

File tree

1 file changed

+62
-0
lines changed

1 file changed

+62
-0
lines changed
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
git # Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Any, Callable, List, Optional, Sequence, Tuple, Type
8+
9+
import executorch
10+
import executorch.backends.test.harness.stages as BaseStages
11+
12+
import torch
13+
from executorch.backends.apple.coreml.partition import CoreMLPartitioner
14+
from executorch.backends.test.harness import Tester as TesterBase
15+
from executorch.backends.test.harness.stages import StageType
16+
from executorch.exir import EdgeCompileConfig
17+
from executorch.exir.backend.partitioner import Partitioner
18+
19+
20+
class Partition(BaseStages.Partition):
21+
def __init__(self, partitioner: Optional[Partitioner] = None):
22+
super().__init__(
23+
partitioner=partitioner or CoreMLPartitioner,
24+
)
25+
26+
27+
class ToEdgeTransformAndLower(BaseStages.ToEdgeTransformAndLower):
28+
def __init__(
29+
self,
30+
partitioners: Optional[List[Partitioner]] = None,
31+
edge_compile_config: Optional[EdgeCompileConfig] = None,
32+
):
33+
super().__init__(
34+
default_partitioner_cls=CoreMLPartitioner,
35+
partitioners=partitioners,
36+
edge_compile_config=edge_compile_config,
37+
)
38+
39+
40+
class CoreMLTester(TesterBase):
41+
def __init__(
42+
self,
43+
module: torch.nn.Module,
44+
example_inputs: Tuple[torch.Tensor],
45+
dynamic_shapes: Optional[Tuple[Any]] = None,
46+
):
47+
# Specialize for XNNPACK
48+
stage_classes = (
49+
executorch.backends.test.harness.Tester.default_stage_classes()
50+
| {
51+
StageType.PARTITION: Partition,
52+
StageType.TO_EDGE_TRANSFORM_AND_LOWER: ToEdgeTransformAndLower,
53+
}
54+
)
55+
56+
super().__init__(
57+
module=module,
58+
stage_classes=stage_classes,
59+
example_inputs=example_inputs,
60+
dynamic_shapes=dynamic_shapes,
61+
)
62+

0 commit comments

Comments
 (0)