1
1
import numpy as np
2
2
import pymc as pm
3
3
import pytensor .tensor as pt
4
- import pytest
5
4
6
5
from pymc_experimental .gp .pytensor_gp import GP , ExpQuad
7
6
8
7
9
8
def test_exp_quad ():
10
9
x = pt .arange (3 )[:, None ]
11
10
ls = pt .ones (())
12
- cov = ExpQuad . build_covariance (x , ls ).eval ()
11
+ cov = ExpQuad (x , ls = ls ).eval ()
13
12
expected_distance = np .array ([[0.0 , 1.0 , 4.0 ], [1.0 , 0.0 , 1.0 ], [4.0 , 1.0 , 0.0 ]])
14
13
15
14
np .testing .assert_allclose (cov , np .exp (- 0.5 * expected_distance ))
16
15
17
16
18
- @pytest .fixture (scope = "session" )
19
- def marginal_model ():
17
+ # @pytest.fixture(scope="session")
18
+ def latent_model ():
20
19
with pm .Model () as m :
21
20
X = pm .Data ("X" , np .arange (3 )[:, None ])
22
21
y = np .full (3 , np .pi )
23
22
ls = 1.0
24
- cov = ExpQuad (X , ls )
23
+ cov = ExpQuad (X , ls = ls )
25
24
gp = GP ("gp" , cov = cov )
26
25
27
26
sigma = 1.0
@@ -30,31 +29,147 @@ def marginal_model():
30
29
return m
31
30
32
31
33
- def test_marginal_sigma_rewrites_to_white_noise_cov (marginal_model ):
34
- obs = marginal_model ["obs" ]
32
+ def latent_model_old_API ():
33
+ with pm .Model () as m :
34
+ X = pm .Data ("X" , np .arange (3 )[:, None ])
35
+ y = np .full (3 , np .pi )
36
+ ls = 1.0
37
+ cov = pm .gp .cov .ExpQuad (1 , ls )
38
+ gp_class = pm .gp .Latent (cov_func = cov )
39
+ gp = gp_class .prior ("gp" , X , reparameterize = False )
40
+
41
+ sigma = 1.0
42
+ obs = pm .Normal ("obs" , mu = gp , sigma = sigma , observed = y )
35
43
36
- # TODO: Bring these checks back after we implement marginalization of the GP RV
37
- #
38
- # assert sum(isinstance(var.owner.op, pm.Normal.rv_type)
39
- # for var in ancestors([obs])
40
- # if var.owner is not None) == 1
41
- #
42
- f = pm .compile_pymc ([], obs )
43
- #
44
- # assert not any(isinstance(node.op, pm.Normal.rv_type) for node in f.maker.fgraph.apply_nodes)
44
+ return m , gp_class
45
45
46
- draws = np .stack ([f () for _ in range (10_000 )])
47
- empirical_cov = np .cov (draws .T )
48
46
49
- expected_distance = np .array ([[0.0 , 1.0 , 4.0 ], [1.0 , 0.0 , 1.0 ], [4.0 , 1.0 , 0.0 ]])
47
+ def test_latent_model_prior ():
48
+ m = latent_model ()
49
+ ref_m , _ = latent_model_old_API ()
50
+
51
+ prior = pm .draw (m ["gp" ], draws = 1000 )
52
+ prior_ref = pm .draw (ref_m ["gp" ], draws = 1000 )
53
+
54
+ np .testing .assert_allclose (
55
+ prior .mean (),
56
+ prior_ref .mean (),
57
+ atol = 0.1 ,
58
+ )
59
+
60
+ np .testing .assert_allclose (
61
+ prior .std (),
62
+ prior_ref .std (),
63
+ rtol = 0.1 ,
64
+ )
65
+
66
+
67
+ def test_latent_model_logp ():
68
+ m = latent_model ()
69
+ ip = m .initial_point ()
70
+
71
+ ref_m , _ = latent_model_old_API ()
72
+
73
+ np .testing .assert_allclose (
74
+ m .compile_logp ()(ip ),
75
+ ref_m .compile_logp ()(ip ),
76
+ rtol = 1e-6 ,
77
+ )
78
+
79
+
80
+ import arviz as az
81
+
82
+
83
+ def gp_conditional (model , gp , Xnew , jitter = 1e-6 ):
84
+ def _build_conditional (self , Xnew , f , cov , jitter ):
85
+ X , ls = cov .owner .inputs
86
+
87
+ Kxx = cov
88
+ Kxs = cov .owner .op .build_covariance (X , Xnew , ls = ls )
89
+ Kss = cov .owner .op .build_covariance (Xnew , ls = ls )
90
+
91
+ L = pt .linalg .cholesky (Kxx + pt .eye (X .shape [0 ]) * jitter )
92
+ # TODO: Use cho_solve
93
+ A = pt .linalg .solve_triangular (L , Kxs , lower = True )
94
+ v = pt .linalg .solve_triangular (L , f , lower = True )
95
+
96
+ mu = (A .mT @ v ).T # Vector?
97
+ cov = Kss - (A .mT @ A )
98
+
99
+ return mu , cov
100
+
101
+ with model .copy () as new_m :
102
+ gp = new_m [gp .name ]
103
+ _ , cov = gp .owner .op .dist_params (gp .owner )
104
+ mu_star , cov_star = _build_conditional (None , Xnew , gp , cov , jitter )
105
+ gp_star = pm .MvNormal ("gp_star" , mu_star , cov_star )
106
+ return new_m
107
+
108
+
109
+ def test_latent_model_predict_new_x ():
110
+ rng = np .random .default_rng (0 )
111
+ new_x = np .array ([3 , 4 ])[:, None ]
112
+
113
+ m = latent_model ()
114
+ ref_m , ref_gp_class = latent_model_old_API ()
115
+
116
+ posterior_idata = az .from_dict ({"gp" : rng .normal (np .pi , 1e-3 , size = (4 , 1000 , 2 ))})
117
+
118
+ # with gp_extend_to_new_x(m):
119
+ with gp_conditional (m , m ["gp" ], new_x ):
120
+ pred = (
121
+ pm .sample_posterior_predictive (posterior_idata , var_names = ["gp_star" ])
122
+ .posterior_predictiev ["gp" ]
123
+ .values
124
+ )
125
+
126
+ with ref_m :
127
+ gp_star = ref_gp_class .conditional ("gp_star" , Xnew = new_x )
128
+ pred_ref = (
129
+ pm .sample_posterior_predictive (posterior_idata , var_names = ["gp_star" ])
130
+ .posterior_predictive ["gp" ]
131
+ .values
132
+ )
133
+
134
+ np .testing .assert_allclose (
135
+ pred .mean (),
136
+ pred_ref .mean (),
137
+ atol = 0.1 ,
138
+ )
50
139
51
140
np .testing .assert_allclose (
52
- empirical_cov , np .exp (- 0.5 * expected_distance ) + np .eye (3 ), atol = 0.1 , rtol = 0.1
141
+ pred .std (),
142
+ pred_ref .std (),
143
+ rtol = 0.1 ,
53
144
)
54
145
55
146
56
- def test_marginal_gp_logp (marginal_model ):
57
- expected_logps = {"obs" : - 8.8778 }
58
- point_logps = marginal_model .point_logps (round_vals = 4 )
59
- for v1 , v2 in zip (point_logps .values (), expected_logps .values ()):
60
- np .testing .assert_allclose (v1 , v2 , atol = 1e-6 )
147
+ #
148
+ # def test_marginal_sigma_rewrites_to_white_noise_cov(marginal_model, ):
149
+ # obs = marginal_model["obs"]
150
+ #
151
+ # # TODO: Bring these checks back after we implement marginalization of the GP RV
152
+ # #
153
+ # # assert sum(isinstance(var.owner.op, pm.Normal.rv_type)
154
+ # # for var in ancestors([obs])
155
+ # # if var.owner is not None) == 1
156
+ # #
157
+ # f = pm.compile_pymc([], obs)
158
+ # #
159
+ # # assert not any(isinstance(node.op, pm.Normal.rv_type) for node in f.maker.fgraph.apply_nodes)
160
+ #
161
+ # draws = np.stack([f() for _ in range(10_000)])
162
+ # empirical_cov = np.cov(draws.T)
163
+ #
164
+ # expected_distance = np.array([[0.0, 1.0, 4.0], [1.0, 0.0, 1.0], [4.0, 1.0, 0.0]])
165
+ #
166
+ # np.testing.assert_allclose(
167
+ # empirical_cov, np.exp(-0.5 * expected_distance) + np.eye(3), atol=0.1, rtol=0.1
168
+ # )
169
+ #
170
+ #
171
+ # def test_marginal_gp_logp(marginal_model):
172
+ # expected_logps = {"obs": -8.8778}
173
+ # point_logps = marginal_model.point_logps(round_vals=4)
174
+ # for v1, v2 in zip(point_logps.values(), expected_logps.values()):
175
+ # np.testing.assert_allclose(v1, v2, atol=1e-6)
0 commit comments