Skip to content

Commit 1719c3c

Browse files
committed
Fix forced import of pyopencl in cl.fake_numpy
1 parent b57a7f0 commit 1719c3c

File tree

1 file changed

+19
-7
lines changed

1 file changed

+19
-7
lines changed

arraycontext/impl/pyopencl/fake_numpy.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,6 @@
4747
from arraycontext.loopy import LoopyBasedFakeNumpyNamespace
4848

4949

50-
try:
51-
import pyopencl as cl # noqa: F401
52-
import pyopencl.array as cl_array
53-
except ImportError:
54-
pass
55-
56-
5750
# {{{ fake numpy
5851

5952
class PyOpenCLFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace):
@@ -121,6 +114,7 @@ def _copy(subary):
121114
return self._array_context._rec_map_container(_copy, ary)
122115

123116
def arange(self, *args, **kwargs):
117+
import pyopencl.array as cl_array
124118
return cl_array.arange(self._array_context.queue, *args, **kwargs)
125119

126120
# }}}
@@ -155,13 +149,15 @@ def _rec_ravel(a):
155149
return rec_map_array_container(_rec_ravel, a)
156150

157151
def concatenate(self, arrays, axis=0):
152+
import pyopencl.array as cl_array
158153
return cl_array.concatenate(
159154
arrays, axis,
160155
self._array_context.queue,
161156
self._array_context.allocator
162157
)
163158

164159
def stack(self, arrays, axis=0):
160+
import pyopencl.array as cl_array
165161
return rec_multimap_array_container(
166162
lambda *args: cl_array.stack(arrays=args, axis=axis,
167163
queue=self._array_context.queue),
@@ -172,6 +168,7 @@ def stack(self, arrays, axis=0):
172168
# {{{ linear algebra
173169

174170
def vdot(self, x, y, dtype=None):
171+
import pyopencl.array as cl_array
175172
return rec_multimap_reduce_array_container(
176173
sum,
177174
partial(cl_array.vdot, dtype=dtype, queue=self._array_context.queue),
@@ -189,6 +186,7 @@ def _all(ary):
189186
return np.int8(all([ary]))
190187
return ary.all(queue=queue)
191188

189+
import pyopencl.array as cl_array
192190
return rec_map_reduce_array_container(
193191
partial(reduce, partial(cl_array.minimum, queue=queue)),
194192
_all,
@@ -202,6 +200,7 @@ def _any(ary):
202200
return np.int8(any([ary]))
203201
return ary.any(queue=queue)
204202

203+
import pyopencl.array as cl_array
205204
return rec_map_reduce_array_container(
206205
partial(reduce, partial(cl_array.maximum, queue=queue)),
207206
_any,
@@ -215,6 +214,8 @@ def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> Array:
215214
true_ary = actx.from_numpy(np.int8(True))
216215
false_ary = actx.from_numpy(np.int8(False))
217216

217+
import pyopencl.array as cl_array
218+
218219
def rec_equal(x: ArrayOrContainer, y: ArrayOrContainer) -> cl_array.Array:
219220
if type(x) is not type(y):
220221
return false_ary
@@ -270,12 +271,15 @@ def not_equal(self, x, y):
270271
return rec_multimap_array_container(operator.ne, x, y)
271272

272273
def logical_or(self, x, y):
274+
import pyopencl.array as cl_array
273275
return rec_multimap_array_container(cl_array.logical_or, x, y)
274276

275277
def logical_and(self, x, y):
278+
import pyopencl.array as cl_array
276279
return rec_multimap_array_container(cl_array.logical_and, x, y)
277280

278281
def logical_not(self, x):
282+
import pyopencl.array as cl_array
279283
return rec_map_array_container(cl_array.logical_not, x)
280284

281285
# }}}
@@ -290,11 +294,13 @@ def _rec_sum(ary):
290294
if axis not in [None, tuple(range(ary.ndim))]:
291295
raise NotImplementedError(f"Sum over '{axis}' axes not supported.")
292296

297+
import pyopencl.array as cl_array
293298
return cl_array.sum(ary, dtype=dtype, queue=self._array_context.queue)
294299

295300
return rec_map_reduce_array_container(sum, _rec_sum, a)
296301

297302
def maximum(self, x, y):
303+
import pyopencl.array as cl_array
298304
return rec_multimap_array_container(
299305
partial(cl_array.maximum, queue=self._array_context.queue),
300306
x, y)
@@ -308,8 +314,10 @@ def amax(self, a, axis=None):
308314
def _rec_max(ary):
309315
if axis not in [None, tuple(range(ary.ndim))]:
310316
raise NotImplementedError(f"Max. over '{axis}' axes not supported.")
317+
import pyopencl.array as cl_array
311318
return cl_array.max(ary, queue=queue)
312319

320+
import pyopencl.array as cl_array
313321
return rec_map_reduce_array_container(
314322
partial(reduce, partial(cl_array.maximum, queue=queue)),
315323
_rec_max,
@@ -318,6 +326,7 @@ def _rec_max(ary):
318326
max = amax
319327

320328
def minimum(self, x, y):
329+
import pyopencl.array as cl_array
321330
return rec_multimap_array_container(
322331
partial(cl_array.minimum, queue=self._array_context.queue),
323332
x, y)
@@ -331,8 +340,10 @@ def amin(self, a, axis=None):
331340
def _rec_min(ary):
332341
if axis not in [None, tuple(range(ary.ndim))]:
333342
raise NotImplementedError(f"Min. over '{axis}' axes not supported.")
343+
import pyopencl.array as cl_array
334344
return cl_array.min(ary, queue=queue)
335345

346+
import pyopencl.array as cl_array
336347
return rec_map_reduce_array_container(
337348
partial(reduce, partial(cl_array.minimum, queue=queue)),
338349
_rec_min,
@@ -351,6 +362,7 @@ def where(self, criterion, then, else_):
351362
def where_inner(inner_crit, inner_then, inner_else):
352363
if isinstance(inner_crit, bool | np.bool_):
353364
return inner_then if inner_crit else inner_else
365+
import pyopencl.array as cl_array
354366
return cl_array.if_positive(inner_crit != 0, inner_then, inner_else,
355367
queue=self._array_context.queue)
356368

0 commit comments

Comments
 (0)