2
2
#
3
3
# SPDX-License-Identifier: Apache-2.0
4
4
5
- """A wrapper to connect to the SPIR-V binaries (Tools, Translator)."""
5
+ """
6
+ A wrapper to call dpcpp's llvm-spirv tool to generate a SPIR-V binary from
7
+ a numba-dpex generated LLVM IR module.
8
+ """
6
9
7
10
import os
8
11
import tempfile
@@ -28,47 +31,8 @@ def run_cmd(args, error_message=None):
28
31
)
29
32
30
33
31
- class CmdLine :
32
- def disassemble (self , ipath , opath ):
33
- """
34
- Disassemble a spirv module.
35
-
36
- Args:
37
- ipath: Input file path of the spirv module.
38
- opath: Output file path of the disassembled spirv module.
39
- """
40
- flags = []
41
- run_cmd (
42
- ["spirv-dis" , * flags , "-o" , opath , ipath ],
43
- error_message = "Error during SPIRV disassemble" ,
44
- )
45
-
46
- def validate (self , ipath ):
47
- """
48
- Validate a spirv module.
49
-
50
- Args:
51
- ipath: Input file path of the spirv module.
52
- """
53
- flags = []
54
- run_cmd (
55
- ["spirv-val" , * flags , ipath ],
56
- error_message = "Error during SPIRV validation" ,
57
- )
58
-
59
- def optimize (self , ipath , opath ):
60
- """
61
- Optimize a spirv module.
62
-
63
- Args:
64
- ipath: Input file path of the spirv module.
65
- opath: Output file path of the optimized spirv module.
66
- """
67
- flags = []
68
- run_cmd (
69
- ["spirv-opt" , * flags , "-o" , opath , ipath ],
70
- error_message = "Error during SPIRV optimization" ,
71
- )
34
+ class _SpirvGenerator :
35
+ """Generates a SPIR-V binary from supplied LLVM IR."""
72
36
73
37
def generate (self , llvm_spirv_args , ipath , opath ):
74
38
"""
@@ -115,7 +79,7 @@ def __init__(self, context, llvmir, llvmbc):
115
79
"""
116
80
self ._tmpdir = tempfile .mkdtemp ()
117
81
self ._tempfiles = []
118
- self ._cmd = CmdLine ()
82
+ self ._generator = _SpirvGenerator ()
119
83
self ._finalized = False
120
84
self .context = context
121
85
@@ -182,7 +146,7 @@ def finalize(self):
182
146
print ("generated_llvm.bc" )
183
147
print ("" .center (80 , "=" ))
184
148
185
- self ._cmd .generate (
149
+ self ._generator .generate (
186
150
llvm_spirv_args = llvm_spirv_args ,
187
151
ipath = self ._llvmfile ,
188
152
opath = spirv_path ,
@@ -199,28 +163,7 @@ def finalize(self):
199
163
print ("generated_spirv.spir" )
200
164
print ("" .center (80 , "=" ))
201
165
202
- # Validate the SPIR-V code
203
- if config .SPIRV_VAL == 1 :
204
- try :
205
- self ._cmd .validate (ipath = spirv_path )
206
- except CalledProcessError :
207
- print ("SPIR-V Validation failed..." )
208
- pass
209
- else :
210
- # Optimize SPIR-V code
211
- opt_path = self ._track_temp_file ("optimized-spirv" )
212
- self ._cmd .optimize (ipath = spirv_path , opath = opt_path )
213
-
214
- if config .DUMP_ASSEMBLY :
215
- # Disassemble optimized SPIR-V code
216
- dis_path = self ._track_temp_file ("disassembled-spirv" )
217
- self ._cmd .disassemble (ipath = opt_path , opath = dis_path )
218
- with open (dis_path , "rb" ) as fin_opt :
219
- print ("ASSEMBLY" .center (80 , "-" ))
220
- print (fin_opt .read ())
221
- print ("" .center (80 , "=" ))
222
-
223
- # Read and return final SPIR-V (not optimized!)
166
+ # Read and return final SPIR-V
224
167
with open (spirv_path , "rb" ) as fin :
225
168
spirv = fin .read ()
226
169
0 commit comments