22#
33# SPDX-License-Identifier: Apache-2.0
44
5- """A wrapper to connect to the SPIR-V binaries (Tools, Translator)."""
5+ """
6+ A wrapper to call dpcpp's llvm-spirv tool to generate a SPIR-V binary from
7+ a numba-dpex generated LLVM IR module.
8+ """
69
710import os
811import tempfile
@@ -28,47 +31,8 @@ def run_cmd(args, error_message=None):
2831 )
2932
3033
31- class CmdLine :
32- def disassemble (self , ipath , opath ):
33- """
34- Disassemble a spirv module.
35-
36- Args:
37- ipath: Input file path of the spirv module.
38- opath: Output file path of the disassembled spirv module.
39- """
40- flags = []
41- run_cmd (
42- ["spirv-dis" , * flags , "-o" , opath , ipath ],
43- error_message = "Error during SPIRV disassemble" ,
44- )
45-
46- def validate (self , ipath ):
47- """
48- Validate a spirv module.
49-
50- Args:
51- ipath: Input file path of the spirv module.
52- """
53- flags = []
54- run_cmd (
55- ["spirv-val" , * flags , ipath ],
56- error_message = "Error during SPIRV validation" ,
57- )
58-
59- def optimize (self , ipath , opath ):
60- """
61- Optimize a spirv module.
62-
63- Args:
64- ipath: Input file path of the spirv module.
65- opath: Output file path of the optimized spirv module.
66- """
67- flags = []
68- run_cmd (
69- ["spirv-opt" , * flags , "-o" , opath , ipath ],
70- error_message = "Error during SPIRV optimization" ,
71- )
34+ class _SpirvGenerator :
35+ """Generates a SPIR-V binary from supplied LLVM IR."""
7236
7337 def generate (self , llvm_spirv_args , ipath , opath ):
7438 """
@@ -115,7 +79,7 @@ def __init__(self, context, llvmir, llvmbc):
11579 """
11680 self ._tmpdir = tempfile .mkdtemp ()
11781 self ._tempfiles = []
118- self ._cmd = CmdLine ()
82+ self ._generator = _SpirvGenerator ()
11983 self ._finalized = False
12084 self .context = context
12185
@@ -182,7 +146,7 @@ def finalize(self):
182146 print ("generated_llvm.bc" )
183147 print ("" .center (80 , "=" ))
184148
185- self ._cmd .generate (
149+ self ._generator .generate (
186150 llvm_spirv_args = llvm_spirv_args ,
187151 ipath = self ._llvmfile ,
188152 opath = spirv_path ,
@@ -199,28 +163,7 @@ def finalize(self):
199163 print ("generated_spirv.spir" )
200164 print ("" .center (80 , "=" ))
201165
202- # Validate the SPIR-V code
203- if config .SPIRV_VAL == 1 :
204- try :
205- self ._cmd .validate (ipath = spirv_path )
206- except CalledProcessError :
207- print ("SPIR-V Validation failed..." )
208- pass
209- else :
210- # Optimize SPIR-V code
211- opt_path = self ._track_temp_file ("optimized-spirv" )
212- self ._cmd .optimize (ipath = spirv_path , opath = opt_path )
213-
214- if config .DUMP_ASSEMBLY :
215- # Disassemble optimized SPIR-V code
216- dis_path = self ._track_temp_file ("disassembled-spirv" )
217- self ._cmd .disassemble (ipath = opt_path , opath = dis_path )
218- with open (dis_path , "rb" ) as fin_opt :
219- print ("ASSEMBLY" .center (80 , "-" ))
220- print (fin_opt .read ())
221- print ("" .center (80 , "=" ))
222-
223- # Read and return final SPIR-V (not optimized!)
166+ # Read and return final SPIR-V
224167 with open (spirv_path , "rb" ) as fin :
225168 spirv = fin .read ()
226169
0 commit comments