|
1 | | -from collections.abc import Sequence |
| 1 | +from collections.abc import Callable, Sequence |
2 | 2 | from typing import Any, cast |
3 | 3 |
|
4 | 4 | import numpy as np |
| 5 | +from numpy import broadcast_shapes, empty, ndindex, nditer |
5 | 6 |
|
6 | 7 | from pytensor import config |
7 | 8 | from pytensor.compile.builders import OpFromGraph |
|
28 | 29 | from pytensor.tensor.variable import TensorVariable |
29 | 30 |
|
30 | 31 |
|
| 32 | +def _vectorize_node_perform(core_node, batch_ndim: int): |
| 33 | + core_op_perform = core_node.op.perform |
| 34 | + n_outs = len(core_node.outputs) |
| 35 | + |
| 36 | + def vectorized_func( |
| 37 | + *args, |
| 38 | + core_node=core_node, |
| 39 | + core_op_perform=core_op_perform, |
| 40 | + batch_ndim=batch_ndim, |
| 41 | + n_outs=n_outs, |
| 42 | + ): |
| 43 | + batch_shape = broadcast_shapes(*(arg.shape[:batch_ndim] for arg in args)) |
| 44 | + args = list(args) |
| 45 | + for i, arg in enumerate(args): |
| 46 | + if arg.shape[:batch_ndim] != batch_shape: |
| 47 | + # Main logic of `np.broadcast_to` |
| 48 | + it = nditer( |
| 49 | + (arg,), |
| 50 | + flags=["multi_index", "zerosize_ok"], |
| 51 | + op_flags=["readonly"], |
| 52 | + itershape=batch_shape + arg.shape[batch_ndim:], |
| 53 | + order="C", |
| 54 | + ) |
| 55 | + with it: |
| 56 | + args[i] = it.itviews[0] |
| 57 | + |
| 58 | + core_output_storage = [[None] for _ in range(n_outs)] |
| 59 | + ndindex_iterator = ndindex(*batch_shape) |
| 60 | + # Call once to get the output shapes |
| 61 | + try: |
| 62 | + # TODO: Pass core shape as input like BlockwiseWithCoreShape does? |
| 63 | + index0 = next(ndindex_iterator) |
| 64 | + except StopIteration: |
| 65 | + raise NotImplementedError("vectorize with zero iterations not implemented") |
| 66 | + else: |
| 67 | + core_op_perform( |
| 68 | + core_node, |
| 69 | + [np.asarray(arg[index0]) for arg in args], |
| 70 | + core_output_storage, |
| 71 | + ) |
| 72 | + outputs = tuple( |
| 73 | + empty(batch_shape + core_output[0].shape, dtype=core_output[0].dtype) |
| 74 | + for core_output in core_output_storage |
| 75 | + ) |
| 76 | + for output, core_output in zip(outputs, core_output_storage): # noqa: B905 |
| 77 | + output[index0] = core_output[0] |
| 78 | + |
| 79 | + for index in ndindex_iterator: |
| 80 | + core_op_perform( |
| 81 | + core_node, |
| 82 | + [np.asarray(a[index]) for a in args], |
| 83 | + core_output_storage, |
| 84 | + ) |
| 85 | + for output, core_output in zip(outputs, core_output_storage): # noqa: B905 |
| 86 | + output[index] = core_output[0] |
| 87 | + |
| 88 | + return outputs |
| 89 | + |
| 90 | + return vectorized_func |
| 91 | + |
| 92 | + |
31 | 93 | class Blockwise(Op): |
32 | 94 | """Generalizes a core `Op` to work with batched dimensions. |
33 | 95 |
|
@@ -308,46 +370,29 @@ def L_op(self, inputs, outs, ograds): |
308 | 370 |
|
309 | 371 | return rval |
310 | 372 |
|
311 | | - def _create_node_gufunc(self, node) -> None: |
| 373 | + def _create_node_gufunc(self, node: Apply) -> Callable: |
312 | 374 | """Define (or retrieve) the node gufunc used in `perform`. |
313 | 375 |
|
314 | 376 | If the Blockwise or core_op have a `gufunc_spec`, the relevant numpy or scipy gufunc is used directly. |
315 | 377 | Otherwise, we default to `np.vectorize` of the core_op `perform` method for a dummy node. |
316 | 378 |
|
317 | 379 | The gufunc is stored in the tag of the node. |
318 | 380 | """ |
319 | | - gufunc_spec = self.gufunc_spec or getattr(self.core_op, "gufunc_spec", None) |
320 | | - |
321 | | - if gufunc_spec is not None: |
| 381 | + if ( |
| 382 | + gufunc_spec := self.gufunc_spec |
| 383 | + or getattr(self.core_op, "gufunc_spec", None) |
| 384 | + ) is not None: |
322 | 385 | gufunc = import_func_from_string(gufunc_spec[0]) |
323 | 386 | if gufunc is None: |
324 | 387 | raise ValueError(f"Could not import gufunc {gufunc_spec[0]} for {self}") |
325 | | - |
326 | 388 | else: |
327 | | - # Wrap core_op perform method in numpy vectorize |
328 | | - n_outs = len(self.outputs_sig) |
329 | 389 | core_node = self._create_dummy_core_node(node.inputs) |
330 | | - inner_outputs_storage = [[None] for _ in range(n_outs)] |
331 | | - |
332 | | - def core_func( |
333 | | - *inner_inputs, |
334 | | - core_node=core_node, |
335 | | - inner_outputs_storage=inner_outputs_storage, |
336 | | - ): |
337 | | - self.core_op.perform( |
338 | | - core_node, |
339 | | - [np.asarray(inp) for inp in inner_inputs], |
340 | | - inner_outputs_storage, |
341 | | - ) |
342 | | - |
343 | | - if n_outs == 1: |
344 | | - return inner_outputs_storage[0][0] |
345 | | - else: |
346 | | - return tuple(r[0] for r in inner_outputs_storage) |
347 | | - |
348 | | - gufunc = np.vectorize(core_func, signature=self.signature) |
| 390 | + gufunc = _vectorize_node_perform( |
| 391 | + core_node, |
| 392 | + batch_ndim=self.batch_ndim(node), |
| 393 | + ) |
349 | 394 |
|
350 | | - node.tag.gufunc = gufunc |
| 395 | + return gufunc |
351 | 396 |
|
352 | 397 | def _check_runtime_broadcast(self, node, inputs): |
353 | 398 | batch_ndim = self.batch_ndim(node) |
@@ -375,23 +420,15 @@ def perform(self, node, inputs, output_storage): |
375 | 420 | gufunc = getattr(node.tag, "gufunc", None) |
376 | 421 |
|
377 | 422 | if gufunc is None: |
378 | | - # Cache it once per node |
379 | | - self._create_node_gufunc(node) |
380 | | - gufunc = node.tag.gufunc |
| 423 | + gufunc = node.tag.gufunc = self._create_node_gufunc(node) |
381 | 424 |
|
382 | 425 | self._check_runtime_broadcast(node, inputs) |
383 | 426 |
|
384 | 427 | res = gufunc(*inputs) |
385 | 428 | if not isinstance(res, tuple): |
386 | 429 | res = (res,) |
387 | 430 |
|
388 | | - # strict=False because we are in a hot loop |
389 | | - for node_out, out_storage, r in zip( |
390 | | - node.outputs, output_storage, res, strict=False |
391 | | - ): |
392 | | - out_dtype = getattr(node_out, "dtype", None) |
393 | | - if out_dtype and out_dtype != r.dtype: |
394 | | - r = np.asarray(r, dtype=out_dtype) |
| 431 | + for node_out, out_storage, r in zip(node.outputs, output_storage, res): # noqa: B905 |
395 | 432 | out_storage[0] = r |
396 | 433 |
|
397 | 434 | def __str__(self): |
|
0 commit comments