Skip to content

Commit 96660b5

Browse files
committed
feat(typing): improve types in arraycontext.pytest
1 parent 8996ad0 commit 96660b5

File tree

1 file changed

+47
-39
lines changed

1 file changed

+47
-39
lines changed

arraycontext/pytest.py

Lines changed: 47 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
THE SOFTWARE.
3434
"""
3535

36-
from typing import TYPE_CHECKING, Any, cast
36+
from typing import TYPE_CHECKING, Any, ClassVar, cast
3737

3838
from typing_extensions import override
3939

@@ -43,6 +43,8 @@
4343
if TYPE_CHECKING:
4444
from collections.abc import Callable, Sequence
4545

46+
import pytest
47+
4648
import pyopencl as cl
4749

4850
from arraycontext.context import ArrayContext
@@ -66,7 +68,7 @@ class PytestPyOpenCLArrayContextFactory(PytestArrayContextFactory):
6668
"""
6769
device: cl.Device
6870

69-
def __init__(self, device: cl.Device):
71+
def __init__(self, device: cl.Device) -> None:
7072
"""
7173
:arg device: a :class:`pyopencl.Device`.
7274
"""
@@ -76,12 +78,12 @@ def __init__(self, device: cl.Device):
7678
@override
7779
def is_available(cls) -> bool:
7880
try:
79-
import pyopencl # noqa: F401
81+
import pyopencl # noqa: F401 # pyright: ignore[reportUnusedImport]
8082
return True
8183
except ImportError:
8284
return False
8385

84-
def get_command_queue(self):
86+
def get_command_queue(self) -> tuple[cl.Context, cl.CommandQueue]:
8587
# Get rid of leftovers from past tests.
8688
# CL implementations are surprisingly limited in how many
8789
# simultaneous contexts they allow...
@@ -101,22 +103,24 @@ def get_command_queue(self):
101103

102104
class _PytestPyOpenCLArrayContextFactoryWithClass(PytestPyOpenCLArrayContextFactory):
103105
# Deprecated, remove in 2025.
104-
_force_device_scalars = True
106+
_force_device_scalars: ClassVar[bool] = True
105107

106108
@property
107-
def force_device_scalars(self):
109+
def force_device_scalars(self) -> bool:
108110
from warnings import warn
109111
warn(
110112
"force_device_scalars is deprecated and will be removed in 2025.",
111113
DeprecationWarning, stacklevel=2)
114+
112115
return self._force_device_scalars
113116

114117
@property
115-
def actx_class(self):
118+
def actx_class(self) -> type[ArrayContext]:
116119
from arraycontext import PyOpenCLArrayContext
117120
return PyOpenCLArrayContext
118121

119-
def __call__(self):
122+
@override
123+
def __call__(self) -> ArrayContext:
120124
# The ostensibly pointless assignment to *ctx* keeps the CL context alive
121125
# long enough to create the array context, which will then start
122126
# holding a reference to the context to keep it alive in turn.
@@ -125,7 +129,6 @@ def __call__(self):
125129
_ctx, queue = self.get_command_queue()
126130

127131
alloc = None
128-
129132
if queue.device.platform.name == "NVIDIA CUDA":
130133
from pyopencl.tools import ImmediateAllocator
131134
alloc = ImmediateAllocator(queue)
@@ -136,34 +139,33 @@ def __call__(self):
136139
"See https://github.com/inducer/arraycontext/issues/196",
137140
stacklevel=1)
138141

139-
return self.actx_class(
140-
queue,
141-
allocator=alloc)
142+
return self.actx_class(queue, allocator=alloc)
142143

143-
def __str__(self):
144+
@override
145+
def __str__(self) -> str:
144146
return (f"<{self.actx_class.__name__} "
145147
f"for <pyopencl.Device '{self.device.name.strip()}' "
146148
f"on '{self.device.platform.name.strip()}'>>")
147149

148150

149151
class _PytestPytatoPyOpenCLArrayContextFactory(PytestPyOpenCLArrayContextFactory):
150152
@classmethod
153+
@override
151154
def is_available(cls) -> bool:
152155
try:
153-
import pyopencl # noqa: F401
154-
import pytato # noqa: F401
156+
import pyopencl # noqa: F401 # pyright: ignore[reportUnusedImport]
157+
import pytato # noqa: F401 # pyright: ignore[reportUnusedImport]
155158
return True
156159
except ImportError:
157160
return False
158161

159162
@property
160-
def actx_class(self):
163+
def actx_class(self) -> type[ArrayContext]:
161164
from arraycontext import PytatoPyOpenCLArrayContext
162-
actx_cls = PytatoPyOpenCLArrayContext
163-
return actx_cls
165+
return PytatoPyOpenCLArrayContext
164166

165167
@override
166-
def __call__(self):
168+
def __call__(self) -> ArrayContext:
167169
# The ostensibly pointless assignment to *ctx* keeps the CL context alive
168170
# long enough to create the array context, which will then start
169171
# holding a reference to the context to keep it alive in turn.
@@ -186,75 +188,79 @@ def __call__(self):
186188
return self.actx_class(queue, allocator=alloc)
187189

188190
@override
189-
def __str__(self):
190-
return ("<PytatoPyOpenCLArrayContext for "
191+
def __str__(self) -> str:
192+
return (f"<{self.actx_class.__name__} for "
191193
f"<pyopencl.Device '{self.device.name.strip()}' "
192194
f"on '{self.device.platform.name.strip()}'>>")
193195

194196

195197
class _PytestEagerJaxArrayContextFactory(PytestArrayContextFactory):
196-
def __init__(self, *args, **kwargs):
198+
def __init__(self, *args, **kwargs) -> None:
197199
pass
198200

199201
@classmethod
200202
@override
201203
def is_available(cls) -> bool:
202204
try:
203-
import jax # noqa: F401
205+
import jax # noqa: F401 # pyright: ignore[reportUnusedImport]
204206
return True
205207
except ImportError:
206208
return False
207209

208210
@override
209-
def __call__(self):
210-
from jax import config
211+
def __call__(self) -> ArrayContext:
212+
import jax
211213

212214
from arraycontext import EagerJAXArrayContext
213-
config.update("jax_enable_x64", True)
215+
216+
jax.config.update("jax_enable_x64", True)
214217
return EagerJAXArrayContext()
215218

216219
@override
217-
def __str__(self):
220+
def __str__(self) -> str:
218221
return "<EagerJAXArrayContext>"
219222

220223

221224
class _PytestPytatoJaxArrayContextFactory(PytestArrayContextFactory):
222-
def __init__(self, *args, **kwargs):
225+
def __init__(self, *args, **kwargs) -> None:
223226
pass
224227

225228
@classmethod
229+
@override
226230
def is_available(cls) -> bool:
227231
try:
228-
import jax # noqa: F401
229-
import pytato # noqa: F401
232+
import jax # noqa: F401 # pyright: ignore[reportUnusedImport]
233+
import pytato # noqa: F401 # pyright: ignore[reportUnusedImport]
230234
return True
231235
except ImportError:
232236
return False
233237

234-
def __call__(self):
235-
from jax import config
238+
@override
239+
def __call__(self) -> ArrayContext:
240+
import jax
236241

237242
from arraycontext import PytatoJAXArrayContext
238-
config.update("jax_enable_x64", True)
243+
244+
jax.config.update("jax_enable_x64", True)
239245
return PytatoJAXArrayContext()
240246

241247
@override
242-
def __str__(self):
248+
def __str__(self) -> str:
243249
return "<PytatoJAXArrayContext>"
244250

245251

246252
# {{{ _PytestArrayContextFactory
247253

248254
class _PytestNumpyArrayContextFactory(PytestArrayContextFactory):
249-
def __init__(self, *args, **kwargs):
255+
def __init__(self, *args, **kwargs) -> None:
250256
super().__init__()
251257

252258
@override
253259
def __call__(self) -> NumpyArrayContext:
254260
return NumpyArrayContext()
255261

256262
@override
257-
def __str__(self):
263+
def __str__(self) -> str:
258264
return "<NumpyArrayContext>"
259265

260266
# }}}
@@ -322,6 +328,7 @@ def pytest_generate_tests_for_array_contexts(
322328
import os
323329
env_factory_string = os.environ.get("ARRAYCONTEXT_TEST", None)
324330

331+
unique_factories: set[str | type[PytestArrayContextFactory]]
325332
if env_factory_string is not None:
326333
unique_factories = set(env_factory_string.split(","))
327334
else:
@@ -345,7 +352,8 @@ def pytest_generate_tests_for_array_contexts(
345352
raise ValueError(f"unknown array contexts: {unknown_factories}")
346353

347354
available_factories = {
348-
factory for key in unique_factories
355+
factory
356+
for key in unique_factories
349357
for factory in [_ARRAY_CONTEXT_FACTORY_REGISTRY.get(key, key)]
350358
if (
351359
not isinstance(factory, str)
@@ -360,7 +368,7 @@ def pytest_generate_tests_for_array_contexts(
360368

361369
# }}}
362370

363-
def inner(metafunc):
371+
def inner(metafunc: pytest.Metafunc) -> None:
364372
# {{{ get pyopencl devices
365373

366374
import pyopencl.tools as cl_tools
@@ -383,7 +391,7 @@ def inner(metafunc):
383391
f"Cannot use both an '{factory_arg_name}' and a "
384392
"'ctx_factory' / 'ctx_getter' as arguments.")
385393

386-
arg_values_with_actx = []
394+
arg_values_with_actx: list[dict[str, Any]] = []
387395

388396
if pyopencl_factories:
389397
for arg_dict in arg_values:

0 commit comments

Comments
 (0)