3333THE SOFTWARE.
3434"""
3535
36- from typing import TYPE_CHECKING , Any , cast
36+ from typing import TYPE_CHECKING , Any , ClassVar , cast
3737
3838from typing_extensions import override
3939
4343if TYPE_CHECKING :
4444 from collections .abc import Callable , Sequence
4545
46+ import pytest
47+
4648 import pyopencl as cl
4749
4850 from arraycontext .context import ArrayContext
@@ -66,7 +68,7 @@ class PytestPyOpenCLArrayContextFactory(PytestArrayContextFactory):
6668 """
6769 device : cl .Device
6870
69- def __init__ (self , device : cl .Device ):
71+ def __init__ (self , device : cl .Device ) -> None :
7072 """
7173 :arg device: a :class:`pyopencl.Device`.
7274 """
@@ -76,12 +78,12 @@ def __init__(self, device: cl.Device):
7678 @override
7779 def is_available (cls ) -> bool :
7880 try :
79- import pyopencl # noqa: F401
81+ import pyopencl # noqa: F401 # pyright: ignore[reportUnusedImport]
8082 return True
8183 except ImportError :
8284 return False
8385
84- def get_command_queue (self ):
86+ def get_command_queue (self ) -> tuple [ cl . Context , cl . CommandQueue ] :
8587 # Get rid of leftovers from past tests.
8688 # CL implementations are surprisingly limited in how many
8789 # simultaneous contexts they allow...
@@ -101,22 +103,24 @@ def get_command_queue(self):
101103
102104class _PytestPyOpenCLArrayContextFactoryWithClass (PytestPyOpenCLArrayContextFactory ):
103105 # Deprecated, remove in 2025.
104- _force_device_scalars = True
106+ _force_device_scalars : ClassVar [ bool ] = True
105107
106108 @property
107- def force_device_scalars (self ):
109+ def force_device_scalars (self ) -> bool :
108110 from warnings import warn
109111 warn (
110112 "force_device_scalars is deprecated and will be removed in 2025." ,
111113 DeprecationWarning , stacklevel = 2 )
114+
112115 return self ._force_device_scalars
113116
114117 @property
115- def actx_class (self ):
118+ def actx_class (self ) -> type [ ArrayContext ] :
116119 from arraycontext import PyOpenCLArrayContext
117120 return PyOpenCLArrayContext
118121
119- def __call__ (self ):
122+ @override
123+ def __call__ (self ) -> ArrayContext :
120124 # The ostensibly pointless assignment to *ctx* keeps the CL context alive
121125 # long enough to create the array context, which will then start
122126 # holding a reference to the context to keep it alive in turn.
@@ -125,7 +129,6 @@ def __call__(self):
125129 _ctx , queue = self .get_command_queue ()
126130
127131 alloc = None
128-
129132 if queue .device .platform .name == "NVIDIA CUDA" :
130133 from pyopencl .tools import ImmediateAllocator
131134 alloc = ImmediateAllocator (queue )
@@ -136,34 +139,33 @@ def __call__(self):
136139 "See https://github.com/inducer/arraycontext/issues/196" ,
137140 stacklevel = 1 )
138141
139- return self .actx_class (
140- queue ,
141- allocator = alloc )
142+ return self .actx_class (queue , allocator = alloc )
142143
143- def __str__ (self ):
144+ @override
145+ def __str__ (self ) -> str :
144146 return (f"<{ self .actx_class .__name__ } "
145147 f"for <pyopencl.Device '{ self .device .name .strip ()} ' "
146148 f"on '{ self .device .platform .name .strip ()} '>>" )
147149
148150
149151class _PytestPytatoPyOpenCLArrayContextFactory (PytestPyOpenCLArrayContextFactory ):
150152 @classmethod
153+ @override
151154 def is_available (cls ) -> bool :
152155 try :
153- import pyopencl # noqa: F401
154- import pytato # noqa: F401
156+ import pyopencl # noqa: F401 # pyright: ignore[reportUnusedImport]
157+ import pytato # noqa: F401 # pyright: ignore[reportUnusedImport]
155158 return True
156159 except ImportError :
157160 return False
158161
159162 @property
160- def actx_class (self ):
163+ def actx_class (self ) -> type [ ArrayContext ] :
161164 from arraycontext import PytatoPyOpenCLArrayContext
162- actx_cls = PytatoPyOpenCLArrayContext
163- return actx_cls
165+ return PytatoPyOpenCLArrayContext
164166
165167 @override
166- def __call__ (self ):
168+ def __call__ (self ) -> ArrayContext :
167169 # The ostensibly pointless assignment to *ctx* keeps the CL context alive
168170 # long enough to create the array context, which will then start
169171 # holding a reference to the context to keep it alive in turn.
@@ -186,75 +188,79 @@ def __call__(self):
186188 return self .actx_class (queue , allocator = alloc )
187189
188190 @override
189- def __str__ (self ):
190- return ("<PytatoPyOpenCLArrayContext for "
191+ def __str__ (self ) -> str :
192+ return (f"< { self . actx_class . __name__ } for "
191193 f"<pyopencl.Device '{ self .device .name .strip ()} ' "
192194 f"on '{ self .device .platform .name .strip ()} '>>" )
193195
194196
195197class _PytestEagerJaxArrayContextFactory (PytestArrayContextFactory ):
196- def __init__ (self , * args , ** kwargs ):
198+ def __init__ (self , * args , ** kwargs ) -> None :
197199 pass
198200
199201 @classmethod
200202 @override
201203 def is_available (cls ) -> bool :
202204 try :
203- import jax # noqa: F401
205+ import jax # noqa: F401 # pyright: ignore[reportUnusedImport]
204206 return True
205207 except ImportError :
206208 return False
207209
208210 @override
209- def __call__ (self ):
210- from jax import config
211+ def __call__ (self ) -> ArrayContext :
212+ import jax
211213
212214 from arraycontext import EagerJAXArrayContext
213- config .update ("jax_enable_x64" , True )
215+
216+ jax .config .update ("jax_enable_x64" , True )
214217 return EagerJAXArrayContext ()
215218
216219 @override
217- def __str__ (self ):
220+ def __str__ (self ) -> str :
218221 return "<EagerJAXArrayContext>"
219222
220223
221224class _PytestPytatoJaxArrayContextFactory (PytestArrayContextFactory ):
222- def __init__ (self , * args , ** kwargs ):
225+ def __init__ (self , * args , ** kwargs ) -> None :
223226 pass
224227
225228 @classmethod
229+ @override
226230 def is_available (cls ) -> bool :
227231 try :
228- import jax # noqa: F401
229- import pytato # noqa: F401
232+ import jax # noqa: F401 # pyright: ignore[reportUnusedImport]
233+ import pytato # noqa: F401 # pyright: ignore[reportUnusedImport]
230234 return True
231235 except ImportError :
232236 return False
233237
234- def __call__ (self ):
235- from jax import config
238+ @override
239+ def __call__ (self ) -> ArrayContext :
240+ import jax
236241
237242 from arraycontext import PytatoJAXArrayContext
238- config .update ("jax_enable_x64" , True )
243+
244+ jax .config .update ("jax_enable_x64" , True )
239245 return PytatoJAXArrayContext ()
240246
241247 @override
242- def __str__ (self ):
248+ def __str__ (self ) -> str :
243249 return "<PytatoJAXArrayContext>"
244250
245251
246252# {{{ _PytestArrayContextFactory
247253
248254class _PytestNumpyArrayContextFactory (PytestArrayContextFactory ):
249- def __init__ (self , * args , ** kwargs ):
255+ def __init__ (self , * args , ** kwargs ) -> None :
250256 super ().__init__ ()
251257
252258 @override
253259 def __call__ (self ) -> NumpyArrayContext :
254260 return NumpyArrayContext ()
255261
256262 @override
257- def __str__ (self ):
263+ def __str__ (self ) -> str :
258264 return "<NumpyArrayContext>"
259265
260266# }}}
@@ -322,6 +328,7 @@ def pytest_generate_tests_for_array_contexts(
322328 import os
323329 env_factory_string = os .environ .get ("ARRAYCONTEXT_TEST" , None )
324330
331+ unique_factories : set [str | type [PytestArrayContextFactory ]]
325332 if env_factory_string is not None :
326333 unique_factories = set (env_factory_string .split ("," ))
327334 else :
@@ -345,7 +352,8 @@ def pytest_generate_tests_for_array_contexts(
345352 raise ValueError (f"unknown array contexts: { unknown_factories } " )
346353
347354 available_factories = {
348- factory for key in unique_factories
355+ factory
356+ for key in unique_factories
349357 for factory in [_ARRAY_CONTEXT_FACTORY_REGISTRY .get (key , key )]
350358 if (
351359 not isinstance (factory , str )
@@ -360,7 +368,7 @@ def pytest_generate_tests_for_array_contexts(
360368
361369 # }}}
362370
363- def inner (metafunc ) :
371+ def inner (metafunc : pytest . Metafunc ) -> None :
364372 # {{{ get pyopencl devices
365373
366374 import pyopencl .tools as cl_tools
@@ -383,7 +391,7 @@ def inner(metafunc):
383391 f"Cannot use both an '{ factory_arg_name } ' and a "
384392 "'ctx_factory' / 'ctx_getter' as arguments." )
385393
386- arg_values_with_actx = []
394+ arg_values_with_actx : list [ dict [ str , Any ]] = []
387395
388396 if pyopencl_factories :
389397 for arg_dict in arg_values :
0 commit comments