29
29
_KernelModule = namedtuple ("_KernelModule" , ["kernel_name" , "kernel_bitcode" ])
30
30
31
31
_KernelCompileResult = namedtuple (
32
- "_KernelCompileResult" ,
33
- ["status" , "cres_or_error" , "entry_point" ],
32
+ "_KernelCompileResult" , CompileResult ._fields + ("kernel_device_ir_module" ,)
34
33
)
35
34
36
35
@@ -96,11 +95,11 @@ def _compile_to_spirv(
96
95
)
97
96
98
97
def compile (self , args , return_type ):
99
- kcres = self ._compile_cached (args , return_type )
100
- if kcres . status :
98
+ status , kcres = self ._compile_cached (args , return_type )
99
+ if status :
101
100
return kcres
102
101
103
- raise kcres . cres_or_error
102
+ raise kcres
104
103
105
104
def _compile_cached (
106
105
self , args , return_type : types .Type
@@ -137,34 +136,45 @@ def _compile_cached(
137
136
"""
138
137
key = tuple (args ), return_type
139
138
try :
140
- return _KernelCompileResult ( False , self ._failed_cache [key ], None )
139
+ return False , self ._failed_cache [key ]
141
140
except KeyError :
142
141
pass
143
142
144
143
try :
145
- kernel_cres : CompileResult = self ._compile_core (args , return_type )
144
+ cres : CompileResult = self ._compile_core (args , return_type )
146
145
147
- kernel_library = kernel_cres .library
148
- kernel_fndesc = kernel_cres .fndesc
149
- kernel_targetctx = kernel_cres .target_context
150
-
151
- kernel_module = self ._compile_to_spirv (
152
- kernel_library , kernel_fndesc , kernel_targetctx
146
+ kernel_device_ir_module = self ._compile_to_spirv (
147
+ cres .library , cres .fndesc , cres .target_context
153
148
)
154
149
150
+ kcres_attrs = []
151
+
152
+ for cres_field in cres ._fields :
153
+ cres_attr = getattr (cres , cres_field )
154
+ if cres_field == "entry_point" :
155
+ if cres_attr is not None :
156
+ raise AssertionError (
157
+ "Compiled kernel and device_func should be "
158
+ "compiled with compile_cfunc option turned off"
159
+ )
160
+ cres_attr = cres .fndesc .qualname
161
+ kcres_attrs .append (cres_attr )
162
+
163
+ kcres_attrs .append (kernel_device_ir_module )
164
+
155
165
if config .DUMP_KERNEL_LLVM :
156
166
with open (
157
- kernel_cres .fndesc .llvm_func_name + ".ll" ,
167
+ cres .fndesc .llvm_func_name + ".ll" ,
158
168
"w" ,
159
169
encoding = "UTF-8" ,
160
170
) as f :
161
- f .write (kernel_cres .library .final_module )
171
+ f .write (cres .library .final_module )
162
172
163
173
except errors .TypingError as e :
164
174
self ._failed_cache [key ] = e
165
- return _KernelCompileResult ( False , e , None )
175
+ return False , e
166
176
167
- return _KernelCompileResult ( True , kernel_cres , kernel_module )
177
+ return True , _KernelCompileResult ( * kcres_attrs )
168
178
169
179
170
180
class KernelDispatcher (Dispatcher ):
@@ -234,7 +244,14 @@ def typeof_pyval(self, val):
234
244
235
245
def add_overload (self , cres ):
236
246
args = tuple (cres .signature .args )
237
- self .overloads [args ] = cres .entry_point
247
+ self .overloads [args ] = cres
248
+
249
+ def get_overload_device_ir (self , sig ):
250
+ """
251
+ Return the compiled device bitcode for the given signature.
252
+ """
253
+ args , _ = sigutils .normalize_signature (sig )
254
+ return self .overloads [tuple (args )].kernel_device_ir_module
238
255
239
256
def compile (self , sig ) -> _KernelCompileResult :
240
257
disp = self ._get_dispatcher_for_current_target ()
@@ -274,7 +291,7 @@ def cb_llvm(dur):
274
291
# Don't recompile if signature already exists
275
292
existing = self .overloads .get (tuple (args ))
276
293
if existing is not None :
277
- return existing
294
+ return existing . entry_point
278
295
279
296
# TODO: Enable caching
280
297
# Add code to enable on disk caching of a binary spirv kernel.
@@ -298,7 +315,11 @@ def folded(args, kws):
298
315
)[1 ]
299
316
300
317
raise e .bind_fold_arguments (folded )
301
- self .add_overload (kcres .cres_or_error )
318
+ self .add_overload (kcres )
319
+
320
+ kcres .target_context .insert_user_function (
321
+ kcres .entry_point , kcres .fndesc , [kcres .library ]
322
+ )
302
323
303
324
# TODO: enable caching of kernel_module
304
325
# https://github.com/IntelPython/numba-dpex/issues/1197
0 commit comments