Skip to content

Commit a455e46

Browse files
author
Diptorup Deb
committed
Move printimpl into kernel_api_impl.spirv
1 parent e0d0772 commit a455e46

File tree

2 files changed

+21
-7
lines changed

2 files changed

+21
-7
lines changed

numba_dpex/printimpl.py renamed to numba_dpex/kernel_api_impl/spirv/printimpl.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
"""
6+
An implementation of ``print`` for use in a kernel for the SPIRVKernelTarget.
7+
"""
8+
59
from functools import singledispatch
610

711
import llvmlite.ir as llvmir
@@ -14,7 +18,16 @@
1418
lower = registry.lower
1519

1620

17-
def declare_print(lmod):
21+
def declare_print(lmod: llvmir.Module):
22+
"""Inserts declaration for C printf into the given LLVM module
23+
24+
Args:
25+
lmod (llvmir.Module): LLVM module into which the function declaration
26+
needs to be inserted.
27+
28+
Returns:
29+
An LLVM IR Function object for the inserted C printf function.
30+
"""
1831
voidptrty = llvmir.PointerType(
1932
llvmir.IntType(8), addrspace=address_space.GENERIC.value
2033
)
@@ -32,33 +45,34 @@ def print_item(ty, context, builder, val):
3245
A (format string, [list of arguments]) is returned that will allow
3346
forming the final printf()-like call.
3447
"""
35-
raise NotImplementedError(
36-
"printing unimplemented for values of type %s" % (ty,)
37-
)
48+
raise NotImplementedError(f"printing unimplemented for values of type {ty}")
3849

3950

4051
@print_item.register(types.Integer)
4152
@print_item.register(types.IntegerLiteral)
4253
def int_print_impl(ty, context, builder, val):
54+
"""Implements printing an integer value."""
4355
if ty in types.unsigned_domain:
4456
rawfmt = "%llu"
4557
dsttype = types.uint64
4658
else:
4759
rawfmt = "%lld"
4860
dsttype = types.int64
49-
fmt = context.insert_const_string(builder.module, rawfmt) # noqa
61+
context.insert_const_string(builder.module, rawfmt)
5062
lld = context.cast(builder, val, ty, dsttype)
5163
return rawfmt, [lld]
5264

5365

5466
@print_item.register(types.Float)
5567
def real_print_impl(ty, context, builder, val):
68+
"""Implements printing a real number value."""
5669
lld = context.cast(builder, val, ty, types.float64)
5770
return "%f", [lld]
5871

5972

6073
@print_item.register(types.StringLiteral)
6174
def const_print_impl(ty, context, builder, sigval):
75+
"""Implements printing a string value."""
6276
pyval = ty.literal_value
6377
assert isinstance(pyval, str) # Ensured by lowering
6478
rawfmt = "%s"
@@ -76,7 +90,7 @@ def print_varargs(context, builder, sig, args):
7690
values = []
7791

7892
only_str = True
79-
for i, (argtype, argval) in enumerate(zip(sig.args, args)):
93+
for _, (argtype, argval) in enumerate(zip(sig.args, args)):
8094
argfmt, argvals = print_item(argtype, context, builder, argval)
8195
formats.append(argfmt)
8296
values.extend(argvals)

numba_dpex/kernel_api_impl/spirv/target.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,9 +291,9 @@ def load_additional_registries(self):
291291
292292
"""
293293
# pylint: disable=import-outside-toplevel
294-
from numba_dpex import printimpl
295294
from numba_dpex.dpctl_iface import dpctlimpl
296295
from numba_dpex.dpnp_iface import dpnpimpl
296+
from numba_dpex.kernel_api_impl.spirv import printimpl
297297
from numba_dpex.ocl import mathimpl
298298

299299
self.insert_func_defn(mathimpl.registry.functions)

0 commit comments

Comments
 (0)