12
12
import xarray
13
13
14
14
from pymc import join_nonshared_inputs , DictToArrayBijection
15
- from pymc .util import get_default_varnames
15
+ from pymc .util import get_default_varnames , RandomSeed
16
16
from pymc .backends .arviz import (
17
17
apply_function_over_dataset ,
18
18
PointFunc ,
27
27
def fit_deterministic_advi (
28
28
model : Optional [Model ] = None ,
29
29
n_fixed_draws : int = 30 ,
30
- random_seed : int = 2 ,
30
+ random_seed : RandomSeed = None ,
31
31
n_draws : int = 1000 ,
32
32
keep_untransformed : bool = False ,
33
- ):
33
+ ) -> az . InferenceData :
34
34
"""
35
35
Does inference using deterministic ADVI (automatic differentiation
36
36
variational inference).
@@ -101,7 +101,9 @@ def fit_deterministic_advi(
101
101
opt_means , opt_log_sds = np .split (opt_var_params , 2 )
102
102
103
103
# Make the draws:
104
- draws_raw = np .random .randn (n_draws , n_params )
104
+ generator = np .random .default_rng (seed = random_seed )
105
+ draws_raw = generator .standard_normal (size = (n_draws , n_params ))
106
+
105
107
draws = opt_means + draws_raw * np .exp (opt_log_sds )
106
108
draws_arviz = unstack_laplace_draws (draws , model , chains = 1 , draws = n_draws )
107
109
@@ -116,7 +118,7 @@ def create_dadvi_graph(
116
118
model : Model ,
117
119
n_params : int ,
118
120
n_fixed_draws : int = 30 ,
119
- random_seed : int = 2 ,
121
+ random_seed : RandomSeed = None ,
120
122
) -> Tuple [TensorVariable , TensorVariable ]:
121
123
"""
122
124
Sets up the DADVI graph in pytensor and returns it.
@@ -143,8 +145,8 @@ def create_dadvi_graph(
143
145
"""
144
146
145
147
# Make the fixed draws
146
- state = np .random .RandomState ( random_seed )
147
- draws = state . randn ( n_fixed_draws , n_params )
148
+ generator = np .random .default_rng ( seed = random_seed )
149
+ draws = generator . standard_normal ( size = ( n_fixed_draws , n_params ) )
148
150
149
151
inputs = model .continuous_value_vars + model .discrete_value_vars
150
152
initial_point_dict = model .initial_point ()
@@ -162,7 +164,7 @@ def create_dadvi_graph(
162
164
163
165
draw_matrix = pt .constant (draws )
164
166
samples = means + pt .exp (log_sds ) * draw_matrix
165
-
167
+
166
168
logp_vectorized_draws = pytensor .graph .vectorize_graph (
167
169
logp , replace = {flat_input : samples }
168
170
)
0 commit comments