Skip to content

Commit a1cdee3

Browse files
Modernize NumPy random functions, fix mypy errors for issue#756
1 parent 39e5d75 commit a1cdee3

12 files changed

+170
-91
lines changed

malariagen_data/anoph/sample_metadata.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,17 @@
11
import io
22
from itertools import cycle
3-
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
3+
from typing import (
4+
Any,
5+
Callable,
6+
Dict,
7+
List,
8+
Mapping,
9+
Optional,
10+
Sequence,
11+
Tuple,
12+
Union,
13+
cast,
14+
)
415

516
import ipyleaflet # type: ignore
617
import numpy as np
@@ -11,6 +22,7 @@
1122
from ..util import check_types
1223
from . import base_params, map_params, plotly_params
1324
from .base import AnophelesBase
25+
from numpy.typing import NDArray
1426

1527

1628
class AnophelesSampleMetadata(AnophelesBase):
@@ -891,8 +903,11 @@ def _prep_sample_selection_cache_params(
891903
# integer indices instead.
892904
df_samples = self.sample_metadata(sample_sets=sample_sets)
893905
sample_query_options = sample_query_options or {}
894-
loc_samples = df_samples.eval(sample_query, **sample_query_options).values
895-
sample_indices = np.nonzero(loc_samples)[0].tolist()
906+
loc_samples = cast(
907+
NDArray[Any],
908+
df_samples.eval(sample_query, **sample_query_options).values,
909+
)
910+
sample_indices = cast(List[int], np.nonzero(loc_samples)[0].tolist())
896911

897912
return sample_sets, sample_indices
898913

malariagen_data/anoph/snp_frq.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from numpydoc_decorator import doc # type: ignore
88
import xarray as xr
99
import numba # type: ignore
10-
1110
from .. import veff
1211
from ..util import (
1312
check_types,
@@ -576,8 +575,8 @@ def snp_allele_frequencies_advanced(
576575
raise ValueError("No SNPs remaining after dropping invariant SNPs.")
577576

578577
df_variants = df_variants.loc[loc_variant].reset_index(drop=True)
579-
count = np.compress(loc_variant, count, axis=0)
580-
nobs = np.compress(loc_variant, nobs, axis=0)
578+
count = np.compress(loc_variant, count, axis=0).reshape(-1, count.shape[1])
579+
nobs = np.compress(loc_variant, nobs, axis=0).reshape(-1, nobs.shape[1])
581580
frequency = np.compress(loc_variant, frequency, axis=0)
582581

583582
# Set up variant effect annotator.

0 commit comments

Comments
 (0)