Skip to content

Commit fa9fe3a

Browse files
committed
Merge branch 'fix-numpy-random-tests-756-clean' of github.com:mohamed-laarej/malariagen-data-python into GH756-mohamed-laarej-shadow
2 parents 717136d + e7ef120 commit fa9fe3a

14 files changed

+87
-67
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ license = "MIT"
2121

2222
[tool.poetry.dependencies]
2323
python = ">=3.10,<3.13"
24-
numpy = ">=2.2"
24+
numpy = "*"
2525
numba = ">=0.60.0"
2626
llvmlite = "*"
2727
scipy = "*"

tests/anoph/test_cnv_frq.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
check_plot_frequencies_interactive_map,
2020
)
2121

22+
rng = np.random.default_rng(seed=42)
23+
2224

2325
@pytest.fixture
2426
def ag3_sim_api(ag3_sim_fixture):
@@ -97,7 +99,7 @@ def test_gene_cnv_frequencies_with_str_cohorts(
9799
region = random.choice(api.contigs)
98100
all_sample_sets = api.sample_sets()["sample_set"].to_list()
99101
sample_sets = random.choice(all_sample_sets)
100-
min_cohort_size = random.randint(0, 2)
102+
min_cohort_size = rng.integers(0, 2)
101103

102104
# Set up call params.
103105
params = dict(
@@ -302,7 +304,7 @@ def test_gene_cnv_frequencies_with_dict_cohorts(
302304
):
303305
# Pick test parameters at random.
304306
sample_sets = None # all sample sets
305-
min_cohort_size = random.randint(0, 2)
307+
min_cohort_size = rng.integers(0, 2)
306308
region = random.choice(api.contigs)
307309

308310
# Create cohorts by country.
@@ -343,7 +345,7 @@ def test_gene_cnv_frequencies_without_drop_invariant(
343345
# Pick test parameters at random.
344346
all_sample_sets = api.sample_sets()["sample_set"].to_list()
345347
sample_sets = random.choice(all_sample_sets)
346-
min_cohort_size = random.randint(0, 2)
348+
min_cohort_size = rng.integers(0, 2)
347349
region = random.choice(api.contigs)
348350
cohorts = random.choice(["admin1_year", "admin2_month", "country"])
349351

@@ -398,7 +400,7 @@ def test_gene_cnv_frequencies_with_bad_region(
398400
# Pick test parameters at random.
399401
all_sample_sets = api.sample_sets()["sample_set"].to_list()
400402
sample_sets = random.choice(all_sample_sets)
401-
min_cohort_size = random.randint(0, 2)
403+
min_cohort_size = rng.integers(0, 2)
402404
cohorts = random.choice(["admin1_year", "admin2_month", "country"])
403405

404406
# Set up call params.
@@ -718,7 +720,7 @@ def check_gene_cnv_frequencies_advanced(
718720
all_sample_sets = api.sample_sets()["sample_set"].to_list()
719721
sample_sets = random.choice(all_sample_sets)
720722
if min_cohort_size is None:
721-
min_cohort_size = random.randint(0, 2)
723+
min_cohort_size = rng.integers(0, 2)
722724

723725
# Run function under test.
724726
ds = api.gene_cnv_frequencies_advanced(

tests/anoph/test_distance.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
from malariagen_data.anoph import pca_params
1212

1313

14+
rng = np.random.default_rng(seed=42)
15+
16+
1417
@pytest.fixture
1518
def ag3_sim_api(ag3_sim_fixture):
1619
return AnophelesDistanceAnalysis(
@@ -81,7 +84,7 @@ def check_biallelic_diplotype_pairwise_distance(*, api, data_params, metric):
8184
ds = api.biallelic_snp_calls(**data_params)
8285
n_samples = ds.sizes["samples"]
8386
n_snps_available = ds.sizes["variants"]
84-
n_snps = random.randint(4, n_snps_available)
87+
n_snps = rng.integers(4, n_snps_available)
8588

8689
# Run the distance computation.
8790
dist, samples, n_snps_used = api.biallelic_diplotype_pairwise_distances(
@@ -143,7 +146,7 @@ def check_njt(*, api, data_params, metric, algorithm):
143146
ds = api.biallelic_snp_calls(**data_params)
144147
n_samples = ds.sizes["samples"]
145148
n_snps_available = ds.sizes["variants"]
146-
n_snps = random.randint(4, n_snps_available)
149+
n_snps = rng.integers(4, n_snps_available)
147150

148151
# Run the distance computation.
149152
Z, samples, n_snps_used = api.njt(
@@ -232,7 +235,7 @@ def test_plot_njt(fixture, api: AnophelesDistanceAnalysis):
232235
# Check available data.
233236
ds = api.biallelic_snp_calls(**data_params)
234237
n_snps_available = ds.sizes["variants"]
235-
n_snps = random.randint(4, n_snps_available)
238+
n_snps = rng.integers(4, n_snps_available)
236239

237240
# Exercise the function.
238241
for color, symbol in zip(colors, symbols):

tests/anoph/test_frq.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import pytest
22
import plotly.graph_objects as go # type: ignore
3-
3+
import numpy as np
44
import random
55

6+
rng = np.random.default_rng(seed=42)
7+
68

79
def check_plot_frequencies_heatmap(api, frq_df):
810
fig = api.plot_frequencies_heatmap(frq_df, show=False, max_len=None)
@@ -65,7 +67,7 @@ def check_plot_frequencies_time_series_with_areas(api, ds):
6567
# Pick a random area and areas from valid areas.
6668
cohorts_areas = df_cohorts["cohort_area"].dropna().unique().tolist()
6769
area = random.choice(cohorts_areas)
68-
areas = random.sample(cohorts_areas, random.randint(1, len(cohorts_areas)))
70+
areas = random.sample(cohorts_areas, rng.integers(1, len(cohorts_areas)))
6971

7072
# Plot with area.
7173
fig = api.plot_frequencies_time_series(ds, show=False, areas=area)

tests/anoph/test_fst.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from malariagen_data import ag3 as _ag3
1212
from malariagen_data.anoph.fst import AnophelesFstAnalysis
1313

14+
rng = np.random.default_rng(seed=42)
15+
1416

1517
@pytest.fixture
1618
def ag3_sim_api(ag3_sim_fixture):
@@ -91,7 +93,7 @@ def test_fst_gwss(fixture, api: AnophelesFstAnalysis):
9193
cohort1_query=cohort1_query,
9294
cohort2_query=cohort2_query,
9395
site_mask=random.choice(api.site_mask_ids),
94-
window_size=random.randint(10, 50),
96+
window_size=rng.integers(10, 50),
9597
min_cohort_size=1,
9698
)
9799

@@ -131,7 +133,7 @@ def test_average_fst(fixture, api: AnophelesFstAnalysis):
131133
cohort2_query=cohort2_query,
132134
site_mask=random.choice(api.site_mask_ids),
133135
min_cohort_size=1,
134-
n_jack=random.randint(10, 200),
136+
n_jack=rng.integers(10, 200),
135137
)
136138

137139
# Run main gwss function under test.
@@ -229,7 +231,7 @@ def test_pairwise_average_fst_with_str_cohorts(
229231
sample_sets=all_sample_sets,
230232
site_mask=site_mask,
231233
min_cohort_size=1,
232-
n_jack=random.randint(10, 200),
234+
n_jack=rng.integers(10, 200),
233235
)
234236

235237
# Run checks.
@@ -249,7 +251,7 @@ def test_pairwise_average_fst_with_min_cohort_size(fixture, api: AnophelesFstAna
249251
sample_sets=all_sample_sets,
250252
site_mask=site_mask,
251253
min_cohort_size=15,
252-
n_jack=random.randint(10, 200),
254+
n_jack=rng.integers(10, 200),
253255
)
254256

255257
# Run checks.
@@ -270,7 +272,7 @@ def test_pairwise_average_fst_with_dict_cohorts(fixture, api: AnophelesFstAnalys
270272
sample_sets=all_sample_sets,
271273
site_mask=site_mask,
272274
min_cohort_size=1,
273-
n_jack=random.randint(10, 200),
275+
n_jack=rng.integers(10, 200),
274276
)
275277

276278
# Run checks.
@@ -294,7 +296,7 @@ def test_pairwise_average_fst_with_sample_query(fixture, api: AnophelesFstAnalys
294296
sample_query=sample_query,
295297
site_mask=site_mask,
296298
min_cohort_size=1,
297-
n_jack=random.randint(10, 200),
299+
n_jack=rng.integers(10, 200),
298300
)
299301

300302
# Run checks.

tests/anoph/test_g123.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def test_g123_gwss_with_default_sites(fixture, api: AnophelesG123Analysis):
108108
g123_params = dict(
109109
contig=random.choice(api.contigs),
110110
sample_sets=[random.choice(all_sample_sets)],
111-
window_size=random.randint(100, 500),
111+
window_size=rng.integers(100, 500),
112112
min_cohort_size=10,
113113
)
114114

@@ -124,7 +124,7 @@ def test_g123_gwss_with_phased_sites(fixture, api: AnophelesG123Analysis):
124124
contig=random.choice(api.contigs),
125125
sites=random.choice(api.phasing_analysis_ids),
126126
sample_sets=[random.choice(all_sample_sets)],
127-
window_size=random.randint(100, 500),
127+
window_size=rng.integers(100, 500),
128128
min_cohort_size=10,
129129
)
130130

@@ -141,7 +141,7 @@ def test_g123_gwss_with_segregating_sites(fixture, api: AnophelesG123Analysis):
141141
sites="segregating",
142142
site_mask=random.choice(api.site_mask_ids),
143143
sample_sets=[random.choice(all_sample_sets)],
144-
window_size=random.randint(100, 500),
144+
window_size=rng.integers(100, 500),
145145
min_cohort_size=10,
146146
)
147147

@@ -158,7 +158,7 @@ def test_g123_gwss_with_all_sites(fixture, api: AnophelesG123Analysis):
158158
sites="all",
159159
site_mask=None,
160160
sample_sets=[random.choice(all_sample_sets)],
161-
window_size=random.randint(100, 500),
161+
window_size=rng.integers(100, 500),
162162
min_cohort_size=10,
163163
)
164164

@@ -173,7 +173,7 @@ def test_g123_gwss_with_bad_sites(fixture, api: AnophelesG123Analysis):
173173
g123_params = dict(
174174
contig=random.choice(api.contigs),
175175
sample_sets=[random.choice(all_sample_sets)],
176-
window_size=random.randint(100, 500),
176+
window_size=rng.integers(100, 500),
177177
min_cohort_size=10,
178178
sites="foobar",
179179
)
@@ -187,8 +187,8 @@ def test_g123_gwss_with_bad_sites(fixture, api: AnophelesG123Analysis):
187187
def test_g123_calibration(fixture, api: AnophelesG123Analysis):
188188
# Set up test parameters.
189189
all_sample_sets = api.sample_sets()["sample_set"].to_list()
190-
window_sizes = rng.integers(100, 500, size=random.randint(2, 5)).tolist()
191-
window_sizes = sorted([int(x) for x in window_sizes])
190+
window_sizes = rng.integers(100, 500, size=rng.integers(2, 5)).tolist()
191+
window_sizes = sorted(int(window_sizes))
192192
g123_params = dict(
193193
contig=rng.choice(api.contigs),
194194
sites=rng.choice(api.phasing_analysis_ids),

tests/anoph/test_h12.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def test_h12_gwss_with_default_analysis(fixture, api: AnophelesH12Analysis):
176176
h12_params = dict(
177177
contig=random.choice(api.contigs),
178178
sample_sets=[random.choice(all_sample_sets)],
179-
window_size=random.randint(100, 500),
179+
window_size=rng.integers(100, 500),
180180
min_cohort_size=5,
181181
)
182182

@@ -190,7 +190,7 @@ def test_h12_gwss_with_analysis(fixture, api: AnophelesH12Analysis):
190190
all_sample_sets = api.sample_sets()["sample_set"].to_list()
191191
sample_sets = [random.choice(all_sample_sets)]
192192
contig = random.choice(api.contigs)
193-
window_size = random.randint(100, 500)
193+
window_size = rng.integers(100, 500)
194194

195195
for analysis in api.phasing_analysis_ids:
196196
# Check if any samples available for the given phasing analysis.
@@ -244,7 +244,7 @@ def test_h12_gwss_multi_with_default_analysis(fixture, api: AnophelesH12Analysis
244244
h12_params = dict(
245245
contig=random.choice(api.contigs),
246246
sample_sets=all_sample_sets,
247-
window_size=random.randint(100, 500),
247+
window_size=rng.integers(100, 500),
248248
min_cohort_size=1,
249249
cohorts={"cohort1": cohort1_query, "cohort2": cohort2_query},
250250
)
@@ -265,8 +265,8 @@ def test_h12_gwss_multi_with_window_size_dict(fixture, api: AnophelesH12Analysis
265265
contig=random.choice(api.contigs),
266266
sample_sets=all_sample_sets,
267267
window_size={
268-
"cohort1": random.randint(100, 500),
269-
"cohort2": random.randint(100, 500),
268+
"cohort1": rng.integers(100, 500),
269+
"cohort2": rng.integers(100, 500),
270270
},
271271
min_cohort_size=1,
272272
cohorts={"cohort1": cohort1_query, "cohort2": cohort2_query},
@@ -317,7 +317,7 @@ def test_h12_gwss_multi_with_analysis(fixture, api: AnophelesH12Analysis):
317317
analysis=analysis,
318318
contig=contig,
319319
sample_sets=all_sample_sets,
320-
window_size=random.randint(100, 500),
320+
window_size=rng.integers(100, 500),
321321
min_cohort_size=min(n1, n2),
322322
cohorts={"cohort1": cohort1_query, "cohort2": cohort2_query},
323323
)

tests/anoph/test_h1x.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from malariagen_data import ag3 as _ag3
1010
from malariagen_data.anoph.h1x import AnophelesH1XAnalysis, haplotype_joint_frequencies
1111

12+
rng = np.random.default_rng(seed=42)
13+
1214

1315
@pytest.fixture
1416
def ag3_sim_api(ag3_sim_fixture):
@@ -147,7 +149,7 @@ def test_h1x_gwss_with_default_analysis(fixture, api: AnophelesH1XAnalysis):
147149
h1x_params = dict(
148150
contig=random.choice(api.contigs),
149151
sample_sets=all_sample_sets,
150-
window_size=random.randint(100, 500),
152+
window_size=rng.integers(100, 500),
151153
min_cohort_size=1,
152154
cohort1_query=cohort1_query,
153155
cohort2_query=cohort2_query,
@@ -198,7 +200,7 @@ def test_h1x_gwss_with_analysis(fixture, api: AnophelesH1XAnalysis):
198200
analysis=analysis,
199201
contig=contig,
200202
sample_sets=all_sample_sets,
201-
window_size=random.randint(100, 500),
203+
window_size=rng.integers(100, 500),
202204
min_cohort_size=min(n1, n2),
203205
cohort1_query=cohort1_query,
204206
cohort2_query=cohort2_query,

tests/anoph/test_hap_data.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ def test_haplotypes_with_cohort_size_param(
470470
analysis = api.phasing_analysis_ids[0]
471471

472472
# Parametrize over cohort_size.
473-
parametrize_cohort_size = [random.randint(1, 10), random.randint(10, 50), 1_000]
473+
parametrize_cohort_size = [rng.integers(1, 10), rng.integers(10, 50), 1_000]
474474
for cohort_size in parametrize_cohort_size:
475475
check_haplotypes(
476476
fixture=fixture,
@@ -497,8 +497,8 @@ def test_haplotypes_with_min_cohort_size_param(
497497

498498
# Parametrize over min_cohort_size.
499499
parametrize_min_cohort_size = [
500-
random.randint(1, 10),
501-
random.randint(10, 50),
500+
rng.integers(1, 10),
501+
rng.integers(10, 50),
502502
1_000,
503503
]
504504
for min_cohort_size in parametrize_min_cohort_size:
@@ -527,8 +527,8 @@ def test_haplotypes_with_max_cohort_size_param(
527527

528528
# Parametrize over max_cohort_size.
529529
parametrize_max_cohort_size = [
530-
random.randint(1, 10),
531-
random.randint(10, 50),
530+
rng.integers(1, 10),
531+
rng.integers(10, 50),
532532
1_000,
533533
]
534534
for max_cohort_size in parametrize_max_cohort_size:

tests/anoph/test_hap_frq.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
check_plot_frequencies_interactive_map,
1818
)
1919

20+
rng = np.random.default_rng(seed=42)
21+
2022

2123
@pytest.fixture
2224
def ag3_sim_api(ag3_sim_fixture):
@@ -168,7 +170,7 @@ def test_hap_frequencies_with_str_cohorts(
168170
# Pick test parameters at random.
169171
all_sample_sets = api.sample_sets()["sample_set"].to_list()
170172
sample_sets = random.choice(all_sample_sets)
171-
min_cohort_size = random.randint(0, 2)
173+
min_cohort_size = rng.integers(0, 2)
172174
region = fixture.random_region_str()
173175

174176
# Set up call params.
@@ -210,7 +212,7 @@ def test_hap_frequencies_advanced(
210212
):
211213
all_sample_sets = api.sample_sets()["sample_set"].to_list()
212214
sample_sets = random.choice(all_sample_sets)
213-
min_cohort_size = random.randint(0, 2)
215+
min_cohort_size = rng.integers(0, 2)
214216
region = fixture.random_region_str()
215217

216218
# Set up call params.

0 commit comments

Comments
 (0)