Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5,326 changes: 290 additions & 5,036 deletions .basedpyright/baseline.json

Large diffs are not rendered by default.

7 changes: 7 additions & 0 deletions doc/array.rst
Original file line number Diff line number Diff line change
Expand Up @@ -308,3 +308,10 @@ Generating Arrays of Random Numbers
-----------------------------------

.. automodule:: pyopencl.clrandom

Type Aliases
------------

.. class:: cl.Device

See :class:`pyopencl.Device`.
1 change: 1 addition & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
nitpick_ignore = [
("py:class", r"numpy._typing._dtype_like._SupportsDType"),
("py:class", r"numpy._typing._dtype_like._DTypeDict"),
("py:class", r"pytest.Metafunc"),
]

intersphinx_mapping = {
Expand Down
12 changes: 6 additions & 6 deletions pyopencl/_cl.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -591,12 +591,12 @@ class kernel_arg_type_qualifier(IntEnum): # noqa: N801
to_string = classmethod(pyopencl._monkeypatch.to_string)

class kernel_work_group_info(IntEnum): # noqa: N801
WORK_GROUP_SIZE = auto()
COMPILE_WORK_GROUP_SIZE = auto()
LOCAL_MEM_SIZE = auto()
PREFERRED_WORK_GROUP_SIZE_MULTIPLE = auto()
PRIVATE_MEM_SIZE = auto()
GLOBAL_WORK_SIZE = auto()
WORK_GROUP_SIZE = 0x11B0
COMPILE_WORK_GROUP_SIZE = 0x11B1
LOCAL_MEM_SIZE = 0x11B2
PREFERRED_WORK_GROUP_SIZE_MULTIPLE = 0x11B3
PRIVATE_MEM_SIZE = 0x11B4
GLOBAL_WORK_SIZE = 0x11B5
to_string = classmethod(pyopencl._monkeypatch.to_string)

class kernel_sub_group_info(IntEnum): # noqa: N801
Expand Down
41 changes: 39 additions & 2 deletions pyopencl/_monkeypatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@
from typing import (
TYPE_CHECKING,
Any,
Literal,
TextIO,
TypeVar,
cast,
overload,
)
from warnings import warn

Expand Down Expand Up @@ -221,7 +223,7 @@ def __getattr__(self, name: str):
kernel_old_get_work_group_info = _cl.Kernel.get_work_group_info


def kernel_set_arg_types(self: _cl.Kernel, arg_types):
def kernel_set_arg_types(self: _cl.Kernel, arg_types) -> None:
arg_types = tuple(arg_types)

# {{{ arg counting bug handling
Expand Down Expand Up @@ -259,7 +261,42 @@ def kernel_set_arg_types(self: _cl.Kernel, arg_types):
devs=self.context.devices))


def kernel_get_work_group_info(self: _cl.Kernel, param: int, device: _cl.Device):
@overload
def kernel_get_work_group_info(
self: _cl.Kernel,
param: Literal[
_cl.kernel_work_group_info.WORK_GROUP_SIZE,
_cl.kernel_work_group_info.PREFERRED_WORK_GROUP_SIZE_MULTIPLE,
_cl.kernel_work_group_info.LOCAL_MEM_SIZE,
_cl.kernel_work_group_info.PRIVATE_MEM_SIZE,
],
device: _cl.Device
) -> int: ...

@overload
def kernel_get_work_group_info(
self: _cl.Kernel,
param: Literal[
_cl.kernel_work_group_info.COMPILE_WORK_GROUP_SIZE,
_cl.kernel_work_group_info.GLOBAL_WORK_SIZE,
],
device: _cl.Device
) -> Sequence[int]: ...


@overload
def kernel_get_work_group_info(
self: _cl.Kernel,
param: int,
device: _cl.Device
) -> object: ...


def kernel_get_work_group_info(
self: _cl.Kernel,
param: int,
device: _cl.Device
) -> object:
try:
wg_info_cache = self._wg_info_cache
except AttributeError:
Expand Down
3 changes: 1 addition & 2 deletions pyopencl/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,10 +244,9 @@ def kernel_runner(out: Array, *args: P.args, **kwargs: P.kwargs) -> cl.Event:
assert queue is not None

knl = kernel_getter(out, *args, **kwargs)
work_group_info = cast("int", knl.get_work_group_info(
gs, ls = out._get_sizes(queue, knl.get_work_group_info(
cl.kernel_work_group_info.WORK_GROUP_SIZE,
queue.device))
gs, ls = out._get_sizes(queue, work_group_info)

knl_args = (out, *args, out.size)
if ARRAY_KERNEL_EXEC_HOOK is not None:
Expand Down
6 changes: 2 additions & 4 deletions pyopencl/characterize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
"""


from typing import cast

from pytools import memoize

import pyopencl as cl
Expand Down Expand Up @@ -70,9 +68,9 @@ def reasonable_work_group_size_multiple(
}
""")
prg.build()
return cast("int", prg.knl.get_work_group_info(
return prg.knl.get_work_group_info(
cl.kernel_work_group_info.PREFERRED_WORK_GROUP_SIZE_MULTIPLE,
dev))
dev)


def nv_compute_capability(dev: cl.Device):
Expand Down
69 changes: 42 additions & 27 deletions pyopencl/cltypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,17 @@
"""

import warnings
from typing import Any
from typing import TYPE_CHECKING, Any, cast

import numpy as np

from pyopencl.tools import get_or_register_dtype


if TYPE_CHECKING:
import builtins
from collections.abc import MutableSequence

if __file__.endswith("array.py"):
warnings.warn(
"pyopencl.array.vec is deprecated. Please use pyopencl.cltypes.",
Expand All @@ -53,37 +57,41 @@

# {{{ vector types

def _create_vector_types():
def _create_vector_types() -> tuple[
dict[tuple[np.dtype[Any], builtins.int], np.dtype[Any]],
dict[np.dtype[Any], tuple[np.dtype[Any], builtins.int]]]:
mapping = [(k, globals()[k]) for k in
["char", "uchar", "short", "ushort", "int",
"uint", "long", "ulong", "float", "double"]]

def set_global(key, val):
def set_global(key: str, val: np.dtype[Any]) -> None:
globals()[key] = val

vec_types = {}
vec_type_to_scalar_and_count = {}
vec_types: dict[tuple[np.dtype[Any], builtins.int], np.dtype[Any]] = {}
vec_type_to_scalar_and_count: dict[np.dtype[Any],
tuple[np.dtype[Any], builtins.int]] = {}

field_names = ["x", "y", "z", "w"]

counts = [2, 3, 4, 8, 16]

for base_name, base_type in mapping:
for count in counts:
name = "%s%d" % (base_name, count)

titles = field_names[:count]
name = f"{base_name}{count}"
titles = cast("MutableSequence[str | None]", field_names[:count])

padded_count = count
if count == 3:
padded_count = 4

names = ["s%d" % i for i in range(count)]
names = [f"s{i}" for i in range(count)]
while len(names) < padded_count:
names.append("padding%d" % (len(names) - count))
pad = len(names) - count
names.append(f"padding{pad}")

if len(titles) < len(names):
titles.extend((len(names) - len(titles)) * [None])
pad = len(names) - len(titles)
titles.extend([None] * pad)

try:
dtype = np.dtype({
Expand All @@ -96,14 +104,16 @@ def set_global(key, val):
for (n, title)
in zip(names, titles, strict=True)])
except TypeError:
dtype = np.dtype([(n, base_type) for (n, title)
in zip(names, titles, strict=True)])
dtype = np.dtype([(n, base_type) for n in names])

assert isinstance(dtype, np.dtype)
get_or_register_dtype(name, dtype)

set_global(name, dtype)

def create_array(dtype, count, padded_count, *args, **kwargs):
def create_array(dtype: np.dtype[Any],
count: int,
padded_count: int,
*args: Any, **kwargs: Any) -> dict[str, Any]:
if len(args) < count:
from warnings import warn
warn("default values for make_xxx are deprecated;"
Expand All @@ -116,21 +126,26 @@ def create_array(dtype, count, padded_count, *args, **kwargs):
{"array": np.array,
"padded_args": padded_args,
"dtype": dtype})
for key, val in list(kwargs.items()):

for key, val in kwargs.items():
array[key] = val

return array

set_global("make_" + name, eval(
"lambda *args, **kwargs: create_array(dtype, %i, %i, "
"*args, **kwargs)" % (count, padded_count),
{"create_array": create_array, "dtype": dtype}))
set_global("filled_" + name, eval(
"lambda val: make_%s(*[val]*%i)" % (name, count)))
set_global("zeros_" + name, eval("lambda: filled_%s(0)" % (name)))
set_global("ones_" + name, eval("lambda: filled_%s(1)" % (name)))

vec_types[np.dtype(base_type), count] = dtype
vec_type_to_scalar_and_count[dtype] = np.dtype(base_type), count
set_global(
f"make_{name}",
eval("lambda *args, **kwargs: "
f"create_array(dtype, {count}, {padded_count}, *args, **kwargs)",
{"create_array": create_array, "dtype": dtype}))
set_global(
f"filled_{name}",
eval(f"lambda val: make_{name}(*[val]*{count})"))
set_global(f"zeros_{name}", eval(f"lambda: filled_{name}(0)"))
set_global(f"ones_{name}", eval(f"lambda: filled_{name}(1)"))

base_dtype = np.dtype(base_type)
vec_types[base_dtype, count] = dtype
vec_type_to_scalar_and_count[dtype] = base_dtype, count

return vec_types, vec_type_to_scalar_and_count

Expand Down
Loading
Loading