13
13
# limitations under the License.
14
14
15
15
16
- import aesara .tensor as at
17
16
import numpy as np
18
17
import pymc as pm
18
+ import pytensor .tensor as pt
19
19
from pymc .gp .util import JITTER_DEFAULT , cholesky , solve_lower , solve_upper , stabilize
20
20
21
21
@@ -42,7 +42,7 @@ def _build_prior(self, name, X, Xu, jitter=JITTER_DEFAULT, **kwargs):
42
42
u = pm .Deterministic (name + "_u" , L @ v )
43
43
44
44
Kfu = self .cov_func (X , Xu )
45
- Kuuiu = solve_upper (at .transpose (L ), solve_lower (L , u ))
45
+ Kuuiu = solve_upper (pt .transpose (L ), solve_lower (L , u ))
46
46
47
47
return pm .Deterministic (name , mu + Kfu @ Kuuiu ), Kuuiu , L
48
48
@@ -62,8 +62,8 @@ def prior(self, name, X, Xu=None, jitter=JITTER_DEFAULT, **kwargs):
62
62
def _build_conditional (self , name , Xnew , Xu , L , Kuuiu , jitter , ** kwargs ):
63
63
Ksu = self .cov_func (Xnew , Xu )
64
64
mu = self .mean_func (Xnew ) + Ksu @ Kuuiu
65
- tmp = solve_lower (L , at .transpose (Ksu ))
66
- Qss = at .transpose (tmp ) @ tmp # Qss = tt.dot(tt.dot(Ksu, tt.nlinalg.pinv(Kuu)), Ksu.T)
65
+ tmp = solve_lower (L , pt .transpose (Ksu ))
66
+ Qss = pt .transpose (tmp ) @ tmp # Qss = tt.dot(tt.dot(Ksu, tt.nlinalg.pinv(Kuu)), Ksu.T)
67
67
Kss = self .cov_func (Xnew )
68
68
Lss = cholesky (stabilize (Kss - Qss , jitter ))
69
69
return mu , Lss
@@ -100,42 +100,42 @@ def prior(self, name, X, **kwargs):
100
100
return f
101
101
102
102
def _generate_basis (self , X , L ):
103
- indices = at .arange (1 , self .M + 1 )
104
- m1 = (np .pi / (2.0 * L )) * at .tile (L + X , self .M )
105
- m2 = at .diag (indices )
106
- Phi = at .sin (m1 @ m2 ) / at .sqrt (L )
103
+ indices = pt .arange (1 , self .M + 1 )
104
+ m1 = (np .pi / (2.0 * L )) * pt .tile (L + X , self .M )
105
+ m2 = pt .diag (indices )
106
+ Phi = pt .sin (m1 @ m2 ) / pt .sqrt (L )
107
107
omega = (np .pi * indices ) / (2.0 * L )
108
108
return Phi , omega
109
109
110
110
def _build_prior (self , name , X , ** kwargs ):
111
111
n_obs = np .shape (X )[0 ]
112
112
113
113
# standardize input scale
114
- X = at .as_tensor_variable (X )
115
- Xmu = at .mean (X , axis = 0 )
116
- Xsd = at .std (X , axis = 0 )
114
+ X = pt .as_tensor_variable (X )
115
+ Xmu = pt .mean (X , axis = 0 )
116
+ Xsd = pt .std (X , axis = 0 )
117
117
Xz = (X - Xmu ) / Xsd
118
118
119
119
# define L using Xz and c
120
- La = at .abs (at .min (Xz )) # .eval()?
121
- Lb = at .max (Xz )
122
- L = self .c * at .max ([La , Lb ])
120
+ La = pt .abs (pt .min (Xz )) # .eval()?
121
+ Lb = pt .max (Xz )
122
+ L = self .c * pt .max ([La , Lb ])
123
123
124
124
# make basis and omega, spectral density
125
125
Phi , omega = self ._generate_basis (Xz , L )
126
126
scale , ls , spectral_density = self ._validate_cov_func (self .cov_func )
127
127
spd = scale * spectral_density (omega , ls / Xsd ).flatten ()
128
128
129
129
beta = pm .Normal (f"{ name } _coeffs_" , size = self .M )
130
- f = pm .Deterministic (name , self .mean_func (X ) + at .dot (Phi * at .sqrt (spd ), beta ))
130
+ f = pm .Deterministic (name , self .mean_func (X ) + pt .dot (Phi * pt .sqrt (spd ), beta ))
131
131
return f , Phi , L , spd , beta , Xmu , Xsd
132
132
133
133
def _build_conditional (self , Xnew , Xmu , Xsd , L , beta ):
134
134
Xnewz = (Xnew - Xmu ) / Xsd
135
135
Phi , omega = self ._generate_basis (Xnewz , L )
136
136
scale , ls , spectral_density = self ._validate_cov_func (self .cov_func )
137
137
spd = scale * spectral_density (omega , ls / Xsd ).flatten ()
138
- return self .mean_func (Xnew ) + at .dot (Phi * at .sqrt (spd ), beta )
138
+ return self .mean_func (Xnew ) + pt .dot (Phi * pt .sqrt (spd ), beta )
139
139
140
140
def conditional (self , name , Xnew ):
141
141
# warn about extrapolation
@@ -147,15 +147,15 @@ class ExpQuad(pm.gp.cov.ExpQuad):
147
147
@staticmethod
148
148
def spectral_density (omega , ls ):
149
149
# univariate spectral denisty, implement multi
150
- return at .sqrt (2 * np .pi ) * ls * at .exp (- 0.5 * ls ** 2 * omega ** 2 )
150
+ return pt .sqrt (2 * np .pi ) * ls * pt .exp (- 0.5 * ls ** 2 * omega ** 2 )
151
151
152
152
153
153
class Matern52 (pm .gp .cov .Matern52 ):
154
154
@staticmethod
155
155
def spectral_density (omega , ls ):
156
156
# univariate spectral denisty, implement multi
157
157
# https://arxiv.org/pdf/1611.06740.pdf
158
- lam = at .sqrt (5 ) * (1.0 / ls )
158
+ lam = pt .sqrt (5 ) * (1.0 / ls )
159
159
return (16.0 / 3.0 ) * lam ** 5 * (1.0 / (lam ** 2 + omega ** 2 ) ** 3 )
160
160
161
161
@@ -165,7 +165,7 @@ def spectral_density(omega, ls):
165
165
# univariate spectral denisty, implement multi
166
166
# https://arxiv.org/pdf/1611.06740.pdf
167
167
lam = np .sqrt (3.0 ) * (1.0 / ls )
168
- return 4.0 * lam ** 3 * (1.0 / at .square (lam ** 2 + omega ** 2 ))
168
+ return 4.0 * lam ** 3 * (1.0 / pt .square (lam ** 2 + omega ** 2 ))
169
169
170
170
171
171
class Matern12 (pm .gp .cov .Matern12 ):
@@ -193,7 +193,7 @@ def __init__(
193
193
def _build_prior (self , name , X , jitter = 1e-6 , ** kwargs ):
194
194
mu = self .mean_func (X )
195
195
Kxx = pm .gp .util .stabilize (self .cov_func (X ), jitter )
196
- vals , vecs = at .linalg .eigh (Kxx )
196
+ vals , vecs = pt .linalg .eigh (Kxx )
197
197
## NOTE: REMOVED PRECISION CUTOFF
198
198
if self .variance_limit is None :
199
199
n_eigs = self .n_eigs
@@ -204,7 +204,7 @@ def _build_prior(self, name, X, jitter=1e-6, **kwargs):
204
204
n_eigs = ((vals [::- 1 ].cumsum () / vals .sum ()) > self .variance_limit ).nonzero ()[0 ][0 ]
205
205
U = vecs [:, - n_eigs :]
206
206
s = vals [- n_eigs :]
207
- basis = U * at .sqrt (s )
207
+ basis = U * pt .sqrt (s )
208
208
209
209
coefs_raw = pm .Normal (f"_gp_{ name } _coefs" , mu = 0 , sigma = 1 , size = n_eigs )
210
210
# weight = pm.HalfNormal(f"_gp_{name}_sd")
@@ -222,7 +222,7 @@ def prior(self, name, X, jitter=1e-6, **kwargs):
222
222
def _build_conditional (self , Xnew , X , f , U , s , jitter ):
223
223
Kxs = self .cov_func (X , Xnew )
224
224
Kss = self .cov_func (Xnew )
225
- Kxxpinv = U @ at .diag (1.0 / s ) @ U .T
225
+ Kxxpinv = U @ pt .diag (1.0 / s ) @ U .T
226
226
mus = Kxs .T @ Kxxpinv @ f
227
227
K = Kss - Kxs .T @ Kxxpinv @ Kxs
228
228
L = pm .gp .util .cholesky (pm .gp .util .stabilize (K , jitter ))
0 commit comments