Skip to content

Commit 684171d

Browse files
authored
fix: correct spirv node incoming edge mapping
1 parent cb5d88f commit 684171d

File tree

4 files changed

+35
-29
lines changed

4 files changed

+35
-29
lines changed

src/vgf_adapter_model_explorer/builder/builder.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -71,27 +71,25 @@ def _build_segment_spirv_nodes(
7171
spirv_nodes = self._build_spirv_nodes(self.vgf_data, module.index)
7272
if not spirv_nodes:
7373
return []
74-
self._clean_spirv_edges(spirv_nodes)
75-
self._connect_spirv_node(spirv_nodes[0], input_nodes)
74+
self._connect_spirv_nodes(spirv_nodes, input_nodes)
7675
return spirv_nodes
7776

78-
def _clean_spirv_edges(self, spirv_nodes: list[gb.GraphNode]) -> None:
79-
"""Clean incoming edges of SPIR-V nodes."""
80-
for node in spirv_nodes:
81-
new_incoming_edges = []
82-
for e in node.incomingEdges:
83-
if e.sourceNodeId != GRAPH_INPUT_ANNOTATION:
84-
new_incoming_edges.append(e)
85-
node.incomingEdges[:] = new_incoming_edges
86-
87-
def _connect_spirv_node(
88-
self, spirv_node: gb.GraphNode, input_nodes: list[gb.GraphNode]
77+
def _connect_spirv_nodes(
78+
self, spirv_nodes: list[gb.GraphNode], input_nodes: list[gb.GraphNode]
8979
) -> None:
90-
"""Connect SPIR-V node to the graph."""
91-
for node in input_nodes:
92-
spirv_node.incomingEdges.append(
93-
gb.IncomingEdge(sourceNodeId=node.id, targetNodeInputId="0")
94-
)
80+
"""Connect SPIR-V nodes to the graph by mapping block arguments to input nodes."""
81+
for spirv_node in spirv_nodes:
82+
new_incoming_edges = []
83+
for edge in spirv_node.incomingEdges:
84+
if edge.sourceNodeId == GRAPH_INPUT_ANNOTATION:
85+
arg_number = int(edge.sourceNodeOutputId or "0")
86+
if arg_number < len(input_nodes):
87+
edge.sourceNodeId = input_nodes[arg_number].id
88+
edge.sourceNodeOutputId = "0"
89+
new_incoming_edges.append(edge)
90+
else:
91+
new_incoming_edges.append(edge)
92+
spirv_node.incomingEdges[:] = new_incoming_edges
9593

9694
def _build_segment_input_nodes(
9795
self, inputs: list[IOBase], module: Module

src/vgf_adapter_model_explorer/tests/fixtures/mobilenet_v2/expected.json

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,11 @@
7373
"sourceNodeOutputId": "0",
7474
"targetNodeInputId": "4"
7575
},
76+
{
77+
"sourceNodeId": "mrt_115",
78+
"sourceNodeOutputId": "0",
79+
"targetNodeInputId": "5"
80+
},
7681
{
7782
"sourceNodeId": "5",
7883
"sourceNodeOutputId": "0",
@@ -92,11 +97,6 @@
9297
"sourceNodeId": "8",
9398
"sourceNodeOutputId": "0",
9499
"targetNodeInputId": "9"
95-
},
96-
{
97-
"sourceNodeId": "mrt_115",
98-
"sourceNodeOutputId": "0",
99-
"targetNodeInputId": "0"
100100
}
101101
],
102102
"outputsMetadata": [

src/vgf_adapter_model_explorer/tests/fixtures/vww4/expected.json

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,11 @@
7373
"sourceNodeOutputId": "0",
7474
"targetNodeInputId": "4"
7575
},
76+
{
77+
"sourceNodeId": "mrt_95",
78+
"sourceNodeOutputId": "0",
79+
"targetNodeInputId": "5"
80+
},
7681
{
7782
"sourceNodeId": "5",
7883
"sourceNodeOutputId": "0",
@@ -92,11 +97,6 @@
9297
"sourceNodeId": "8",
9398
"sourceNodeOutputId": "0",
9499
"targetNodeInputId": "9"
95-
},
96-
{
97-
"sourceNodeId": "mrt_95",
98-
"sourceNodeOutputId": "0",
99-
"targetNodeInputId": "0"
100100
}
101101
],
102102
"outputsMetadata": [

src/vgf_adapter_model_explorer/tests/test_vgf_builder.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@
9191

9292

9393
def test_builder():
94+
from ..constants import GRAPH_INPUT_ANNOTATION
95+
9496
mock_get_spirv_nodes = Mock(
9597
return_value=[
9698
gb.GraphNode(
@@ -99,7 +101,13 @@ def test_builder():
99101
namespace="",
100102
subgraphIds=[],
101103
attrs=[],
102-
incomingEdges=[],
104+
incomingEdges=[
105+
gb.IncomingEdge(
106+
sourceNodeId=GRAPH_INPUT_ANNOTATION,
107+
sourceNodeOutputId="0",
108+
targetNodeInputId="0",
109+
)
110+
],
103111
outputsMetadata=[
104112
gb.MetadataItem(
105113
id="0",

0 commit comments

Comments
 (0)