|
25 | 25 | stack,
|
26 | 26 | switch,
|
27 | 27 | )
|
| 28 | +from pytensor.tensor.blockwise import Blockwise |
28 | 29 | from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise, scalar_elemwise
|
29 | 30 | from pytensor.tensor.shape import shape, specify_broadcastable
|
30 | 31 | from pytensor.tensor.type import (
|
31 | 32 | DenseTensorType,
|
32 |
| - TensorType, |
33 | 33 | complex_dtypes,
|
34 | 34 | continuous_dtypes,
|
35 | 35 | discrete_dtypes,
|
@@ -2868,93 +2868,7 @@ def logsumexp(x, axis=None, keepdims=False):
|
2868 | 2868 | return log(sum(exp(x), axis=axis, keepdims=keepdims))
|
2869 | 2869 |
|
2870 | 2870 |
|
2871 |
| -class MatMul(Op): |
2872 |
| - __props__ = ("dtype",) |
2873 |
| - |
2874 |
| - def __init__(self, dtype=None): |
2875 |
| - self.dtype = dtype |
2876 |
| - |
2877 |
| - @classmethod |
2878 |
| - def _get_output_shape(cls, x1, x2, shapes, validate=False): |
2879 |
| - x1_shape, x2_shape = shapes |
2880 |
| - |
2881 |
| - if x1.ndim == 1 and x2.ndim == 1: |
2882 |
| - if validate and x1_shape[0] != x2_shape[0]: |
2883 |
| - raise ValueError("1d inputs must have the same length.") |
2884 |
| - return () |
2885 |
| - elif x1.ndim == 1 and x2.ndim > 1: |
2886 |
| - if validate and x1_shape[0] != x2_shape[-2]: |
2887 |
| - raise ValueError( |
2888 |
| - "length of input 1 must be equal the length " |
2889 |
| - "of the 2nd-last dimension of input 2" |
2890 |
| - ) |
2891 |
| - return x2_shape[:-2] + x2_shape[-1:] |
2892 |
| - elif x1.ndim > 1 and x2.ndim == 1: |
2893 |
| - if validate and x1_shape[-1] != x2_shape[0]: |
2894 |
| - raise ValueError( |
2895 |
| - "length of input 2 must be equal the length " |
2896 |
| - "of the last dimension of input 1" |
2897 |
| - ) |
2898 |
| - return x1_shape[:-1] |
2899 |
| - elif x1.ndim == 2 and x2.ndim == 2: |
2900 |
| - if validate and x1_shape[-1] != x2_shape[0]: |
2901 |
| - raise ValueError( |
2902 |
| - "number of columns of input 1 must be equal to " |
2903 |
| - "the number of rows of input 2" |
2904 |
| - ) |
2905 |
| - return x1_shape[:-1] + x2_shape[-1:] |
2906 |
| - elif x1.ndim > 2 and x2.ndim == 2: |
2907 |
| - if validate and x1_shape[-1] != x2_shape[0]: |
2908 |
| - raise ValueError( |
2909 |
| - "number of rows of input 2 must be equal to " |
2910 |
| - "the length of the last dimension of input 1" |
2911 |
| - ) |
2912 |
| - return x1_shape[:-2] + x1_shape[-2:-1] + x2_shape[-1:] |
2913 |
| - elif x1.ndim == 2 and x2.ndim > 2: |
2914 |
| - if validate and x1_shape[-1] != x2_shape[-2]: |
2915 |
| - raise ValueError( |
2916 |
| - "number of columns of input 1 must be equal " |
2917 |
| - "the length of the 2nd-last dimension of input 2" |
2918 |
| - ) |
2919 |
| - return x2_shape[:-2] + x1_shape[-2:-1] + x2_shape[-1:] |
2920 |
| - else: |
2921 |
| - if validate: |
2922 |
| - from pytensor.tensor.random.basic import broadcast_shapes |
2923 |
| - |
2924 |
| - bshape = broadcast_shapes(x1_shape[:-2], x2_shape[:-2]) |
2925 |
| - if x1_shape[-1] != x2_shape[-2]: |
2926 |
| - raise ValueError( |
2927 |
| - "length of the last dimension of input 1 must be equal " |
2928 |
| - "to the length of the 2nd-last dimension of input 2" |
2929 |
| - ) |
2930 |
| - else: |
2931 |
| - from pytensor.tensor.extra_ops import broadcast_shape |
2932 |
| - |
2933 |
| - bshape = broadcast_shape( |
2934 |
| - x1_shape[:-2], x2_shape[:-2], arrays_are_shapes=True |
2935 |
| - ) |
2936 |
| - return bshape + x1_shape[-2:-1] + x2_shape[-1:] |
2937 |
| - |
2938 |
| - def make_node(self, a, b): |
2939 |
| - a = as_tensor_variable(a) |
2940 |
| - b = as_tensor_variable(b) |
2941 |
| - |
2942 |
| - if 0 in {a.ndim, b.ndim}: |
2943 |
| - raise ValueError("inputs to `matmul` cannot be scalar.") |
2944 |
| - |
2945 |
| - out_shape = self._get_output_shape( |
2946 |
| - a, b, (a.type.shape, b.type.shape), validate=True |
2947 |
| - ) |
2948 |
| - out = TensorType(dtype=self.dtype, shape=out_shape)() |
2949 |
| - return Apply(self, [a, b], [out]) |
2950 |
| - |
2951 |
| - def perform(self, node, inputs, outputs): |
2952 |
| - x1, x2 = inputs |
2953 |
| - outputs[0][0] = np.matmul(x1, x2, dtype=self.dtype) |
2954 |
| - |
2955 |
| - def infer_shape(self, fgraph, node, shapes): |
2956 |
| - x1, x2 = node.inputs |
2957 |
| - return [self._get_output_shape(x1, x2, shapes)] |
| 2871 | +_matrix_matrix_matmul = Blockwise(_dot, signature="(n,k),(k,m)->(n,m)") |
2958 | 2872 |
|
2959 | 2873 |
|
2960 | 2874 | def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None):
|
@@ -2999,7 +2913,23 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
|
2999 | 2913 | - Stacks of matrices are broadcast together as if the matrices were elements,
|
3000 | 2914 | respecting the signature ``(n, k), (k, m) -> (n, m)``:
|
3001 | 2915 | """
|
3002 |
| - return MatMul(dtype=dtype)(x1, x2) |
| 2916 | + x1 = as_tensor_variable(x1) |
| 2917 | + x2 = as_tensor_variable(x2) |
| 2918 | + if x1.type.ndim == 0 or x2.type.ndim == 0: |
| 2919 | + raise ValueError("matmul operand cannot be scalar") |
| 2920 | + if x1.type.ndim == 1 and x2.type.ndim == 1: |
| 2921 | + out = _dot(x1, x2) |
| 2922 | + elif x1.type.ndim == 1: |
| 2923 | + out = _matrix_matrix_matmul(x1[None], x2).squeeze(-2) |
| 2924 | + elif x2.type.ndim == 1: |
| 2925 | + out = _matrix_matrix_matmul(x1, x2[:, None]).squeeze(-1) |
| 2926 | + else: |
| 2927 | + out = _matrix_matrix_matmul(x1, x2) |
| 2928 | + |
| 2929 | + if dtype is not None: |
| 2930 | + out = out.astype(dtype) |
| 2931 | + |
| 2932 | + return out |
3003 | 2933 |
|
3004 | 2934 |
|
3005 | 2935 | __all__ = [
|
|
0 commit comments