Skip to content

Commit e9af248

Browse files
committed
Update
[ghstack-poisoned]
2 parents 2a5ab06 + a70b7a8 commit e9af248

35 files changed

+778
-260
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
a3942627f5ac048e06b4b1d703b0a6a53bf6da5b
1+
eea657ddbdeb1118943a92fb73c289985c3ee1ba

.github/workflows/android-perf.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -355,8 +355,8 @@ jobs:
355355
"--recipe" "xnnpack"
356356
"--use_custom_sdpa"
357357
"--use_custom_kv_cache"
358-
"--qlinear"
359-
"--qembedding"
358+
"--qlinear" "8da4w"
359+
"--qembedding" "8w"
360360
"--output_dir" ".."
361361
)
362362

.github/workflows/apple-perf.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -360,8 +360,8 @@ jobs:
360360
"--recipe" "xnnpack"
361361
"--use_custom_sdpa"
362362
"--use_custom_kv_cache"
363-
"--qlinear"
364-
"--qembedding"
363+
"--qlinear" "8da4w"
364+
"--qembedding" "8w"
365365
"--output_dir" ".."
366366
)
367367

.github/workflows/trunk.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -711,8 +711,8 @@ jobs:
711711
"--recipe" "xnnpack"
712712
"--use_custom_sdpa"
713713
"--use_custom_kv_cache"
714-
"--qlinear"
715-
"--qembedding"
714+
"--qlinear" "8da4w"
715+
"--qembedding" "8w"
716716
"--output_dir" "${OUTPUT_DIR}"
717717
)
718718

backends/apple/coreml/compiler/torch_ops.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
_get_inputs,
1616
NUM_TO_NUMPY_DTYPE,
1717
NUM_TO_TORCH_DTYPE,
18+
split,
1819
transpose,
1920
unbind,
2021
)
@@ -37,6 +38,12 @@ def unbind_copy(context, node):
3738
unbind(context, node)
3839

3940

41+
# https://github.com/apple/coremltools/pull/2563
42+
@register_torch_op(override=False)
43+
def split_copy(context, node):
44+
split(context, node)
45+
46+
4047
# https://github.com/apple/coremltools/pull/2558
4148
@register_torch_op(
4249
torch_alias=["torchao::dequantize_affine", "torchao.dequantize_affine"],

codegen/gen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ def __call__(
297297
f"""
298298
Kernel(
299299
"{f.namespace}::{f.func.name}",{newline + '"' + (k + '",') if k != "default" else ""}
300-
[]({contextArg.defn()}, EValue** stack) {{
300+
[]({contextArg.defn()}, Span<EValue*> stack) {{
301301
{code_connector.join(code_list)}
302302
303303
{exception_boundary_begin}

codegen/test/test_executorch_gen.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,7 @@ def test_codegen_unboxed_specialized(self) -> None:
507507
Kernel(
508508
"custom_1::op_1",
509509
"v1/7;0,1,2,3|7;0,1,2,3|7;0,1,2,3",
510-
[](torch::executor::KernelRuntimeContext & context, EValue** stack) {
510+
[](torch::executor::KernelRuntimeContext & context, Span<EValue*> stack) {
511511
"""
512512
+ """
513513
@@ -605,7 +605,7 @@ def test_codegen_unboxed_default(self) -> None:
605605
"""
606606
Kernel(
607607
"custom_1::op_1",
608-
[](torch::executor::KernelRuntimeContext & context, EValue** stack) {
608+
[](torch::executor::KernelRuntimeContext & context, Span<EValue*> stack) {
609609
"""
610610
+ """
611611
@@ -632,7 +632,7 @@ def test_codegen_unboxed_default(self) -> None:
632632
"""
633633
Kernel(
634634
"custom_1::op_1",
635-
[](torch::executor::KernelRuntimeContext & context, EValue** stack) {
635+
[](torch::executor::KernelRuntimeContext & context, Span<EValue*> stack) {
636636
"""
637637
+ """
638638
@@ -675,7 +675,7 @@ def test_codegen_unboxed_default_kernel_key_selected(self) -> None:
675675
"""
676676
Kernel(
677677
"custom_1::op_1",
678-
[](torch::executor::KernelRuntimeContext & context, EValue** stack) {
678+
[](torch::executor::KernelRuntimeContext & context, Span<EValue*> stack) {
679679
"""
680680
+ """
681681

devtools/inspector/_inspector.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
map_runtime_aot_intermediate_outputs,
6363
merge_runtime_overlapping_debug_handles,
6464
ProgramOutput,
65+
propagate_back_debug_handle,
6566
RESERVED_FRAMEWORK_EVENT_NAMES,
6667
TimeScale,
6768
verify_debug_data_equivalence,
@@ -1166,7 +1167,18 @@ def _get_aot_intermediate_outputs_and_op_names(
11661167
"""
11671168
if self._etrecord._representative_inputs is None:
11681169
return {}, {}
1169-
export_program = self._etrecord.edge_dialect_program
1170+
1171+
export_program = None
1172+
1173+
# Will use the exported program to extract intermediate output if and only if exported_program has been provided, and it is the greatest ancestor of the edge_dialect_program
1174+
if self._etrecord.exported_program and propagate_back_debug_handle(
1175+
self._etrecord.exported_program,
1176+
self._etrecord.export_graph_id,
1177+
self._etrecord.edge_dialect_program,
1178+
):
1179+
export_program = self._etrecord.exported_program
1180+
else:
1181+
export_program = self._etrecord.edge_dialect_program
11701182
graph_module = export_program.module()
11711183
aot_debug_handle_to_op_name = get_aot_debug_handle_to_op_name_mapping(
11721184
graph_module

devtools/inspector/_inspector_utils.py

Lines changed: 145 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from collections.abc import Sequence
1212
from dataclasses import dataclass
1313
from enum import Enum
14-
from typing import Any, Dict, IO, List, Mapping, Optional, Tuple, TypeAlias, Union
14+
from typing import Any, Dict, IO, List, Mapping, Optional, Set, Tuple, TypeAlias, Union
1515

1616
import executorch.devtools.etdump.schema_flatcc as flatcc
1717

@@ -37,7 +37,7 @@
3737

3838
from executorch.exir.debug_handle_utils import (
3939
DEBUG_HANDLE_KEY,
40-
get_greatest_ancestor_node_identifier,
40+
FROM_NODE_KEY,
4141
UNSET_DEBUG_HANDLE,
4242
)
4343

@@ -46,6 +46,7 @@
4646
from tabulate import tabulate
4747

4848
from torch.export import ExportedProgram
49+
from torch.fx import Node
4950

5051
FORWARD = "forward"
5152
EDGE_DIALECT_GRAPH_KEY = "edge_dialect_graph_module"
@@ -936,6 +937,133 @@ def compare_intermediate_outputs(a: Any, b: Any, comparator) -> List[float]:
936937
)
937938

938939

940+
def get_ancestor_node_identifiers(node: Node) -> List[str]:
941+
"""Get the identifier of the ancestor node of the given node, with the graph id the ancestor node lives in.
942+
943+
The identifier is the concatenation of the node name and graph id of the
944+
greatest ancestor node, where the graph id is the unique id for every graph
945+
module in the export flow and node name is unique within the same graph module.
946+
947+
Returns: the identifiers of all its ancestor nodes
948+
"""
949+
950+
node_source = node.meta[FROM_NODE_KEY]
951+
node_source = node_source[-1]
952+
ancestor_node_ids: List[str] = [f"{node_source.name}.{str(node_source.graph_id)}"]
953+
954+
while len(node_source.from_node) > 0:
955+
node_source = node_source.from_node[-1]
956+
ancestor_node_ids.append(f"{node_source.name}.{str(node_source.graph_id)}")
957+
958+
return ancestor_node_ids
959+
960+
961+
def get_parent_node_identifier(node: Node) -> Optional[str]:
962+
"""Get the identifier of the parent node of the given node, with the graph id the parent node lives in.
963+
964+
The identifier is the concatenation of the node name and graph id of the
965+
greatest parent node, where the graph id is the unique id for every graph
966+
module in the export flow and node name is unique within the same graph module.
967+
968+
Returns: the identifier of the parent node, or None if can not find the parent
969+
"""
970+
971+
if FROM_NODE_KEY not in node.meta:
972+
return None
973+
974+
node_source = node.meta[FROM_NODE_KEY][-1]
975+
return f"{node_source.name}.{str(node_source.graph_id)}"
976+
977+
978+
def _extract_ancestor_debug_handles(
979+
edge_dialect_program: ExportedProgram,
980+
) -> Dict[str, int]:
981+
"""Extract mapping from ancestor node identifiers to debug handles."""
982+
ancestors_node_id_to_debug_handle: Dict[str, int] = {}
983+
984+
def _extract_node_id_to_debug_handle(node: Node) -> None:
985+
if node.op in ("placeholder", "output"):
986+
return
987+
for ancestor_node_id in get_ancestor_node_identifiers(node):
988+
if ancestor_node_id not in ancestors_node_id_to_debug_handle:
989+
ancestors_node_id_to_debug_handle[ancestor_node_id] = node.meta[
990+
DEBUG_HANDLE_KEY
991+
]
992+
else:
993+
assert (
994+
ancestors_node_id_to_debug_handle[ancestor_node_id]
995+
== node.meta[DEBUG_HANDLE_KEY]
996+
)
997+
998+
bfs_trace_with_node_process(
999+
edge_dialect_program.graph_module, _extract_node_id_to_debug_handle
1000+
)
1001+
return ancestors_node_id_to_debug_handle
1002+
1003+
1004+
def _find_matched_debug_handles(
1005+
exported_program: ExportedProgram,
1006+
exported_program_graph_id: int,
1007+
ancestors_node_id_to_debug_handle: Dict[str, int],
1008+
) -> Set[int]:
1009+
"""Find debug handles that have corresponding nodes in the exported program."""
1010+
matched_debug_handles: Set[int] = set()
1011+
1012+
def _find_n_match_node(node: Node) -> None:
1013+
if node.op in ("output", "placeholder"):
1014+
return
1015+
node_id = f"{node.name}.{exported_program_graph_id}"
1016+
parent_node_id = get_parent_node_identifier(node)
1017+
if node_id in ancestors_node_id_to_debug_handle:
1018+
matched_debug_handles.add(ancestors_node_id_to_debug_handle[node_id])
1019+
elif parent_node_id and parent_node_id in ancestors_node_id_to_debug_handle:
1020+
matched_debug_handles.add(ancestors_node_id_to_debug_handle[parent_node_id])
1021+
1022+
bfs_trace_with_node_process(exported_program.graph_module, _find_n_match_node)
1023+
return matched_debug_handles
1024+
1025+
1026+
def _verify_graph_match(
1027+
edge_dialect_program: ExportedProgram, matched_debug_handles: Set[int]
1028+
) -> bool:
1029+
"""Verify if every debug handle in edge dialect program has a corresponding node."""
1030+
graph_matched = True
1031+
1032+
def _check_graph_match(node: Node) -> None:
1033+
nonlocal graph_matched
1034+
if node.op in ("output", "placeholder"):
1035+
return
1036+
if node.meta[DEBUG_HANDLE_KEY] not in matched_debug_handles:
1037+
graph_matched = False
1038+
1039+
bfs_trace_with_node_process(edge_dialect_program.graph_module, _check_graph_match)
1040+
return graph_matched
1041+
1042+
1043+
def _apply_debug_handles(
1044+
exported_program: ExportedProgram,
1045+
exported_program_graph_id: int,
1046+
ancestors_node_id_to_debug_handle: Dict[str, int],
1047+
) -> None:
1048+
"""Apply debug handles to the exported program nodes."""
1049+
1050+
def _equip_debug_handle(node: Node) -> None:
1051+
if node.op in ("output", "placeholder"):
1052+
return
1053+
node_id = f"{node.name}.{exported_program_graph_id}"
1054+
parent_node_id = get_parent_node_identifier(node)
1055+
if node_id in ancestors_node_id_to_debug_handle:
1056+
node.meta[DEBUG_HANDLE_KEY] = ancestors_node_id_to_debug_handle[node_id]
1057+
elif parent_node_id and parent_node_id in ancestors_node_id_to_debug_handle:
1058+
node.meta[DEBUG_HANDLE_KEY] = ancestors_node_id_to_debug_handle[
1059+
parent_node_id
1060+
]
1061+
else:
1062+
node.meta[DEBUG_HANDLE_KEY] = UNSET_DEBUG_HANDLE
1063+
1064+
bfs_trace_with_node_process(exported_program.graph_module, _equip_debug_handle)
1065+
1066+
9391067
def propagate_back_debug_handle(
9401068
exported_program: ExportedProgram,
9411069
exported_program_graph_id: int,
@@ -953,47 +1081,24 @@ def propagate_back_debug_handle(
9531081
Then debug handle of op1 should be same as op1_0, and debug handle of op3 should be same as op3_0 and op3_1.
9541082
The debug handle of op2 will be UNSET_DEBUG_HANDLE for further skipping.
9551083
956-
Return: True if:
957-
a. every debug handle in the edge dialect program has a corresponding node in the exported program
958-
b. the exported program is the greatest ancestor of the edge dialect program
959-
960-
Otherwise, return False.
1084+
Return: True if every debug handle in the edge dialect program has a corresponding node in the exported program, otherwise, return False.
9611085
"""
1086+
# 1. Extract mapping from ancestor node identifiers to debug handles
1087+
ancestors_node_id_to_debug_handle = _extract_ancestor_debug_handles(
1088+
edge_dialect_program
1089+
)
9621090

963-
# 1. set up a mapping from debug handle to identifier of export program's node
964-
# using edge dialect program nodes' debug handles and from_node info
965-
export_graph_node_id_to_debug_handle = {
966-
get_greatest_ancestor_node_identifier(node): node.meta[DEBUG_HANDLE_KEY]
967-
for node in edge_dialect_program.graph.nodes
968-
if node.op not in ("placeholder", "output")
969-
}
970-
971-
# 2. equip debug handle to the exported program's nodes using the mapping
972-
# number of nodes in the exported program that have matched entry in export_graph_node_id_to_debug_handle
973-
n_matched_node = 0
974-
975-
def _find_n_match_node(node: torch.fx.Node) -> None:
976-
nonlocal n_matched_node
977-
if node.name in ("output", "placeholder"):
978-
return
979-
node_id = f"{node.name}.{exported_program_graph_id}"
980-
if node_id in export_graph_node_id_to_debug_handle:
981-
n_matched_node += 1
982-
983-
def _equip_debug_handle(node: torch.fx.Node) -> None:
984-
if node.name in ("output", "placeholder"):
985-
return
986-
node_id = f"{node.name}.{exported_program_graph_id}"
987-
if node_id in export_graph_node_id_to_debug_handle:
988-
node.meta[DEBUG_HANDLE_KEY] = export_graph_node_id_to_debug_handle[node_id]
989-
else:
990-
node.meta[DEBUG_HANDLE_KEY] = UNSET_DEBUG_HANDLE
991-
992-
bfs_trace_with_node_process(exported_program.graph_module, _find_n_match_node)
1091+
# 2. Find debug handles that have corresponding nodes in the exported program
1092+
matched_debug_handles = _find_matched_debug_handles(
1093+
exported_program, exported_program_graph_id, ancestors_node_id_to_debug_handle
1094+
)
9931095

994-
# if any node in the edge dialect program has no corresponding node in the exported program, match failed
995-
if n_matched_node != len(export_graph_node_id_to_debug_handle):
1096+
# 3. Verify if every debug handle in edge dialect program has a corresponding node
1097+
if not _verify_graph_match(edge_dialect_program, matched_debug_handles):
9961098
return False
9971099

998-
bfs_trace_with_node_process(exported_program.graph_module, _equip_debug_handle)
1100+
# 4. Apply debug handles to the exported program
1101+
_apply_debug_handles(
1102+
exported_program, exported_program_graph_id, ancestors_node_id_to_debug_handle
1103+
)
9991104
return True

0 commit comments

Comments
 (0)