Skip to content

Commit e7ef120

Browse files
Updates tests to consistently use the seeded NumPy random number
generator (rng) instead of legacy np.random or Python's random module and unpins the NumPy version in pyproject.toml
1 parent 8e9d1a5 commit e7ef120

14 files changed

+87
-67
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ license = "MIT"
2121

2222
[tool.poetry.dependencies]
2323
python = ">=3.10,<3.13"
24-
numpy = "<2.2"
24+
numpy = "*"
2525
numba = ">=0.60.0"
2626
llvmlite = "*"
2727
scipy = "*"

tests/anoph/test_cnv_frq.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
check_plot_frequencies_interactive_map,
2020
)
2121

22+
rng = np.random.default_rng(seed=42)
23+
2224

2325
@pytest.fixture
2426
def ag3_sim_api(ag3_sim_fixture):
@@ -97,7 +99,7 @@ def test_gene_cnv_frequencies_with_str_cohorts(
9799
region = random.choice(api.contigs)
98100
all_sample_sets = api.sample_sets()["sample_set"].to_list()
99101
sample_sets = random.choice(all_sample_sets)
100-
min_cohort_size = random.randint(0, 2)
102+
min_cohort_size = rng.integers(0, 2)
101103

102104
# Set up call params.
103105
params = dict(
@@ -302,7 +304,7 @@ def test_gene_cnv_frequencies_with_dict_cohorts(
302304
):
303305
# Pick test parameters at random.
304306
sample_sets = None # all sample sets
305-
min_cohort_size = random.randint(0, 2)
307+
min_cohort_size = rng.integers(0, 2)
306308
region = random.choice(api.contigs)
307309

308310
# Create cohorts by country.
@@ -343,7 +345,7 @@ def test_gene_cnv_frequencies_without_drop_invariant(
343345
# Pick test parameters at random.
344346
all_sample_sets = api.sample_sets()["sample_set"].to_list()
345347
sample_sets = random.choice(all_sample_sets)
346-
min_cohort_size = random.randint(0, 2)
348+
min_cohort_size = rng.integers(0, 2)
347349
region = random.choice(api.contigs)
348350
cohorts = random.choice(["admin1_year", "admin2_month", "country"])
349351

@@ -398,7 +400,7 @@ def test_gene_cnv_frequencies_with_bad_region(
398400
# Pick test parameters at random.
399401
all_sample_sets = api.sample_sets()["sample_set"].to_list()
400402
sample_sets = random.choice(all_sample_sets)
401-
min_cohort_size = random.randint(0, 2)
403+
min_cohort_size = rng.integers(0, 2)
402404
cohorts = random.choice(["admin1_year", "admin2_month", "country"])
403405

404406
# Set up call params.
@@ -718,7 +720,7 @@ def check_gene_cnv_frequencies_advanced(
718720
all_sample_sets = api.sample_sets()["sample_set"].to_list()
719721
sample_sets = random.choice(all_sample_sets)
720722
if min_cohort_size is None:
721-
min_cohort_size = random.randint(0, 2)
723+
min_cohort_size = rng.integers(0, 2)
722724

723725
# Run function under test.
724726
ds = api.gene_cnv_frequencies_advanced(

tests/anoph/test_distance.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
from malariagen_data.anoph import pca_params
1212

1313

14+
rng = np.random.default_rng(seed=42)
15+
16+
1417
@pytest.fixture
1518
def ag3_sim_api(ag3_sim_fixture):
1619
return AnophelesDistanceAnalysis(
@@ -81,7 +84,7 @@ def check_biallelic_diplotype_pairwise_distance(*, api, data_params, metric):
8184
ds = api.biallelic_snp_calls(**data_params)
8285
n_samples = ds.sizes["samples"]
8386
n_snps_available = ds.sizes["variants"]
84-
n_snps = random.randint(4, n_snps_available)
87+
n_snps = rng.integers(4, n_snps_available)
8588

8689
# Run the distance computation.
8790
dist, samples, n_snps_used = api.biallelic_diplotype_pairwise_distances(
@@ -143,7 +146,7 @@ def check_njt(*, api, data_params, metric, algorithm):
143146
ds = api.biallelic_snp_calls(**data_params)
144147
n_samples = ds.sizes["samples"]
145148
n_snps_available = ds.sizes["variants"]
146-
n_snps = random.randint(4, n_snps_available)
149+
n_snps = rng.integers(4, n_snps_available)
147150

148151
# Run the distance computation.
149152
Z, samples, n_snps_used = api.njt(
@@ -232,7 +235,7 @@ def test_plot_njt(fixture, api: AnophelesDistanceAnalysis):
232235
# Check available data.
233236
ds = api.biallelic_snp_calls(**data_params)
234237
n_snps_available = ds.sizes["variants"]
235-
n_snps = random.randint(4, n_snps_available)
238+
n_snps = rng.integers(4, n_snps_available)
236239

237240
# Exercise the function.
238241
for color, symbol in zip(colors, symbols):

tests/anoph/test_frq.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import pytest
22
import plotly.graph_objects as go # type: ignore
3-
3+
import numpy as np
44
import random
55

6+
rng = np.random.default_rng(seed=42)
7+
68

79
def check_plot_frequencies_heatmap(api, frq_df):
810
fig = api.plot_frequencies_heatmap(frq_df, show=False, max_len=None)
@@ -65,7 +67,7 @@ def check_plot_frequencies_time_series_with_areas(api, ds):
6567
# Pick a random area and areas from valid areas.
6668
cohorts_areas = df_cohorts["cohort_area"].dropna().unique().tolist()
6769
area = random.choice(cohorts_areas)
68-
areas = random.sample(cohorts_areas, random.randint(1, len(cohorts_areas)))
70+
areas = random.sample(cohorts_areas, rng.integers(1, len(cohorts_areas)))
6971

7072
# Plot with area.
7173
fig = api.plot_frequencies_time_series(ds, show=False, areas=area)

tests/anoph/test_fst.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from malariagen_data import ag3 as _ag3
1212
from malariagen_data.anoph.fst import AnophelesFstAnalysis
1313

14+
rng = np.random.default_rng(seed=42)
15+
1416

1517
@pytest.fixture
1618
def ag3_sim_api(ag3_sim_fixture):
@@ -91,7 +93,7 @@ def test_fst_gwss(fixture, api: AnophelesFstAnalysis):
9193
cohort1_query=cohort1_query,
9294
cohort2_query=cohort2_query,
9395
site_mask=random.choice(api.site_mask_ids),
94-
window_size=random.randint(10, 50),
96+
window_size=rng.integers(10, 50),
9597
min_cohort_size=1,
9698
)
9799

@@ -131,7 +133,7 @@ def test_average_fst(fixture, api: AnophelesFstAnalysis):
131133
cohort2_query=cohort2_query,
132134
site_mask=random.choice(api.site_mask_ids),
133135
min_cohort_size=1,
134-
n_jack=random.randint(10, 200),
136+
n_jack=rng.integers(10, 200),
135137
)
136138

137139
# Run main gwss function under test.
@@ -229,7 +231,7 @@ def test_pairwise_average_fst_with_str_cohorts(
229231
sample_sets=all_sample_sets,
230232
site_mask=site_mask,
231233
min_cohort_size=1,
232-
n_jack=random.randint(10, 200),
234+
n_jack=rng.integers(10, 200),
233235
)
234236

235237
# Run checks.
@@ -249,7 +251,7 @@ def test_pairwise_average_fst_with_min_cohort_size(fixture, api: AnophelesFstAna
249251
sample_sets=all_sample_sets,
250252
site_mask=site_mask,
251253
min_cohort_size=15,
252-
n_jack=random.randint(10, 200),
254+
n_jack=rng.integers(10, 200),
253255
)
254256

255257
# Run checks.
@@ -270,7 +272,7 @@ def test_pairwise_average_fst_with_dict_cohorts(fixture, api: AnophelesFstAnalys
270272
sample_sets=all_sample_sets,
271273
site_mask=site_mask,
272274
min_cohort_size=1,
273-
n_jack=random.randint(10, 200),
275+
n_jack=rng.integers(10, 200),
274276
)
275277

276278
# Run checks.
@@ -294,7 +296,7 @@ def test_pairwise_average_fst_with_sample_query(fixture, api: AnophelesFstAnalys
294296
sample_query=sample_query,
295297
site_mask=site_mask,
296298
min_cohort_size=1,
297-
n_jack=random.randint(10, 200),
299+
n_jack=rng.integers(10, 200),
298300
)
299301

300302
# Run checks.

tests/anoph/test_g123.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def test_g123_gwss_with_default_sites(fixture, api: AnophelesG123Analysis):
108108
g123_params = dict(
109109
contig=random.choice(api.contigs),
110110
sample_sets=[random.choice(all_sample_sets)],
111-
window_size=random.randint(100, 500),
111+
window_size=rng.integers(100, 500),
112112
min_cohort_size=10,
113113
)
114114

@@ -124,7 +124,7 @@ def test_g123_gwss_with_phased_sites(fixture, api: AnophelesG123Analysis):
124124
contig=random.choice(api.contigs),
125125
sites=random.choice(api.phasing_analysis_ids),
126126
sample_sets=[random.choice(all_sample_sets)],
127-
window_size=random.randint(100, 500),
127+
window_size=rng.integers(100, 500),
128128
min_cohort_size=10,
129129
)
130130

@@ -141,7 +141,7 @@ def test_g123_gwss_with_segregating_sites(fixture, api: AnophelesG123Analysis):
141141
sites="segregating",
142142
site_mask=random.choice(api.site_mask_ids),
143143
sample_sets=[random.choice(all_sample_sets)],
144-
window_size=random.randint(100, 500),
144+
window_size=rng.integers(100, 500),
145145
min_cohort_size=10,
146146
)
147147

@@ -158,7 +158,7 @@ def test_g123_gwss_with_all_sites(fixture, api: AnophelesG123Analysis):
158158
sites="all",
159159
site_mask=None,
160160
sample_sets=[random.choice(all_sample_sets)],
161-
window_size=random.randint(100, 500),
161+
window_size=rng.integers(100, 500),
162162
min_cohort_size=10,
163163
)
164164

@@ -173,7 +173,7 @@ def test_g123_gwss_with_bad_sites(fixture, api: AnophelesG123Analysis):
173173
g123_params = dict(
174174
contig=random.choice(api.contigs),
175175
sample_sets=[random.choice(all_sample_sets)],
176-
window_size=random.randint(100, 500),
176+
window_size=rng.integers(100, 500),
177177
min_cohort_size=10,
178178
sites="foobar",
179179
)
@@ -205,7 +205,7 @@ def extract_ints(item):
205205
def test_g123_calibration(fixture, api: AnophelesG123Analysis):
206206
# Set up test parameters.
207207
all_sample_sets = api.sample_sets()["sample_set"].to_list()
208-
window_sizes = rng.integers(100, 500, size=random.randint(2, 5)).tolist()
208+
window_sizes = rng.integers(100, 500, size=rng.integers(2, 5)).tolist()
209209
window_sizes = sorted(ensure_int_list(window_sizes))
210210
g123_params = dict(
211211
contig=rng.choice(api.contigs),

tests/anoph/test_h12.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def extract_ints(item):
125125
def test_h12_calibration(fixture, api: AnophelesH12Analysis):
126126
# Set up test parameters.
127127
all_sample_sets = api.sample_sets()["sample_set"].to_list()
128-
window_sizes = rng.integers(100, 500, size=random.randint(2, 5)).tolist()
128+
window_sizes = rng.integers(100, 500, size=rng.integers(2, 5)).tolist()
129129
# Convert window_sizes to a flattened list of integers
130130
window_sizes = sorted(set(ensure_int_list(window_sizes)))
131131
h12_params = dict(
@@ -194,7 +194,7 @@ def test_h12_gwss_with_default_analysis(fixture, api: AnophelesH12Analysis):
194194
h12_params = dict(
195195
contig=random.choice(api.contigs),
196196
sample_sets=[random.choice(all_sample_sets)],
197-
window_size=random.randint(100, 500),
197+
window_size=rng.integers(100, 500),
198198
min_cohort_size=5,
199199
)
200200

@@ -208,7 +208,7 @@ def test_h12_gwss_with_analysis(fixture, api: AnophelesH12Analysis):
208208
all_sample_sets = api.sample_sets()["sample_set"].to_list()
209209
sample_sets = [random.choice(all_sample_sets)]
210210
contig = random.choice(api.contigs)
211-
window_size = random.randint(100, 500)
211+
window_size = rng.integers(100, 500)
212212

213213
for analysis in api.phasing_analysis_ids:
214214
# Check if any samples available for the given phasing analysis.
@@ -262,7 +262,7 @@ def test_h12_gwss_multi_with_default_analysis(fixture, api: AnophelesH12Analysis
262262
h12_params = dict(
263263
contig=random.choice(api.contigs),
264264
sample_sets=all_sample_sets,
265-
window_size=random.randint(100, 500),
265+
window_size=rng.integers(100, 500),
266266
min_cohort_size=1,
267267
cohorts={"cohort1": cohort1_query, "cohort2": cohort2_query},
268268
)
@@ -283,8 +283,8 @@ def test_h12_gwss_multi_with_window_size_dict(fixture, api: AnophelesH12Analysis
283283
contig=random.choice(api.contigs),
284284
sample_sets=all_sample_sets,
285285
window_size={
286-
"cohort1": random.randint(100, 500),
287-
"cohort2": random.randint(100, 500),
286+
"cohort1": rng.integers(100, 500),
287+
"cohort2": rng.integers(100, 500),
288288
},
289289
min_cohort_size=1,
290290
cohorts={"cohort1": cohort1_query, "cohort2": cohort2_query},
@@ -335,7 +335,7 @@ def test_h12_gwss_multi_with_analysis(fixture, api: AnophelesH12Analysis):
335335
analysis=analysis,
336336
contig=contig,
337337
sample_sets=all_sample_sets,
338-
window_size=random.randint(100, 500),
338+
window_size=rng.integers(100, 500),
339339
min_cohort_size=min(n1, n2),
340340
cohorts={"cohort1": cohort1_query, "cohort2": cohort2_query},
341341
)

tests/anoph/test_h1x.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from malariagen_data import ag3 as _ag3
1010
from malariagen_data.anoph.h1x import AnophelesH1XAnalysis, haplotype_joint_frequencies
1111

12+
rng = np.random.default_rng(seed=42)
13+
1214

1315
@pytest.fixture
1416
def ag3_sim_api(ag3_sim_fixture):
@@ -147,7 +149,7 @@ def test_h1x_gwss_with_default_analysis(fixture, api: AnophelesH1XAnalysis):
147149
h1x_params = dict(
148150
contig=random.choice(api.contigs),
149151
sample_sets=all_sample_sets,
150-
window_size=random.randint(100, 500),
152+
window_size=rng.integers(100, 500),
151153
min_cohort_size=1,
152154
cohort1_query=cohort1_query,
153155
cohort2_query=cohort2_query,
@@ -198,7 +200,7 @@ def test_h1x_gwss_with_analysis(fixture, api: AnophelesH1XAnalysis):
198200
analysis=analysis,
199201
contig=contig,
200202
sample_sets=all_sample_sets,
201-
window_size=random.randint(100, 500),
203+
window_size=rng.integers(100, 500),
202204
min_cohort_size=min(n1, n2),
203205
cohort1_query=cohort1_query,
204206
cohort2_query=cohort2_query,

tests/anoph/test_hap_data.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ def test_haplotypes_with_cohort_size_param(
470470
analysis = api.phasing_analysis_ids[0]
471471

472472
# Parametrize over cohort_size.
473-
parametrize_cohort_size = [random.randint(1, 10), random.randint(10, 50), 1_000]
473+
parametrize_cohort_size = [rng.integers(1, 10), rng.integers(10, 50), 1_000]
474474
for cohort_size in parametrize_cohort_size:
475475
check_haplotypes(
476476
fixture=fixture,
@@ -497,8 +497,8 @@ def test_haplotypes_with_min_cohort_size_param(
497497

498498
# Parametrize over min_cohort_size.
499499
parametrize_min_cohort_size = [
500-
random.randint(1, 10),
501-
random.randint(10, 50),
500+
rng.integers(1, 10),
501+
rng.integers(10, 50),
502502
1_000,
503503
]
504504
for min_cohort_size in parametrize_min_cohort_size:
@@ -527,8 +527,8 @@ def test_haplotypes_with_max_cohort_size_param(
527527

528528
# Parametrize over max_cohort_size.
529529
parametrize_max_cohort_size = [
530-
random.randint(1, 10),
531-
random.randint(10, 50),
530+
rng.integers(1, 10),
531+
rng.integers(10, 50),
532532
1_000,
533533
]
534534
for max_cohort_size in parametrize_max_cohort_size:

tests/anoph/test_hap_frq.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
check_plot_frequencies_interactive_map,
1818
)
1919

20+
rng = np.random.default_rng(seed=42)
21+
2022

2123
@pytest.fixture
2224
def ag3_sim_api(ag3_sim_fixture):
@@ -168,7 +170,7 @@ def test_hap_frequencies_with_str_cohorts(
168170
# Pick test parameters at random.
169171
all_sample_sets = api.sample_sets()["sample_set"].to_list()
170172
sample_sets = random.choice(all_sample_sets)
171-
min_cohort_size = random.randint(0, 2)
173+
min_cohort_size = rng.integers(0, 2)
172174
region = fixture.random_region_str()
173175

174176
# Set up call params.
@@ -210,7 +212,7 @@ def test_hap_frequencies_advanced(
210212
):
211213
all_sample_sets = api.sample_sets()["sample_set"].to_list()
212214
sample_sets = random.choice(all_sample_sets)
213-
min_cohort_size = random.randint(0, 2)
215+
min_cohort_size = rng.integers(0, 2)
214216
region = fixture.random_region_str()
215217

216218
# Set up call params.

0 commit comments

Comments
 (0)