Skip to content

Commit 1608e43

Browse files
author
Diptorup Deb
authored
Merge pull request #1292 from IntelPython/fix/remove_unused_code_spirv_generator
Remove spirv-tool call generators from spirv_generator.py.
2 parents 92a7a7d + d9eb596 commit 1608e43

File tree

2 files changed

+9
-70
lines changed

2 files changed

+9
-70
lines changed

numba_dpex/core/config.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import logging
66
import os
77

8-
import dpctl
98
from numba.core import config
109

1110

@@ -54,9 +53,6 @@ def __getattr__(name):
5453
# To save intermediate files generated by th compiler
5554
SAVE_IR_FILES = _readenv("NUMBA_DPEX_SAVE_IR_FILES", int, 0)
5655

57-
# Turn SPIRV-VALIDATION ON/OFF switch
58-
SPIRV_VAL = _readenv("NUMBA_DPEX_SPIRV_VAL", int, 0)
59-
6056
# Dump offload diagnostics
6157
OFFLOAD_DIAGNOSTICS = _readenv("NUMBA_DPEX_OFFLOAD_DIAGNOSTICS", int, 0)
6258

numba_dpex/spirv_generator.py

Lines changed: 9 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
"""A wrapper to connect to the SPIR-V binaries (Tools, Translator)."""
5+
"""
6+
A wrapper to call dpcpp's llvm-spirv tool to generate a SPIR-V binary from
7+
a numba-dpex generated LLVM IR module.
8+
"""
69

710
import os
811
import tempfile
@@ -28,47 +31,8 @@ def run_cmd(args, error_message=None):
2831
)
2932

3033

31-
class CmdLine:
32-
def disassemble(self, ipath, opath):
33-
"""
34-
Disassemble a spirv module.
35-
36-
Args:
37-
ipath: Input file path of the spirv module.
38-
opath: Output file path of the disassembled spirv module.
39-
"""
40-
flags = []
41-
run_cmd(
42-
["spirv-dis", *flags, "-o", opath, ipath],
43-
error_message="Error during SPIRV disassemble",
44-
)
45-
46-
def validate(self, ipath):
47-
"""
48-
Validate a spirv module.
49-
50-
Args:
51-
ipath: Input file path of the spirv module.
52-
"""
53-
flags = []
54-
run_cmd(
55-
["spirv-val", *flags, ipath],
56-
error_message="Error during SPIRV validation",
57-
)
58-
59-
def optimize(self, ipath, opath):
60-
"""
61-
Optimize a spirv module.
62-
63-
Args:
64-
ipath: Input file path of the spirv module.
65-
opath: Output file path of the optimized spirv module.
66-
"""
67-
flags = []
68-
run_cmd(
69-
["spirv-opt", *flags, "-o", opath, ipath],
70-
error_message="Error during SPIRV optimization",
71-
)
34+
class _SpirvGenerator:
35+
"""Generates a SPIR-V binary from supplied LLVM IR."""
7236

7337
def generate(self, llvm_spirv_args, ipath, opath):
7438
"""
@@ -115,7 +79,7 @@ def __init__(self, context, llvmir, llvmbc):
11579
"""
11680
self._tmpdir = tempfile.mkdtemp()
11781
self._tempfiles = []
118-
self._cmd = CmdLine()
82+
self._generator = _SpirvGenerator()
11983
self._finalized = False
12084
self.context = context
12185

@@ -182,7 +146,7 @@ def finalize(self):
182146
print("generated_llvm.bc")
183147
print("".center(80, "="))
184148

185-
self._cmd.generate(
149+
self._generator.generate(
186150
llvm_spirv_args=llvm_spirv_args,
187151
ipath=self._llvmfile,
188152
opath=spirv_path,
@@ -199,28 +163,7 @@ def finalize(self):
199163
print("generated_spirv.spir")
200164
print("".center(80, "="))
201165

202-
# Validate the SPIR-V code
203-
if config.SPIRV_VAL == 1:
204-
try:
205-
self._cmd.validate(ipath=spirv_path)
206-
except CalledProcessError:
207-
print("SPIR-V Validation failed...")
208-
pass
209-
else:
210-
# Optimize SPIR-V code
211-
opt_path = self._track_temp_file("optimized-spirv")
212-
self._cmd.optimize(ipath=spirv_path, opath=opt_path)
213-
214-
if config.DUMP_ASSEMBLY:
215-
# Disassemble optimized SPIR-V code
216-
dis_path = self._track_temp_file("disassembled-spirv")
217-
self._cmd.disassemble(ipath=opt_path, opath=dis_path)
218-
with open(dis_path, "rb") as fin_opt:
219-
print("ASSEMBLY".center(80, "-"))
220-
print(fin_opt.read())
221-
print("".center(80, "="))
222-
223-
# Read and return final SPIR-V (not optimized!)
166+
# Read and return final SPIR-V
224167
with open(spirv_path, "rb") as fin:
225168
spirv = fin.read()
226169

0 commit comments

Comments
 (0)