|
1 | | -from collections.abc import Sequence |
| 1 | +from collections.abc import Callable, Sequence |
2 | 2 | from typing import Any, cast |
3 | 3 |
|
4 | 4 | import numpy as np |
| 5 | +from numpy import broadcast_shapes, empty |
5 | 6 |
|
6 | 7 | from pytensor import config |
7 | 8 | from pytensor.compile.builders import OpFromGraph |
|
22 | 23 | from pytensor.tensor.utils import ( |
23 | 24 | _parse_gufunc_signature, |
24 | 25 | broadcast_static_dim_lengths, |
| 26 | + faster_broadcast_to, |
| 27 | + faster_ndindex, |
25 | 28 | import_func_from_string, |
26 | 29 | safe_signature, |
27 | 30 | ) |
28 | 31 | from pytensor.tensor.variable import TensorVariable |
29 | 32 |
|
30 | 33 |
|
| 34 | +def _vectorize_node_perform( |
| 35 | + core_node, batch_bcast_patterns, batch_ndim: int, impl=None |
| 36 | +): |
| 37 | + """self, |
| 38 | + node: Apply, |
| 39 | + storage_map: StorageMapType, |
| 40 | + compute_map: ComputeMapType, |
| 41 | + no_recycling: list[Variable], |
| 42 | + impl: str | None = None,""" |
| 43 | + |
| 44 | + storage_map = {var: [None] for var in core_node.inputs + core_node.outputs} |
| 45 | + core_thunk = core_node.op.make_thunk(core_node, storage_map, None, [], impl=impl) |
| 46 | + single_in = len(core_node.inputs) == 1 |
| 47 | + core_input_storage = [storage_map[inp] for inp in core_node.inputs] |
| 48 | + core_output_storage = [storage_map[out] for out in core_node.outputs] |
| 49 | + core_storage = core_input_storage + core_output_storage |
| 50 | + |
| 51 | + def vectorized_perform( |
| 52 | + *args, |
| 53 | + batch_bcast_patterns=batch_bcast_patterns, |
| 54 | + batch_ndim=batch_ndim, |
| 55 | + single_in=single_in, |
| 56 | + core_thunk=core_thunk, |
| 57 | + core_input_storage=core_input_storage, |
| 58 | + core_output_storage=core_output_storage, |
| 59 | + core_storage=core_storage, |
| 60 | + ): |
| 61 | + if single_in: |
| 62 | + batch_shape = args[0].shape[:batch_ndim] |
| 63 | + else: |
| 64 | + _check_runtime_broadcast(args, batch_bcast_patterns, batch_ndim) |
| 65 | + batch_shape = broadcast_shapes(*(arg.shape[:batch_ndim] for arg in args)) |
| 66 | + args = list(args) |
| 67 | + for i, arg in enumerate(args): |
| 68 | + if arg.shape[:batch_ndim] != batch_shape: |
| 69 | + args[i] = faster_broadcast_to( |
| 70 | + arg, batch_shape + arg.shape[batch_ndim:] |
| 71 | + ) |
| 72 | + |
| 73 | + ndindex_iterator = faster_ndindex(batch_shape) |
| 74 | + # Call once to get the output shapes |
| 75 | + try: |
| 76 | + # TODO: Pass core shape as input like BlockwiseWithCoreShape does? |
| 77 | + index0 = next(ndindex_iterator) |
| 78 | + except StopIteration: |
| 79 | + raise NotImplementedError("vectorize with zero size not implemented") |
| 80 | + else: |
| 81 | + for core_input, arg in zip(core_input_storage, args): |
| 82 | + core_input[0] = np.asarray(arg[index0]) |
| 83 | + core_thunk() |
| 84 | + outputs = tuple( |
| 85 | + empty(batch_shape + core_output[0].shape, dtype=core_output[0].dtype) |
| 86 | + for core_output in core_output_storage |
| 87 | + ) |
| 88 | + for output, core_output in zip(outputs, core_output_storage): |
| 89 | + output[index0] = core_output[0] |
| 90 | + |
| 91 | + for index in ndindex_iterator: |
| 92 | + for core_input, arg in zip(core_input_storage, args): |
| 93 | + core_input[0] = np.asarray(arg[index]) |
| 94 | + core_thunk() |
| 95 | + for output, core_output in zip(outputs, core_output_storage): |
| 96 | + output[index] = core_output[0] |
| 97 | + |
| 98 | + # Clear storage |
| 99 | + for core_val in core_storage: |
| 100 | + core_val[0] = None |
| 101 | + return outputs |
| 102 | + |
| 103 | + return vectorized_perform |
| 104 | + |
| 105 | + |
| 106 | +def _check_runtime_broadcast(numerical_inputs, batch_bcast_patterns, batch_ndim): |
| 107 | + # strict=None because we are in a hot loop |
| 108 | + # We zip together the dimension lengths of each input and their broadcast patterns |
| 109 | + for dim_lengths_and_bcast in zip( |
| 110 | + *[ |
| 111 | + zip(input.shape[:batch_ndim], batch_bcast_pattern) |
| 112 | + for input, batch_bcast_pattern in zip( |
| 113 | + numerical_inputs, batch_bcast_patterns |
| 114 | + ) |
| 115 | + ], |
| 116 | + ): |
| 117 | + # If for any dimension where an entry has dim_length != 1, |
| 118 | + # and another a dim_length of 1 and broadcastable=False, we have runtime broadcasting. |
| 119 | + if ( |
| 120 | + any(d != 1 for d, _ in dim_lengths_and_bcast) |
| 121 | + and (1, False) in dim_lengths_and_bcast |
| 122 | + ): |
| 123 | + raise ValueError( |
| 124 | + "Runtime broadcasting not allowed. " |
| 125 | + "At least one input has a distinct batch dimension length of 1, but was not marked as broadcastable.\n" |
| 126 | + "If broadcasting was intended, use `specify_broadcastable` on the relevant input." |
| 127 | + ) |
| 128 | + |
| 129 | + |
31 | 130 | class Blockwise(Op): |
32 | 131 | """Generalizes a core `Op` to work with batched dimensions. |
33 | 132 |
|
@@ -308,91 +407,62 @@ def L_op(self, inputs, outs, ograds): |
308 | 407 |
|
309 | 408 | return rval |
310 | 409 |
|
311 | | - def _create_node_gufunc(self, node) -> None: |
| 410 | + def _create_node_gufunc(self, node: Apply, impl) -> Callable: |
312 | 411 | """Define (or retrieve) the node gufunc used in `perform`. |
313 | 412 |
|
314 | 413 | If the Blockwise or core_op have a `gufunc_spec`, the relevant numpy or scipy gufunc is used directly. |
315 | 414 | Otherwise, we default to `np.vectorize` of the core_op `perform` method for a dummy node. |
316 | 415 |
|
317 | 416 | The gufunc is stored in the tag of the node. |
318 | 417 | """ |
319 | | - gufunc_spec = self.gufunc_spec or getattr(self.core_op, "gufunc_spec", None) |
320 | | - |
321 | | - if gufunc_spec is not None: |
322 | | - gufunc = import_func_from_string(gufunc_spec[0]) |
323 | | - if gufunc is None: |
| 418 | + batch_ndim = self.batch_ndim(node) |
| 419 | + batch_bcast_patterns = [ |
| 420 | + inp.type.broadcastable[:batch_ndim] for inp in node.inputs |
| 421 | + ] |
| 422 | + if ( |
| 423 | + gufunc_spec := self.gufunc_spec |
| 424 | + or getattr(self.core_op, "gufunc_spec", None) |
| 425 | + ) is not None: |
| 426 | + core_func = import_func_from_string(gufunc_spec[0]) |
| 427 | + if core_func is None: |
324 | 428 | raise ValueError(f"Could not import gufunc {gufunc_spec[0]} for {self}") |
325 | 429 |
|
326 | | - else: |
327 | | - # Wrap core_op perform method in numpy vectorize |
328 | | - n_outs = len(self.outputs_sig) |
329 | | - core_node = self._create_dummy_core_node(node.inputs) |
330 | | - inner_outputs_storage = [[None] for _ in range(n_outs)] |
331 | | - |
332 | | - def core_func( |
333 | | - *inner_inputs, |
334 | | - core_node=core_node, |
335 | | - inner_outputs_storage=inner_outputs_storage, |
336 | | - ): |
337 | | - self.core_op.perform( |
338 | | - core_node, |
339 | | - [np.asarray(inp) for inp in inner_inputs], |
340 | | - inner_outputs_storage, |
341 | | - ) |
| 430 | + if len(node.outputs) == 1: |
342 | 431 |
|
343 | | - if n_outs == 1: |
344 | | - return inner_outputs_storage[0][0] |
345 | | - else: |
346 | | - return tuple(r[0] for r in inner_outputs_storage) |
| 432 | + def gufunc(*inputs): |
| 433 | + _check_runtime_broadcast(inputs, batch_bcast_patterns, batch_ndim) |
| 434 | + return (core_func(*inputs),) |
| 435 | + else: |
347 | 436 |
|
348 | | - gufunc = np.vectorize(core_func, signature=self.signature) |
| 437 | + def gufunc(*inputs): |
| 438 | + _check_runtime_broadcast(inputs, batch_bcast_patterns, batch_ndim) |
| 439 | + return core_func(*inputs) |
| 440 | + else: |
| 441 | + core_node = self._create_dummy_core_node(node.inputs) # type: ignore |
| 442 | + gufunc = _vectorize_node_perform( |
| 443 | + core_node, |
| 444 | + batch_bcast_patterns=batch_bcast_patterns, |
| 445 | + batch_ndim=self.batch_ndim(node), |
| 446 | + impl=impl, |
| 447 | + ) |
349 | 448 |
|
350 | | - node.tag.gufunc = gufunc |
| 449 | + return gufunc |
351 | 450 |
|
352 | 451 | def _check_runtime_broadcast(self, node, inputs): |
353 | 452 | batch_ndim = self.batch_ndim(node) |
| 453 | + batch_bcast = [pt_inp.type.broadcastable[:batch_ndim] for pt_inp in node.inputs] |
| 454 | + _check_runtime_broadcast(inputs, batch_bcast, batch_ndim) |
354 | 455 |
|
355 | | - # strict=False because we are in a hot loop |
356 | | - for dims_and_bcast in zip( |
357 | | - *[ |
358 | | - zip( |
359 | | - input.shape[:batch_ndim], |
360 | | - sinput.type.broadcastable[:batch_ndim], |
361 | | - strict=False, |
362 | | - ) |
363 | | - for input, sinput in zip(inputs, node.inputs, strict=False) |
364 | | - ], |
365 | | - strict=False, |
366 | | - ): |
367 | | - if any(d != 1 for d, _ in dims_and_bcast) and (1, False) in dims_and_bcast: |
368 | | - raise ValueError( |
369 | | - "Runtime broadcasting not allowed. " |
370 | | - "At least one input has a distinct batch dimension length of 1, but was not marked as broadcastable.\n" |
371 | | - "If broadcasting was intended, use `specify_broadcastable` on the relevant input." |
372 | | - ) |
| 456 | + def prepare_node(self, node, storage_map, compute_map, impl=None): |
| 457 | + node.tag.gufunc = self._create_node_gufunc(node, impl=impl) |
373 | 458 |
|
374 | 459 | def perform(self, node, inputs, output_storage): |
375 | | - gufunc = getattr(node.tag, "gufunc", None) |
376 | | - |
377 | | - if gufunc is None: |
378 | | - # Cache it once per node |
379 | | - self._create_node_gufunc(node) |
| 460 | + try: |
380 | 461 | gufunc = node.tag.gufunc |
381 | | - |
382 | | - self._check_runtime_broadcast(node, inputs) |
383 | | - |
384 | | - res = gufunc(*inputs) |
385 | | - if not isinstance(res, tuple): |
386 | | - res = (res,) |
387 | | - |
388 | | - # strict=False because we are in a hot loop |
389 | | - for node_out, out_storage, r in zip( |
390 | | - node.outputs, output_storage, res, strict=False |
391 | | - ): |
392 | | - out_dtype = getattr(node_out, "dtype", None) |
393 | | - if out_dtype and out_dtype != r.dtype: |
394 | | - r = np.asarray(r, dtype=out_dtype) |
395 | | - out_storage[0] = r |
| 462 | + except AttributeError: |
| 463 | + gufunc = node.tag.gufunc = self._create_node_gufunc(node, impl=None) |
| 464 | + for out_storage, result in zip(output_storage, gufunc(*inputs)): |
| 465 | + out_storage[0] = result |
396 | 466 |
|
397 | 467 | def __str__(self): |
398 | 468 | if self.name is None: |
|
0 commit comments