|
7 | 7 | import numpy as np |
8 | 8 | import scipy.linalg as scipy_linalg |
9 | 9 | from numpy.exceptions import ComplexWarning |
10 | | -from scipy.linalg._misc import _datacopied |
11 | 10 |
|
12 | 11 | import pytensor |
13 | 12 | import pytensor.tensor as pt |
@@ -65,63 +64,55 @@ def make_node(self, x): |
65 | 64 | dtype = scipy_linalg.cholesky(np.eye(1, dtype=x.type.dtype)).dtype |
66 | 65 | return Apply(self, [x], [tensor(shape=x.type.shape, dtype=dtype)]) |
67 | 66 |
|
68 | | - def _cholesky( |
69 | | - self, a, lower=False, overwrite_a=False, clean=True, check_finite=False |
70 | | - ): |
71 | | - a1 = np.asarray_chkfinite(a) if check_finite else np.asarray(a) |
| 67 | + def perform(self, node, inputs, outputs): |
| 68 | + [x] = inputs |
| 69 | + [out] = outputs |
| 70 | + |
| 71 | + # Quick return for square empty array |
| 72 | + if x.size == 0: |
| 73 | + eye = np.eye(1, dtype=x.dtype) |
| 74 | + (potrf,) = scipy_linalg.get_lapack_funcs(("potrf",), (eye,)) |
| 75 | + c, _ = potrf(eye, lower=False, overwrite_a=False, clean=True) |
| 76 | + out[0] = np.empty_like(x, dtype=c.dtype) |
| 77 | + return |
| 78 | + |
| 79 | + x1 = np.asarray_chkfinite(x) if self.check_finite else x |
72 | 80 |
|
73 | 81 | # Squareness check |
74 | | - if a1.shape[0] != a1.shape[1]: |
| 82 | + if x1.shape[0] != x1.shape[1]: |
75 | 83 | raise ValueError( |
76 | 84 | "Input array is expected to be square but has " |
77 | | - f"the shape: {a1.shape}." |
| 85 | + f"the shape: {x1.shape}." |
78 | 86 | ) |
79 | 87 |
|
80 | | - # Quick return for square empty array |
81 | | - if a1.size == 0: |
82 | | - dt = self._cholesky(np.eye(1, dtype=a1.dtype)).dtype |
83 | | - return np.empty_like(a1, dtype=dt), lower |
84 | | - |
85 | | - overwrite_a = overwrite_a or _datacopied(a1, a) |
86 | | - (potrf,) = scipy_linalg.get_lapack_funcs(("potrf",), (a1,)) |
87 | | - c, info = potrf(a1, lower=lower, overwrite_a=overwrite_a, clean=clean) |
88 | | - if info > 0: |
89 | | - raise scipy_linalg.LinAlgError( |
90 | | - f"{info}-th leading minor of the array is not positive definite" |
91 | | - ) |
92 | | - if info < 0: |
93 | | - raise ValueError( |
94 | | - f"LAPACK reported an illegal value in {-info}-th argument " |
95 | | - f'on entry to "POTRF".' |
96 | | - ) |
97 | | - return c |
| 88 | + # Scipy cholesky only makes use of overwrite_a when it is F_CONTIGUOUS |
| 89 | + # If we have a `C_CONTIGUOUS` array we transpose to benefit from it |
| 90 | + if self.overwrite_a and x.flags["C_CONTIGUOUS"]: |
| 91 | + x1 = x1.T |
| 92 | + lower = not self.lower |
| 93 | + overwrite_a = True |
| 94 | + else: |
| 95 | + lower = self.lower |
| 96 | + overwrite_a = self.overwrite_a |
98 | 97 |
|
99 | | - def perform(self, node, inputs, outputs): |
100 | | - [x] = inputs |
101 | | - [out] = outputs |
102 | | - try: |
103 | | - # Scipy cholesky only makes use of overwrite_a when it is F_CONTIGUOUS |
104 | | - # If we have a `C_CONTIGUOUS` array we transpose to benefit from it |
105 | | - if self.overwrite_a and x.flags["C_CONTIGUOUS"]: |
106 | | - out[0] = self._cholesky( |
107 | | - x.T, |
108 | | - lower=not self.lower, |
109 | | - check_finite=self.check_finite, |
110 | | - overwrite_a=True, |
111 | | - ).T |
112 | | - else: |
113 | | - out[0] = self._cholesky( |
114 | | - x, |
115 | | - lower=self.lower, |
116 | | - check_finite=self.check_finite, |
117 | | - overwrite_a=self.overwrite_a, |
118 | | - ) |
| 98 | + (potrf,) = scipy_linalg.get_lapack_funcs(("potrf",), (x1,)) |
| 99 | + c, info = potrf(x1, lower=lower, overwrite_a=overwrite_a, clean=True) |
119 | 100 |
|
120 | | - except scipy_linalg.LinAlgError: |
121 | | - if self.on_error == "raise": |
122 | | - raise |
123 | | - else: |
| 101 | + if info != 0: |
| 102 | + if self.on_error == "nan": |
124 | 103 | out[0] = np.full(x.shape, np.nan, dtype=node.outputs[0].type.dtype) |
| 104 | + elif info > 0: |
| 105 | + raise scipy_linalg.LinAlgError( |
| 106 | + f"{info}-th leading minor of the array is not positive definite" |
| 107 | + ) |
| 108 | + elif info < 0: |
| 109 | + raise ValueError( |
| 110 | + f"LAPACK reported an illegal value in {-info}-th argument " |
| 111 | + f'on entry to "POTRF".' |
| 112 | + ) |
| 113 | + else: |
| 114 | + # Transpose result if input was transposed |
| 115 | + out[0] = c.T if (self.overwrite_a and x.flags["C_CONTIGUOUS"]) else c |
125 | 116 |
|
126 | 117 | def L_op(self, inputs, outputs, gradients): |
127 | 118 | """ |
|
0 commit comments