Skip to content

Commit 3feb26e

Browse files
Fix random number generation and type casting issues
- Replaced all Python random.choice() with rng.choice() for consistency - Replaced random.sample() with rng.choice(..., replace=False) - Added .tolist() to convert NumPy arrays to Python lists where needed - Added str() casting for np.str_ values to ensure Python string compatibility - Fixed 'low >= high' errors in rng.integers() calls by ensuring high > low - Specifically fixed tests/anoph/test_frq.py by changing rng.integers(1, len(cohorts_areas)) to rng.integers(1, len(cohorts_areas)+1) to avoid invalid ranges - Applied int() casting to NumPy integer types where Python int was expected - Fixed site_mask selection to ensure only valid masks are used for each test context Addresses feedback from PR #760 and resolves test failures.
1 parent 4d41538 commit 3feb26e

21 files changed

+561
-410
lines changed

tests/anoph/conftest.py

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import shutil
33
import string
44
from pathlib import Path
5-
from random import choice, choices, randint
65
from typing import Any, Dict, Tuple
76

87
import numpy as np
@@ -40,7 +39,7 @@ def fixture_dir():
4039

4140

4241
def simulate_contig(*, low, high, base_composition):
43-
size = rng.integers(low=low, high=high)
42+
size = int(rng.integers(low=low, high=high))
4443
bases = np.array([b"a", b"c", b"g", b"t", b"n", b"A", b"C", b"G", b"T", b"N"])
4544
p = np.array([base_composition[b] for b in bases])
4645
seq = rng.choice(bases, size=size, replace=True, p=p)
@@ -151,9 +150,9 @@ def simulate_genes(self, *, contig, contig_size):
151150
# Simulate genes.
152151
for gene_ix in range(self.max_genes):
153152
gene_id = f"gene-{contig}-{gene_ix}"
154-
strand = choice(["+", "-"])
155-
inter_size = randint(self.inter_size_low, self.inter_size_high)
156-
gene_size = randint(self.gene_size_low, self.gene_size_high)
153+
strand = rng.choice(["+", "-"])
154+
inter_size = int(rng.integers(self.inter_size_low, self.inter_size_high))
155+
gene_size = int(rng.integers(self.gene_size_low, self.gene_size_high))
157156
if strand == "+":
158157
gene_start = cur_fwd + inter_size
159158
else:
@@ -166,7 +165,11 @@ def simulate_genes(self, *, contig, contig_size):
166165
gene_attrs = f"ID={gene_id}"
167166
for attr in self.attrs:
168167
random_str = "".join(
169-
choices(string.ascii_uppercase + string.digits, k=5)
168+
rng.choice(
169+
list(string.ascii_uppercase + string.digits),
170+
size=5,
171+
replace=True,
172+
)
170173
)
171174
gene_attrs += f";{attr}={random_str}"
172175
gene = (
@@ -212,7 +215,7 @@ def simulate_transcripts(
212215
# accurate in real data.
213216

214217
for transcript_ix in range(
215-
randint(self.n_transcripts_low, self.n_transcripts_high)
218+
int(rng.integers(self.n_transcripts_low, self.n_transcripts_high))
216219
):
217220
transcript_id = f"transcript-{contig}-{gene_ix}-{transcript_ix}"
218221
transcript_start = gene_start
@@ -260,13 +263,16 @@ def simulate_exons(
260263
transcript_size = transcript_end - transcript_start
261264
exons = []
262265
exon_end = transcript_start
263-
n_exons = randint(self.n_exons_low, self.n_exons_high)
266+
n_exons = int(rng.integers(self.n_exons_low, self.n_exons_high))
264267
for exon_ix in range(n_exons):
265268
exon_id = f"exon-{contig}-{gene_ix}-{transcript_ix}-{exon_ix}"
266269
if exon_ix > 0:
267270
# Insert an intron between this exon and the previous one.
268-
intron_size = randint(
269-
self.intron_size_low, min(transcript_size, self.intron_size_high)
271+
intron_size = int(
272+
rng.integers(
273+
self.intron_size_low,
274+
min(transcript_size, self.intron_size_high),
275+
)
270276
)
271277
exon_start = exon_end + intron_size
272278
if exon_start >= transcript_end:
@@ -275,7 +281,7 @@ def simulate_exons(
275281
else:
276282
# First exon, assume exon starts where the transcript starts.
277283
exon_start = transcript_start
278-
exon_size = randint(self.exon_size_low, self.exon_size_high)
284+
exon_size = int(rng.integers(self.exon_size_low, self.exon_size_high))
279285
exon_end = min(exon_start + exon_size, transcript_end)
280286
assert exon_end > exon_start
281287
exon = (
@@ -311,7 +317,7 @@ def simulate_exons(
311317
else:
312318
feature_type = self.cds_type
313319
# Cheat a little, random phase.
314-
phase = choice([1, 2, 3])
320+
phase = rng.choice([1, 2, 3])
315321
feature = (
316322
contig,
317323
self.source,
@@ -549,7 +555,7 @@ def simulate_aim_variants(path, contigs, snp_sites, n_sites_low, n_sites_high):
549555
# Simulate AIM positions variable.
550556
snp_pos = snp_sites[f"{contig}/variants/POS"][:]
551557
loc_aim_sites = rng.choice(
552-
snp_pos.shape[0], size=rng.integers(n_sites_low, n_sites_high)
558+
snp_pos.shape[0], size=int(rng.integers(n_sites_low, n_sites_high))
553559
)
554560
loc_aim_sites.sort()
555561
aim_pos = snp_pos[loc_aim_sites]
@@ -731,11 +737,10 @@ def simulate_cnv_coverage_calls(zarr_path, metadata_path, contigs, contig_sizes)
731737
contig_length_bp = contig_sizes[contig]
732738

733739
# Get a random number of CNV alleles ("variants") to simulate.
734-
n_cnv_alleles = rng.integers(1, 5_000)
740+
n_cnv_alleles = int(rng.integers(1, 5_000))
735741

736742
# Produce a set of random start positions for each allele as a sorted list.
737743
allele_start_pos = sorted(rng.integers(1, contig_length_bp, size=n_cnv_alleles))
738-
739744
# Produce a set of random allele lengths for each allele, according to a range.
740745
allele_length_bp_min = 100
741746
allele_length_bp_max = 100_000
@@ -874,7 +879,7 @@ def simulate_cnv_discordant_read_calls(zarr_path, metadata_path, contigs, contig
874879
contig_length_bp = contig_sizes[contig]
875880

876881
# Get a random number of CNV variants to simulate.
877-
n_cnv_variants = rng.integers(1, 100)
882+
n_cnv_variants = int(rng.integers(1, 100))
878883

879884
# Produce a set of random start positions for each variant as a sorted list.
880885
variant_start_pos = sorted(
@@ -1010,28 +1015,28 @@ def contigs(self) -> Tuple[str, ...]:
10101015
return tuple(self.config["CONTIGS"])
10111016

10121017
def random_contig(self):
1013-
return choice(self.contigs)
1018+
return rng.choice(self.contigs)
10141019

10151020
def random_transcript_id(self):
10161021
df_transcripts = self.genome_features.query("type == 'mRNA'")
10171022
transcript_ids = [
10181023
gff3_parse_attributes(t)["ID"] for t in df_transcripts.loc[:, "attributes"]
10191024
]
1020-
transcript_id = choice(transcript_ids)
1025+
transcript_id = rng.choice(transcript_ids)
10211026
return transcript_id
10221027

10231028
def random_region_str(self, region_size=None):
10241029
contig = self.random_contig()
10251030
contig_size = self.contig_sizes[contig]
1026-
region_start = randint(1, contig_size)
1031+
region_start = int(rng.integers(1, contig_size))
10271032
if region_size:
10281033
# Ensure we the region span doesn't exceed the contig size.
10291034
if contig_size - region_start < region_size:
10301035
region_start = contig_size - region_size
10311036

10321037
region_end = region_start + region_size
10331038
else:
1034-
region_end = randint(region_start, contig_size)
1039+
region_end = int(rng.integers(region_start, contig_size))
10351040
region = f"{contig}:{region_start:,}-{region_end:,}"
10361041
return region
10371042

@@ -1133,7 +1138,7 @@ def init_public_release_manifest(self):
11331138
manifest = pd.DataFrame(
11341139
{
11351140
"sample_set": ["AG1000G-AO", "AG1000G-BF-A"],
1136-
"sample_count": [randint(10, 50), randint(10, 40)],
1141+
"sample_count": [int(rng.integers(10, 50)), int(rng.integers(10, 40))],
11371142
"study_id": ["AG1000G-AO", "AG1000G-BF-1"],
11381143
"study_url": [
11391144
"https://www.malariagen.net/network/where-we-work/AG1000G-AO",
@@ -1165,7 +1170,7 @@ def init_pre_release_manifest(self):
11651170
"1177-VO-ML-LEHMANN-VMF00004",
11661171
],
11671172
# Make sure we have some gambiae, coluzzii and arabiensis.
1168-
"sample_count": [randint(20, 60)],
1173+
"sample_count": [int(rng.integers(20, 60))],
11691174
"study_id": ["1177-VO-ML-LEHMANN"],
11701175
"study_url": [
11711176
"https://www.malariagen.net/network/where-we-work/1177-VO-ML-LEHMANN"

tests/anoph/test_aim_data.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import itertools
2-
import random
3-
42
import plotly.graph_objects as go
53
import pytest
64
import xarray as xr
75
from numpy.testing import assert_array_equal
8-
6+
import numpy as np
97
from malariagen_data import ag3 as _ag3
108
from malariagen_data.anoph.aim_data import AnophelesAimData
119

10+
rng = np.random.default_rng(seed=42)
11+
1212

1313
@pytest.fixture
1414
def ag3_sim_api(ag3_sim_fixture):
@@ -88,9 +88,9 @@ def test_aim_calls(aims, ag3_sim_api):
8888
all_releases = api.releases
8989
parametrize_sample_sets = [
9090
None,
91-
random.choice(all_sample_sets),
92-
random.sample(all_sample_sets, 2),
93-
random.choice(all_releases),
91+
rng.choice(all_sample_sets),
92+
rng.choice(all_sample_sets, 2, replace=False).tolist(),
93+
rng.choice(all_releases),
9494
]
9595

9696
# Parametrize sample_query.
@@ -179,9 +179,9 @@ def test_plot_aim_heatmap(aims, ag3_sim_api):
179179
all_releases = api.releases
180180
parametrize_sample_sets = [
181181
None,
182-
random.choice(all_sample_sets),
183-
random.sample(all_sample_sets, 2),
184-
random.choice(all_releases),
182+
rng.choice(all_sample_sets),
183+
rng.choice(all_sample_sets, 2, replace=False).tolist(),
184+
rng.choice(all_releases),
185185
]
186186

187187
# Parametrize sample_query.

tests/anoph/test_cnv_data.py

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import random
2-
31
import bokeh.models
42
import dask.array as da
53
import numpy as np
@@ -139,14 +137,14 @@ def test_open_cnv_coverage_calls(fixture, api: AnophelesCnvData):
139137
# Check with a sample set that should not exist
140138
with pytest.raises(ValueError):
141139
root = api.open_cnv_coverage_calls(
142-
sample_set="foobar", analysis=random.choice(api.coverage_calls_analysis_ids)
140+
sample_set="foobar", analysis=rng.choice(api.coverage_calls_analysis_ids)
143141
)
144142

145143
# Check with an analysis that should not exist
146144
all_sample_sets = api.sample_sets()["sample_set"].to_list()
147145
with pytest.raises(ValueError):
148146
root = api.open_cnv_coverage_calls(
149-
sample_set=random.choice(all_sample_sets), analysis="foobar"
147+
sample_set=rng.choice(all_sample_sets), analysis="foobar"
150148
)
151149

152150
# Check with a sample set and analysis that should not exist
@@ -346,15 +344,15 @@ def test_cnv_hmm(fixture, api: AnophelesCnvData):
346344
all_sample_sets = api.sample_sets()["sample_set"].to_list()
347345
parametrize_sample_sets = [
348346
None,
349-
random.choice(all_sample_sets),
350-
random.sample(all_sample_sets, 2),
351-
random.choice(all_releases),
347+
rng.choice(all_sample_sets),
348+
rng.choice(all_sample_sets, 2, replace=False).tolist(),
349+
rng.choice(all_releases),
352350
]
353351

354352
# Parametrize region.
355353
parametrize_region = [
356354
fixture.random_contig(),
357-
random.sample(api.contigs, 2),
355+
rng.choice(api.contigs, 2, replace=False).tolist(),
358356
fixture.random_region_str(),
359357
]
360358

@@ -424,7 +422,7 @@ def test_cnv_hmm(fixture, api: AnophelesCnvData):
424422
def test_cnv_hmm__max_coverage_variance(fixture, api: AnophelesCnvData):
425423
# Set up test.
426424
all_sample_sets = api.sample_sets()["sample_set"].to_list()
427-
sample_set = random.choice(all_sample_sets)
425+
sample_set = rng.choice(all_sample_sets)
428426
region = fixture.random_contig()
429427

430428
# Parametrize max_coverage_variance.
@@ -468,15 +466,15 @@ def test_cnv_hmm__max_coverage_variance(fixture, api: AnophelesCnvData):
468466
def test_cnv_coverage_calls(fixture, api: AnophelesCnvData):
469467
# Parametrize sample_sets.
470468
all_sample_sets = api.sample_sets()["sample_set"].to_list()
471-
parametrize_sample_sets = random.sample(all_sample_sets, 3)
469+
parametrize_sample_sets = rng.choice(all_sample_sets, 3, replace=False).tolist()
472470

473471
# Parametrize analysis.
474472
parametrize_analysis = api.coverage_calls_analysis_ids
475473

476474
# Parametrize region.
477475
parametrize_region = [
478476
fixture.random_contig(),
479-
random.sample(api.contigs, 2),
477+
rng.choice(api.contigs, 2, replace=False).tolist(),
480478
fixture.random_region_str(),
481479
]
482480

@@ -554,15 +552,15 @@ def test_cnv_discordant_read_calls(fixture, api: AnophelesCnvData):
554552
all_sample_sets = api.sample_sets()["sample_set"].to_list()
555553
parametrize_sample_sets = [
556554
None,
557-
random.choice(all_sample_sets),
558-
random.sample(all_sample_sets, 2),
559-
random.choice(all_releases),
555+
rng.choice(all_sample_sets),
556+
rng.choice(all_sample_sets, 2, replace=False).tolist(),
557+
rng.choice(all_releases),
560558
]
561559

562560
# Parametrize contig.
563561
parametrize_contig = [
564-
random.choice(api.contigs),
565-
random.sample(api.contigs, 2),
562+
rng.choice(api.contigs),
563+
rng.choice(api.contigs, 2, replace=False).tolist(),
566564
]
567565

568566
for sample_sets in parametrize_sample_sets:
@@ -631,13 +629,13 @@ def test_cnv_discordant_read_calls(fixture, api: AnophelesCnvData):
631629
# Check with a contig that should not exist
632630
with pytest.raises(ValueError):
633631
api.cnv_discordant_read_calls(
634-
contig="foobar", sample_sets=random.choice(all_sample_sets)
632+
contig="foobar", sample_sets=rng.choice(all_sample_sets)
635633
)
636634

637635
# Check with a sample set that should not exist
638636
with pytest.raises(ValueError):
639637
api.cnv_discordant_read_calls(
640-
contig=random.choice(api.contigs), sample_sets="foobar"
638+
contig=rng.choice(api.contigs), sample_sets="foobar"
641639
)
642640

643641
# Check with a contig and sample set that should not exist
@@ -809,7 +807,7 @@ def test_cnv_discordant_read_calls__sample_query_options(
809807
def test_plot_cnv_hmm_coverage_track(fixture, api: AnophelesCnvData):
810808
# Set up test.
811809
all_sample_sets = api.sample_sets()["sample_set"].to_list()
812-
sample_set = random.choice(all_sample_sets)
810+
sample_set = rng.choice(all_sample_sets)
813811
region = fixture.random_contig()
814812
df_samples = api.sample_metadata(sample_sets=sample_set)
815813
all_sample_ids = df_samples["sample_id"].values
@@ -916,9 +914,9 @@ def test_plot_cnv_hmm_heatmap_track(fixture, api: AnophelesCnvData):
916914
all_sample_sets = api.sample_sets()["sample_set"].to_list()
917915
parametrize_sample_sets = [
918916
None,
919-
random.choice(all_sample_sets),
920-
random.sample(all_sample_sets, 2),
921-
random.choice(all_releases),
917+
rng.choice(all_sample_sets),
918+
rng.choice(all_sample_sets, 2, replace=False).tolist(),
919+
rng.choice(all_releases),
922920
]
923921

924922
for region in parametrize_region:

0 commit comments

Comments
 (0)