Skip to content

Commit 2c600d9

Browse files
committed
Remove array and potrf copies
1 parent 1af5b61 commit 2c600d9

File tree

1 file changed

+16
-13
lines changed

1 file changed

+16
-13
lines changed

pytensor/tensor/slinalg.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -68,35 +68,38 @@ def perform(self, node, inputs, outputs):
6868
[x] = inputs
6969
[out] = outputs
7070

71+
(potrf,) = scipy_linalg.get_lapack_funcs(("potrf",), (x,))
72+
7173
# Quick return for square empty array
7274
if x.size == 0:
73-
eye = np.eye(1, dtype=x.dtype)
74-
(potrf,) = scipy_linalg.get_lapack_funcs(("potrf",), (eye,))
75-
c, _ = potrf(eye, lower=False, overwrite_a=False, clean=True)
76-
out[0] = np.empty_like(x, dtype=c.dtype)
75+
out[0] = np.empty_like(x, dtype=potrf.dtype)
7776
return
7877

79-
x1 = np.asarray_chkfinite(x) if self.check_finite else x
78+
if self.check_finite and not np.isfinite(x).all():
79+
if self.on_error == "nan":
80+
out[0] = np.full(x.shape, np.nan, dtype=node.outputs[0].type.dtype)
81+
return
82+
else:
83+
raise ValueError("array must not contain infs or NaNs")
8084

8185
# Squareness check
82-
if x1.shape[0] != x1.shape[1]:
86+
if x.shape[0] != x.shape[1]:
8387
raise ValueError(
84-
"Input array is expected to be square but has "
85-
f"the shape: {x1.shape}."
88+
"Input array is expected to be square but has " f"the shape: {x.shape}."
8689
)
8790

8891
# Scipy cholesky only makes use of overwrite_a when it is F_CONTIGUOUS
8992
# If we have a `C_CONTIGUOUS` array we transpose to benefit from it
90-
if self.overwrite_a and x.flags["C_CONTIGUOUS"]:
91-
x1 = x1.T
93+
c_contiguous_input = self.overwrite_a and x.flags["C_CONTIGUOUS"]
94+
if c_contiguous_input:
95+
x = x.T
9296
lower = not self.lower
9397
overwrite_a = True
9498
else:
9599
lower = self.lower
96100
overwrite_a = self.overwrite_a
97101

98-
(potrf,) = scipy_linalg.get_lapack_funcs(("potrf",), (x1,))
99-
c, info = potrf(x1, lower=lower, overwrite_a=overwrite_a, clean=True)
102+
c, info = potrf(x, lower=lower, overwrite_a=overwrite_a, clean=True)
100103

101104
if info != 0:
102105
if self.on_error == "nan":
@@ -112,7 +115,7 @@ def perform(self, node, inputs, outputs):
112115
)
113116
else:
114117
# Transpose result if input was transposed
115-
out[0] = c.T if (self.overwrite_a and x.flags["C_CONTIGUOUS"]) else c
118+
out[0] = c.T if c_contiguous_input else c
116119

117120
def L_op(self, inputs, outputs, gradients):
118121
"""

0 commit comments

Comments
 (0)