Skip to content

Commit 4d86272

Browse files
Add dim order guard for mediatek backend AOT part
1 parent 45df002 commit 4d86272

File tree

2 files changed

+26
-0
lines changed

2 files changed

+26
-0
lines changed

backends/mediatek/partitioner.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@ def is_node_supported(self, _, node: torch.fx.Node) -> bool:
4444
return False
4545

4646
op_type = node.target.__name__
47+
48+
# Skip until we can handle the dimension order representation
49+
if op_type == 'aten._to_copy.default':
50+
return False
51+
4752
if op_type in self._op_types_to_skip or node.name in self._op_names_to_skip:
4853
print(
4954
f"[Neuropilot Backend] The {op_type} operator with name '{node.name}' is skipped."

backends/mediatek/preprocess.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,23 @@
2222
SKIP_COMPILE_SPEC_KEYS = {"ImportForever"}
2323

2424

25+
def assert_default_dim_order(edge_graph_module: torch.fx.GraphModule) -> None:
26+
for node in edge_graph_module.graph.nodes:
27+
if node.op != "placeholder":
28+
continue
29+
30+
# We expect the default dim order for all tensor-like inputs i.e. inputs, buffers, and params
31+
t = node.meta.get("val", None)
32+
if t is not None and getattr(t, "dim_order", None) is not None:
33+
default_dim_order = tuple(range(t.dim()))
34+
if t.dim_order() != default_dim_order:
35+
raise RuntimeError(
36+
f"Neuropilot backend only supports contiguous memory format for inputs."
37+
f"Expecting dim_order: {default_dim_order}, but got "
38+
f"{node.meta['val'].dim_order()} for a placeholder node {node}."
39+
)
40+
41+
2542
@final
2643
class NeuropilotBackend(BackendDetails):
2744

@@ -30,6 +47,10 @@ def preprocess(
3047
cls, edge_program: ExportedProgram, module_compile_spec: List[CompileSpec]
3148
) -> PreprocessResult:
3249

50+
# Make sure all inputs are contiguous_format or NCHW or default dim order
51+
print('here')
52+
assert_default_dim_order(edge_program.graph_module)
53+
3354
name_to_node_mappings = {node.name: node for node in edge_program.graph.nodes}
3455
input_names = edge_program.graph_signature.user_inputs
3556
output_names = edge_program.graph_signature.user_outputs

0 commit comments

Comments
 (0)