|
59 | 59 |
|
60 | 60 | import numpy as np
|
61 | 61 |
|
| 62 | +from pytensor.tensor.rewriting.basic import register_specialize |
| 63 | + |
62 | 64 |
|
63 | 65 | try:
|
64 | 66 | import numpy.__config__ # noqa
|
|
79 | 81 | )
|
80 | 82 | from pytensor.graph.rewriting.db import SequenceDB
|
81 | 83 | from pytensor.graph.utils import InconsistencyError
|
82 |
| -from pytensor.printing import debugprint |
83 | 84 | from pytensor.tensor import basic as at
|
84 | 85 | from pytensor.tensor.blas import (
|
85 | 86 | Dot22,
|
86 | 87 | _dot22,
|
87 | 88 | _dot22scalar,
|
| 89 | + batched_dot, |
88 | 90 | gemm_inplace,
|
89 | 91 | gemm_no_inplace,
|
90 | 92 | gemv_inplace,
|
|
94 | 96 | )
|
95 | 97 | from pytensor.tensor.elemwise import DimShuffle, Elemwise
|
96 | 98 | from pytensor.tensor.exceptions import NotScalarConstantError
|
97 |
| -from pytensor.tensor.math import Dot, add, mul, neg, sub |
| 99 | +from pytensor.tensor.math import Dot, _matrix_matrix_matmul, add, mul, neg, sub |
98 | 100 | from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift
|
99 | 101 | from pytensor.tensor.type import (
|
100 | 102 | DenseTensorType,
|
@@ -899,9 +901,32 @@ def local_dot22_to_dot22scalar(fgraph, node):
|
899 | 901 | )
|
900 | 902 |
|
901 | 903 |
|
902 |
| -# from opt import register_specialize, register_canonicalize |
903 |
| -# @register_specialize |
904 |
| -@node_rewriter([sub, add]) |
905 |
| -def local_print_as_we_go_along(fgraph, node): |
906 |
| - if node.op in (sub, add): |
907 |
| - debugprint(node) |
| 904 | +@register_specialize |
| 905 | +@node_rewriter([_matrix_matrix_matmul]) |
| 906 | +def specialize_matmul_to_batched_dot(fgraph, node): |
| 907 | + """Rewrite Matmul (Blockwise matrix-matrix) without implicit broadcasted batched dimension as BatchedDot. |
| 908 | +
|
| 909 | + TODO: Do the same for Blockwise BatchedDot |
| 910 | + """ |
| 911 | + x, y = node.inputs |
| 912 | + |
| 913 | + # BatchedDot does not allow implicit broadcasting of the batch dimensions |
| 914 | + # We do not want to explicitly broadcast as it may result in huge arrays |
| 915 | + if x.type.broadcastable[:-2] != y.type.broadcastable[:-2]: |
| 916 | + return None |
| 917 | + |
| 918 | + x_shape = tuple(x.shape) |
| 919 | + y_shape = tuple(y.shape) |
| 920 | + if len(x_shape) > 3: |
| 921 | + # If we have more than one batch dim, ravel it |
| 922 | + x = x.reshape((-1, x_shape[-2], x_shape[-1])) |
| 923 | + y = y.reshape((-1, y_shape[-2], y_shape[-1])) |
| 924 | + |
| 925 | + new_out = batched_dot(x, y) |
| 926 | + |
| 927 | + if len(x_shape) > 3: |
| 928 | + # And then unravel it |
| 929 | + new_out = new_out.reshape((*x_shape[:-2], x_shape[-2], y_shape[-1])) |
| 930 | + |
| 931 | + copy_stack_trace(node.outputs, [new_out]) |
| 932 | + return [new_out] |
0 commit comments