Skip to content

Commit a8a53f3

Browse files
author
Martin Ingram
committed
Respond to comments
1 parent 9ab2e1e commit a8a53f3

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

pymc_extras/inference/deterministic_advi/dadvi.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import xarray
1313

1414
from pymc import join_nonshared_inputs, DictToArrayBijection
15-
from pymc.util import get_default_varnames
15+
from pymc.util import get_default_varnames, RandomSeed
1616
from pymc.backends.arviz import (
1717
apply_function_over_dataset,
1818
PointFunc,
@@ -27,10 +27,10 @@
2727
def fit_deterministic_advi(
2828
model: Optional[Model] = None,
2929
n_fixed_draws: int = 30,
30-
random_seed: int = 2,
30+
random_seed: RandomSeed = None,
3131
n_draws: int = 1000,
3232
keep_untransformed: bool = False,
33-
):
33+
) -> az.InferenceData:
3434
"""
3535
Does inference using deterministic ADVI (automatic differentiation
3636
variational inference).
@@ -101,7 +101,9 @@ def fit_deterministic_advi(
101101
opt_means, opt_log_sds = np.split(opt_var_params, 2)
102102

103103
# Make the draws:
104-
draws_raw = np.random.randn(n_draws, n_params)
104+
generator = np.random.default_rng(seed=random_seed)
105+
draws_raw = generator.standard_normal(size=(n_draws, n_params))
106+
105107
draws = opt_means + draws_raw * np.exp(opt_log_sds)
106108
draws_arviz = unstack_laplace_draws(draws, model, chains=1, draws=n_draws)
107109

@@ -116,7 +118,7 @@ def create_dadvi_graph(
116118
model: Model,
117119
n_params: int,
118120
n_fixed_draws: int = 30,
119-
random_seed: int = 2,
121+
random_seed: RandomSeed = None,
120122
) -> Tuple[TensorVariable, TensorVariable]:
121123
"""
122124
Sets up the DADVI graph in pytensor and returns it.
@@ -143,8 +145,8 @@ def create_dadvi_graph(
143145
"""
144146

145147
# Make the fixed draws
146-
state = np.random.RandomState(random_seed)
147-
draws = state.randn(n_fixed_draws, n_params)
148+
generator = np.random.default_rng(seed=random_seed)
149+
draws = generator.standard_normal(size=(n_fixed_draws, n_params))
148150

149151
inputs = model.continuous_value_vars + model.discrete_value_vars
150152
initial_point_dict = model.initial_point()
@@ -162,7 +164,7 @@ def create_dadvi_graph(
162164

163165
draw_matrix = pt.constant(draws)
164166
samples = means + pt.exp(log_sds) * draw_matrix
165-
167+
166168
logp_vectorized_draws = pytensor.graph.vectorize_graph(
167169
logp, replace={flat_input: samples}
168170
)

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ dependencies = [
4040
"better-optimize>=0.1.5",
4141
"pydantic>=2.0.0",
4242
"preliz>=0.20.0",
43-
"jax>=0.7.0"
4443
]
4544

4645
[project.optional-dependencies]

0 commit comments

Comments
 (0)