Skip to content

Commit 2ec9e7e

Browse files
committed
Improve python typing
1 parent 7eb3e9a commit 2ec9e7e

File tree

3 files changed

+5
-6
lines changed

3 files changed

+5
-6
lines changed

numba_dpex/core/kernel_interface/spirv_kernel.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from numba_dpex import config, spirv_generator
1111
from numba_dpex.core.compiler import compile_with_dpex
1212
from numba_dpex.core.exceptions import UncompiledKernelError, UnreachableError
13+
from numba_dpex.core.targets.kernel_target import DpexKernelTargetContext
1314

1415
from .kernel_base import KernelInterface
1516

@@ -133,9 +134,8 @@ def compile(
133134
)
134135

135136
func = cres.library.get_function(cres.fndesc.llvm_func_name)
136-
kernel = cres.target_context.prepare_spir_kernel(
137-
func, cres.signature.args
138-
)
137+
kernel_targetctx: DpexKernelTargetContext = cres.target_context
138+
kernel = kernel_targetctx.prepare_spir_kernel(func, cres.signature.args)
139139

140140
# XXX: Setting the inline_threshold in the following way is a temporary
141141
# workaround till the JitKernel dispatcher is replaced by

numba_dpex/core/parfors/parfor_lowerer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def _submit_parfor_kernel(
185185
kl_builder.set_arguments(
186186
kernel_fn.kernel_arg_types, kernel_args=kernel_args
187187
)
188-
kl_builder.set_dependant_event_list(dep_events=[])
188+
kl_builder.set_dependant_event_list([])
189189
event_ref = kl_builder.submit()
190190

191191
sycl.dpctl_event_wait(lowerer.builder, event_ref)

numba_dpex/core/targets/kernel_target.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -351,9 +351,8 @@ def mangler(self, name, argtypes, abi_tags=(), uid=None):
351351
)
352352

353353
def prepare_spir_kernel(self, func, argtypes):
354-
module = func.module
355354
func.linkage = "linkonce_odr"
356-
module.data_layout = codegen.SPIR_DATA_LAYOUT[self.address_size]
355+
func.module.data_layout = codegen.SPIR_DATA_LAYOUT[self.address_size]
357356
wrapper = self._generate_spir_kernel_wrapper(func, argtypes)
358357
return wrapper
359358

0 commit comments

Comments
 (0)