Skip to content

Commit ea5b3bd

Browse files
author
Diptorup Deb
authored
Merge pull request #1247 from IntelPython/feature/typing_improvement
Feature/typing improvement
2 parents 3f8cdf9 + 2ec9e7e commit ea5b3bd

File tree

4 files changed

+36
-22
lines changed

4 files changed

+36
-22
lines changed

numba_dpex/core/kernel_interface/spirv_kernel.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from numba_dpex import config, spirv_generator
1111
from numba_dpex.core.compiler import compile_with_dpex
1212
from numba_dpex.core.exceptions import UncompiledKernelError, UnreachableError
13+
from numba_dpex.core.targets.kernel_target import DpexKernelTargetContext
1314

1415
from .kernel_base import KernelInterface
1516

@@ -133,9 +134,8 @@ def compile(
133134
)
134135

135136
func = cres.library.get_function(cres.fndesc.llvm_func_name)
136-
kernel = cres.target_context.prepare_spir_kernel(
137-
func, cres.signature.args
138-
)
137+
kernel_targetctx: DpexKernelTargetContext = cres.target_context
138+
kernel = kernel_targetctx.prepare_spir_kernel(func, cres.signature.args)
139139

140140
# XXX: Setting the inline_threshold in the following way is a temporary
141141
# workaround till the JitKernel dispatcher is replaced by

numba_dpex/core/parfors/parfor_lowerer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def _submit_parfor_kernel(
185185
kl_builder.set_arguments(
186186
kernel_fn.kernel_arg_types, kernel_args=kernel_args
187187
)
188-
kl_builder.set_dependant_event_list(dep_events=[])
188+
kl_builder.set_dependant_event_list([])
189189
event_ref = kl_builder.submit()
190190

191191
sycl.dpctl_event_wait(lowerer.builder, event_ref)

numba_dpex/core/targets/kernel_target.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -351,9 +351,8 @@ def mangler(self, name, argtypes, abi_tags=(), uid=None):
351351
)
352352

353353
def prepare_spir_kernel(self, func, argtypes):
354-
module = func.module
355354
func.linkage = "linkonce_odr"
356-
module.data_layout = codegen.SPIR_DATA_LAYOUT[self.address_size]
355+
func.module.data_layout = codegen.SPIR_DATA_LAYOUT[self.address_size]
357356
wrapper = self._generate_spir_kernel_wrapper(func, argtypes)
358357
return wrapper
359358

numba_dpex/experimental/kernel_dispatcher.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,28 +10,33 @@
1010
from typing import Tuple
1111

1212
import numba.core.event as ev
13+
from llvmlite.binding.value import ValueRef
1314
from numba.core import errors, sigutils, types
1415
from numba.core.compiler import CompileResult, Flags
1516
from numba.core.compiler_lock import global_compiler_lock
1617
from numba.core.dispatcher import Dispatcher, _FunctionCompiler
18+
from numba.core.funcdesc import PythonFunctionDescriptor
1719
from numba.core.target_extension import dispatcher_registry, target_registry
1820
from numba.core.types import void
1921
from numba.core.typing.typeof import Purpose, typeof
2022

2123
from numba_dpex import config, spirv_generator
24+
from numba_dpex.core.codegen import SPIRVCodeLibrary
2225
from numba_dpex.core.exceptions import (
2326
ExecutionQueueInferenceError,
2427
KernelHasReturnValueError,
2528
UnsupportedKernelArgumentError,
2629
)
2730
from numba_dpex.core.pipelines import kernel_compiler
28-
from numba_dpex.core.targets.kernel_target import CompilationMode
31+
from numba_dpex.core.targets.kernel_target import (
32+
CompilationMode,
33+
DpexKernelTargetContext,
34+
)
2935
from numba_dpex.core.types import DpnpNdArray
36+
from numba_dpex.core.utils import kernel_launcher as kl
3037

3138
from .target import DPEX_KERNEL_EXP_TARGET_NAME, dpex_exp_kernel_target
3239

33-
_KernelModule = namedtuple("_KernelModule", ["kernel_name", "kernel_bitcode"])
34-
3540
_KernelCompileResult = namedtuple(
3641
"_KernelCompileResult", CompileResult._fields + ("kernel_device_ir_module",)
3742
)
@@ -76,9 +81,14 @@ def check_queue_equivalence_of_args(
7681
)
7782

7883
def _compile_to_spirv(
79-
self, kernel_library, kernel_fndesc, kernel_targetctx
84+
self,
85+
kernel_library: SPIRVCodeLibrary,
86+
kernel_fndesc: PythonFunctionDescriptor,
87+
kernel_targetctx: DpexKernelTargetContext,
8088
):
81-
kernel_func = kernel_library.get_function(kernel_fndesc.llvm_func_name)
89+
kernel_func: ValueRef = kernel_library.get_function(
90+
kernel_fndesc.llvm_func_name
91+
)
8292

8393
# Create a spir_kernel wrapper function
8494
kernel_fn = kernel_targetctx.prepare_spir_kernel(
@@ -103,11 +113,11 @@ def _compile_to_spirv(
103113
kernel_library.final_module,
104114
kernel_library.final_module.as_bitcode(),
105115
)
106-
return _KernelModule(
116+
return kl.SPIRVKernelModule(
107117
kernel_name=kernel_fn.name, kernel_bitcode=kernel_spirv_module
108118
)
109119

110-
def compile(self, args, return_type):
120+
def compile(self, args, return_type) -> _KernelCompileResult:
111121
status, kcres = self._compile_cached(args, return_type)
112122
if status:
113123
return kcres
@@ -160,8 +170,10 @@ def _compile_cached(
160170
self.targetoptions["_compilation_mode"]
161171
== CompilationMode.KERNEL
162172
):
163-
kernel_device_ir_module: _KernelModule = self._compile_to_spirv(
164-
cres.library, cres.fndesc, cres.target_context
173+
kernel_device_ir_module: kl.SPIRVKernelModule = (
174+
self._compile_to_spirv(
175+
cres.library, cres.fndesc, cres.target_context
176+
)
165177
)
166178
else:
167179
kernel_device_ir_module = None
@@ -329,14 +341,17 @@ def cb_llvm(dur):
329341
# Add code to enable on disk caching of a binary spirv kernel.
330342
# Refer: https://github.com/IntelPython/numba-dpex/issues/1197
331343
self._cache_misses[sig] += 1
332-
ev_details = {
333-
"dispatcher": self,
334-
"args": args,
335-
"return_type": return_type,
336-
}
337-
with ev.trigger_event("numba_dpex:compile", data=ev_details):
344+
with ev.trigger_event(
345+
"numba_dpex:compile",
346+
data={
347+
"dispatcher": self,
348+
"args": args,
349+
"return_type": return_type,
350+
},
351+
):
338352
try:
339-
kcres: _KernelCompileResult = self._compiler.compile(
353+
compiler: _KernelCompiler = self._compiler
354+
kcres: _KernelCompileResult = compiler.compile(
340355
args, return_type
341356
)
342357
except errors.ForceLiteralArg as e:

0 commit comments

Comments
 (0)