File tree Expand file tree Collapse file tree 5 files changed +15
-15
lines changed Expand file tree Collapse file tree 5 files changed +15
-15
lines changed Original file line number Diff line number Diff line change 34
34
from aesara .scalar import UnaryScalarOp , upgrade_to_float_no_complex
35
35
from aesara .tensor import gammaln
36
36
from aesara .tensor .elemwise import Elemwise
37
- from aesara .tensor .slinalg import Cholesky
38
- from aesara .tensor .slinalg import solve_lower_triangular as solve_lower
39
- from aesara .tensor .slinalg import solve_upper_triangular as solve_upper
37
+ from aesara .tensor .slinalg import Cholesky , SolveTriangular
40
38
41
39
from pymc .aesaraf import floatX
42
40
from pymc .distributions .shape_utils import to_tuple
43
41
42
+ solve_lower = SolveTriangular (lower = True )
43
+ solve_upper = SolveTriangular (lower = False )
44
+
44
45
f = floatX
45
46
c = - 0.5 * np .log (2.0 * np .pi )
46
47
_beta_clip_values = {
Original file line number Diff line number Diff line change 33
33
from aesara .tensor .random .basic import dirichlet , multinomial , multivariate_normal
34
34
from aesara .tensor .random .op import RandomVariable , default_supp_shape_from_params
35
35
from aesara .tensor .random .utils import broadcast_params , normalize_size_param
36
- from aesara .tensor .slinalg import Cholesky
37
- from aesara .tensor .slinalg import solve_lower_triangular as solve_lower
38
- from aesara .tensor .slinalg import solve_upper_triangular as solve_upper
36
+ from aesara .tensor .slinalg import Cholesky , SolveTriangular
39
37
from aesara .tensor .type import TensorType
40
38
from scipy import linalg , stats
41
39
79
77
"StickBreakingWeights" ,
80
78
]
81
79
80
+ solve_lower = SolveTriangular (lower = True )
81
+ solve_upper = SolveTriangular (lower = False )
82
+
82
83
83
84
class SimplexContinuous (Continuous ):
84
85
"""Base class for simplex continuous distributions"""
Original file line number Diff line number Diff line change 19
19
20
20
from aesara .compile import SharedVariable
21
21
from aesara .tensor .slinalg import ( # noqa: W0611; pylint: disable=unused-import
22
+ SolveTriangular ,
22
23
cholesky ,
23
24
solve ,
24
25
)
25
- from aesara .tensor .slinalg import ( # noqa: W0611; pylint: disable=unused-import
26
- solve_lower_triangular as solve_lower ,
27
- )
28
- from aesara .tensor .slinalg import ( # noqa: W0611; pylint: disable=unused-import
29
- solve_upper_triangular as solve_upper ,
30
- )
31
26
from aesara .tensor .var import TensorConstant
32
27
from scipy .cluster .vq import kmeans
33
28
41
36
42
37
JITTER_DEFAULT = 1e-6
43
38
39
+ solve_lower = SolveTriangular (lower = True )
40
+ solve_upper = SolveTriangular (lower = False )
41
+
44
42
45
43
def replace_with_values (vars_needed , replacements = None , model = None ):
46
44
R"""
Original file line number Diff line number Diff line change @@ -230,8 +230,8 @@ def kron_vector_op(v):
230
230
231
231
# Define kronecker functions that work on 1D and 2D arrays
232
232
kron_dot = partial (kron_matrix_op , op = at .dot )
233
- kron_solve_lower = partial (kron_matrix_op , op = at .slinalg .solve_lower_triangular )
234
- kron_solve_upper = partial (kron_matrix_op , op = at .slinalg .solve_upper_triangular )
233
+ kron_solve_lower = partial (kron_matrix_op , op = at .slinalg .SolveTriangular ( lower = True ) )
234
+ kron_solve_upper = partial (kron_matrix_op , op = at .slinalg .SolveTriangular ( lower = False ) )
235
235
236
236
237
237
def flat_outer (a , b ):
Original file line number Diff line number Diff line change @@ -116,7 +116,7 @@ def test_kron_solve_lower():
116
116
x = np .random .rand (tot_size ).reshape ((tot_size , 1 ))
117
117
# Construct entire kronecker product then solve
118
118
big = kronecker (* Ls )
119
- slow_ans = at .slinalg .solve_lower_triangular (big , x )
119
+ slow_ans = at .slinalg .solve_triangular (big , x , lower = True )
120
120
# Use tricks to avoid construction of entire kronecker product
121
121
fast_ans = kron_solve_lower (Ls , x )
122
122
np .testing .assert_array_almost_equal (slow_ans .eval (), fast_ans .eval ())
You can’t perform that action at this time.
0 commit comments