Skip to content

Commit ddbbed7

Browse files
committed
test: check that normalizing flows are reproducible
1 parent f37e0ff commit ddbbed7

File tree

9 files changed

+584
-60
lines changed

9 files changed

+584
-60
lines changed

.github/workflows/ci.yml

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,13 @@ jobs:
5656
python3 -m venv .venv
5757
source .venv/bin/activate
5858
uv pip install 'nutpie[stan]' --find-links dist --force-reinstall
59-
uv pip install pytest pytest-timeout
60-
pytest -m "stan and not flow"
59+
uv pip install pytest pytest-timeout pytest-arraydiff
60+
pytest -m "stan and not flow" --arraydiff
6161
uv pip install 'nutpie[pymc]' --find-links dist --force-reinstall
6262
uv pip install jax
63-
pytest -m "pymc and not flow"
63+
pytest -m "pymc and not flow" --arraydiff
6464
uv pip install 'nutpie[all]' --find-links dist --force-reinstall
65-
pytest -m flow
65+
pytest -m flow --arraydiff
6666
6767
# pyarrow doesn't currently seem to work on musllinux
6868
#musllinux:
@@ -183,13 +183,13 @@ jobs:
183183
python3 -m venv .venv
184184
source .venv/Scripts/activate
185185
uv pip install "nutpie[stan]" --find-links dist --force-reinstall
186-
uv pip install pytest pytest-timeout
187-
pytest -m "stan and not flow"
186+
uv pip install pytest pytest-timeout pytest-arraydiff
187+
pytest -m "stan and not flow" --arraydiff
188188
uv pip install "nutpie[pymc]" --find-links dist --force-reinstall
189189
uv pip install jax
190-
pytest -m "pymc and not flow"
190+
pytest -m "pymc and not flow" --arraydiff
191191
uv pip install "nutpie[all]" --find-links dist --force-reinstall
192-
pytest -m flow
192+
pytest -m flow --arraydiff
193193
194194
macos:
195195
runs-on: ${{ matrix.platform.runner }}
@@ -232,13 +232,13 @@ jobs:
232232
python3 -m venv .venv
233233
source .venv/bin/activate
234234
uv pip install 'nutpie[stan]' --find-links dist --force-reinstall
235-
uv pip install pytest pytest-timeout
236-
pytest -m "stan and not flow"
235+
uv pip install pytest pytest-timeout pytest-arraydiff
236+
pytest -m "stan and not flow" --arraydiff
237237
uv pip install 'nutpie[pymc]' --find-links dist --force-reinstall
238238
uv pip install jax
239-
pytest -m "pymc and not flow"
239+
pytest -m "pymc and not flow" --arraydiff
240240
uv pip install 'nutpie[all]' --find-links dist --force-reinstall
241-
pytest -m flow
241+
pytest -m flow --arraydiff
242242
sdist:
243243
runs-on: ubuntu-latest
244244
steps:

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ name = "nutpie"
77
description = "Sample Stan or PyMC models"
88
authors = [{ name = "PyMC Developers", email = "[email protected]" }]
99
readme = "README.md"
10-
requires-python = ">=3.10,<3.14"
10+
requires-python = ">=3.10"
1111
license = { text = "MIT" }
1212
classifiers = [
1313
"Programming Language :: Rust",
@@ -41,6 +41,7 @@ dev = [
4141
"flowjax >= 17.0.2",
4242
"pytest",
4343
"pytest-timeout",
44+
"pytest-arraydiff",
4445
]
4546
all = [
4647
"bridgestan >= 2.6.1",

python/nutpie/normalizing_flow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1916,7 +1916,7 @@ def make_flow(
19161916
diag = jnp.sqrt(pos_std / grad_std)
19171917
mean = positions.mean(0) + gradients.mean(0) * diag * diag
19181918

1919-
key = jax.random.PRNGKey(seed % (2**63))
1919+
key = jax.random.key(seed % (2**63), impl="threefry2x32")
19201920

19211921
diag_param = Parameterize(
19221922
lambda x: x + jnp.sqrt(1 + x**2),
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
0.941959
2+
0.559649
3+
0.534203
4+
0.561444
5+
0.561444
6+
0.418685
7+
0.827896
8+
0.847014
9+
0.738508
10+
0.961291
11+
0.923931
12+
1.00584
13+
1.16386
14+
1.10065
15+
1.6348
16+
1.13139
17+
0.993458
18+
0.993458
19+
0.966241
20+
1.10922
21+
1.10922
22+
1.05723
23+
1.05723
24+
2.32492
25+
0.0700824
26+
0.0860656
27+
1.36431
28+
0.829624
29+
0.584658
30+
0.531506
31+
0.507961
32+
0.543701
33+
0.510104
34+
2.46898
35+
0.820341
36+
0.490474
37+
0.343958
38+
0.300549
39+
2.60267
40+
0.588131
41+
0.430013
42+
0.618032
43+
1.27527
44+
1.80449
45+
1.80449
46+
0.855217
47+
0.556106
48+
1.77619
49+
2.03761
50+
1.02106
51+
0.774811
52+
1.78438
53+
1.61398
54+
0.712683
55+
1.04966
56+
1.17936
57+
1.5425
58+
1.5425
59+
1.26262
60+
1.39659
61+
0.337024
62+
0.177694
63+
0.0424286
64+
0.180403
65+
0.140553
66+
0.367095
67+
0.348732
68+
0.341436
69+
1.82764
70+
0.692738
71+
0.629186
72+
0.245706
73+
0.732305
74+
0.56873
75+
0.498757
76+
0.204131
77+
0.417031
78+
0.184895
79+
0.208768
80+
0.238139
81+
1.95089
82+
1.95089
83+
0.593379
84+
0.593379
85+
0.750063
86+
0.69929
87+
0.490359
88+
0.478709
89+
0.361632
90+
0.346159
91+
0.728965
92+
1.58228
93+
0.985676
94+
1.58468
95+
0.709012
96+
0.700483
97+
0.805006
98+
1.70347
99+
1.26293
100+
1.24837
101+
0.23989
102+
0.881025
103+
1.39084
104+
1.37812
105+
0.969265
106+
0.969265
107+
0.938487
108+
0.846447
109+
1.61945
110+
0.108473
111+
0.173496
112+
0.897353
113+
0.455899
114+
0.571886
115+
0.891672
116+
0.891672
117+
0.864419
118+
0.739099
119+
1.49009
120+
1.49009
121+
0.385499
122+
0.228701
123+
1.83156
124+
1.83156
125+
0.947635
126+
0.805623
127+
0.714762
128+
0.853477
129+
1.45906
130+
0.908818
131+
0.540951
132+
1.40995
133+
1.22564
134+
0.26496
135+
0.159994
136+
0.423836
137+
0.350158
138+
0.388884
139+
1.39507
140+
0.727701
141+
1.80674
142+
0.466389
143+
1.61574
144+
1.61574
145+
0.42774
146+
0.217983
147+
0.14579
148+
1.01321
149+
1.01321
150+
1.19713
151+
0.390791
152+
0.223687
153+
0.149019
154+
0.103866
155+
0.153768
156+
0.12942
157+
0.346371
158+
0.814553
159+
2.41042
160+
0.42739
161+
0.322291
162+
0.248911
163+
0.854404
164+
1.35372
165+
1.35372
166+
2.00546
167+
0.0457881
168+
0.0415644
169+
0.0797551
170+
0.0913076
171+
0.070948
172+
0.00993872
173+
0.421448
174+
0.550377
175+
0.609387
176+
0.490487
177+
2.6607
178+
0.32804
179+
0.385999
180+
0.497294
181+
1.67109
182+
1.14328
183+
1.14328
184+
0.903063
185+
0.903063
186+
0.903063
187+
0.691269
188+
2.00151
189+
0.587672
190+
0.79679
191+
1.35563
192+
0.598471
193+
0.681826
194+
0.818296
195+
1.14265
196+
0.113094
197+
0.250861
198+
0.284491
199+
0.00420445
200+
0.00566936

0 commit comments

Comments
 (0)