Skip to content

Commit c072bbb

Browse files
authored
chore: minor code tidy up (#105)
1 parent c30d323 commit c072bbb

File tree

7 files changed

+36
-67
lines changed

7 files changed

+36
-67
lines changed

scripts/generate-mlir-bin-sbom.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def main():
6666
"open-source-office@arm.com",
6767
)
6868
],
69-
creator_comment="THIS SOFTWARE BILL OF MATERIALS (\"SBOM\") IS PROVIDED BY ARM LIMITED \"AS IS\" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NONINFRINGEMENT ARE DISCLAIMED. IN NO EVENT SHALL ARM LIMITED BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SBOM, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.",
69+
creator_comment='THIS SOFTWARE BILL OF MATERIALS ("SBOM") IS PROVIDED BY ARM LIMITED "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NONINFRINGEMENT ARE DISCLAIMED. IN NO EVENT SHALL ARM LIMITED BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SBOM, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.',
7070
created=build_time,
7171
)
7272
document = Document(creation_info)

src/vgf_adapter_model_explorer/builder/builder.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from ..constants import GRAPH_INPUT_ANNOTATION, GRAPH_OUTPUT_ANNOTATION
1313
from ..generic import append_attr_to_metadata_list
1414
from ..parser.types import IOBase, Module, Resource, Vgf
15-
from .utils import extend_resource, find_item, format_index
15+
from .utils import extend_resource, format_index
1616

1717

1818
class VgfGraphBuilder:
@@ -26,6 +26,7 @@ def __init__(
2626
"""Builds a Model Explorer GraphCollection from VGF data."""
2727

2828
self.vgf_data = vgf_data
29+
self.vgf_resources = {r.index: r for r in vgf_data.resources}
2930
self._build_spirv_nodes = build_spirv_nodes
3031
self.graph_collection = self._build_graph_collection()
3132

@@ -41,9 +42,11 @@ def _build_graph_collection(self) -> gb.GraphCollection:
4142
def _build_nodes(self) -> list[gb.GraphNode]:
4243
"""Build all nodes for the graph."""
4344
nodes = [self._build_graph_input_node()]
45+
modules = {m.index: m for m in self.vgf_data.modules}
46+
graph_output_node = self._build_graph_output_node()
4447

4548
for segment in self.vgf_data.model_sequence.segments:
46-
module = find_item(segment.module_index, self.vgf_data.modules)
49+
module = modules.get(segment.module_index)
4750
if not module:
4851
continue
4952

@@ -56,9 +59,7 @@ def _build_nodes(self) -> list[gb.GraphNode]:
5659
output_nodes = self._build_segment_output_nodes(
5760
segment.outputs, module, spirv_nodes
5861
)
59-
nodes.extend(
60-
input_nodes + output_nodes + [self._build_graph_output_node()]
61-
)
62+
nodes.extend(input_nodes + output_nodes + [graph_output_node])
6263

6364
return nodes
6465

@@ -98,7 +99,7 @@ def _build_segment_input_nodes(
9899
nodes: list[gb.GraphNode] = []
99100

100101
for input in inputs:
101-
resource = find_item(input.mrt_index, self.vgf_data.resources)
102+
resource = self.vgf_resources.get(input.mrt_index)
102103
if not resource:
103104
continue
104105
nodes.append(self._build_node(resource, module, input.mrt_index))
@@ -113,7 +114,7 @@ def _build_segment_output_nodes(
113114
source_id = spirv_nodes[-1].id if spirv_nodes else None
114115

115116
for output in outputs:
116-
resource = find_item(output.mrt_index, self.vgf_data.resources)
117+
resource = self.vgf_resources.get(output.mrt_index)
117118
if not resource:
118119
continue
119120
output_nodes.append(
@@ -157,9 +158,7 @@ def _build_node(
157158
) -> gb.GraphNode:
158159
"""Build a single graph node."""
159160
if source_id:
160-
incoming_edges = [
161-
gb.IncomingEdge(sourceNodeId=source_id, targetNodeInputId="0")
162-
]
161+
incoming_edges = [gb.IncomingEdge(sourceNodeId=source_id)]
163162
else:
164163
incoming_edges = self._build_incoming_edges(resource)
165164

src/vgf_adapter_model_explorer/builder/utils.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,9 @@
44
# Licensed under the Apache License v2.0
55
# See http://www.apache.org/licenses/LICENSE-2.0 for license information.
66

7-
from typing import Iterable, TypeVar
8-
97
from model_explorer import graph_builder as gb
108

11-
from ..parser.types import Module, Resource
12-
13-
T = TypeVar("T")
14-
15-
16-
def find_item(index: int, items: Iterable[T]) -> T | None:
17-
"""Finds the item at the given index."""
18-
return next(
19-
filter(lambda x: getattr(x, "index", None) == index, items), None
20-
)
9+
from ..parser.types import Resource
2110

2211

2312
def format_index(prefix: str, index: int) -> str:
@@ -34,14 +23,3 @@ def extend_resource(node: gb.GraphNode, resource: Resource) -> None:
3423
gb.KeyValue(key="Format", value=resource.vk_format),
3524
]
3625
)
37-
38-
39-
def extend_module(node: gb.GraphNode, module: Module) -> None:
40-
"""Extends module nodes with attributes."""
41-
node.attrs.extend(
42-
[
43-
gb.KeyValue(key="Has Spirv", value=str(module.has_spirv)),
44-
gb.KeyValue(key="Type", value=str(module.type)),
45-
gb.KeyValue(key="Entry Point", value=str(module.entry_point)),
46-
]
47-
)

src/vgf_adapter_model_explorer/constants.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
GRAPH_OUTPUT_ANNOTATION: str = "GraphOutputs"
1414
GRAPH_TENSOR_IDX: str = "tensor_index"
1515
GRAPH_TENSOR_TYPE: str = "tensor_shape"
16-
GRAPH_TENSOR_TAG: str = "__tensor_tag"
1716

1817
TERMINATOR_OPS: Set[str] = {"func.return"}
1918

src/vgf_adapter_model_explorer/exec/mlir_translate.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,10 @@
44
# Licensed under the Apache License v2.0
55
# See http://www.apache.org/licenses/LICENSE-2.0 for license information.
66

7-
import importlib.resources
8-
import sys
97
from pathlib import Path
108

119
from vgf_adapter_model_explorer.exec.exec_cmd import exec_cmd
12-
13-
14-
def get_binary_path(binary_name: str) -> Path:
15-
"""Get path to bundled binary, accounting for platform extensions."""
16-
if sys.platform.startswith("win"):
17-
binary_name = binary_name + ".exe"
18-
19-
return Path(
20-
str(
21-
importlib.resources.files(
22-
"vgf_adapter_model_explorer.bin"
23-
).joinpath(binary_name)
24-
)
25-
)
10+
from vgf_adapter_model_explorer.exec.utils import get_binary_path
2611

2712

2813
def exec_mlir_translate(spirv_path: Path) -> str:
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License v2.0
5+
# See http://www.apache.org/licenses/LICENSE-2.0 for license information.
6+
7+
import importlib.resources
8+
import sys
9+
from pathlib import Path
10+
11+
12+
def get_binary_path(binary_name: str) -> Path:
13+
"""Get path to bundled binary, accounting for platform extensions."""
14+
if sys.platform.startswith("win"):
15+
binary_name = binary_name + ".exe"
16+
17+
return Path(
18+
str(
19+
importlib.resources.files(
20+
"vgf_adapter_model_explorer.bin"
21+
).joinpath(binary_name)
22+
)
23+
)

src/vgf_adapter_model_explorer/exec/vgf_dump.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,12 @@
44
# Licensed under the Apache License v2.0
55
# See http://www.apache.org/licenses/LICENSE-2.0 for license information.
66

7-
import importlib.resources
8-
import sys
97
import tempfile
108
from pathlib import Path
119
from typing import Optional
1210

1311
from vgf_adapter_model_explorer.exec.exec_cmd import exec_cmd
14-
15-
16-
def get_binary_path(binary_name: str) -> Path:
17-
"""Get path to bundled binary, accounting for platform extensions."""
18-
if sys.platform.startswith("win"):
19-
binary_name = binary_name + ".exe"
20-
21-
return Path(
22-
str(
23-
importlib.resources.files(
24-
"vgf_adapter_model_explorer.bin"
25-
).joinpath(binary_name)
26-
)
27-
)
12+
from vgf_adapter_model_explorer.exec.utils import get_binary_path
2813

2914

3015
def exec_vgf_dump(

0 commit comments

Comments
 (0)