Skip to content

Commit d682554

Browse files
committed
feat(perf): faster reverse complementing and option to pass pre-alloc output
1 parent cb1afe5 commit d682554

File tree

6 files changed

+204
-60
lines changed

6 files changed

+204
-60
lines changed

python/seqpro/_numba.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
from typing import Optional, Union, overload
24

35
import numba as nb
@@ -132,3 +134,25 @@ def gufunc_translate(
132134
if (seq_kmers == kmer_keys[i]).all():
133135
res[0] = kmer_values[i] # type: ignore
134136
break
137+
138+
139+
@nb.guvectorize(
140+
["(u1, u1[:], u1[:])"],
141+
"(),(n)->()",
142+
nopython=True,
143+
cache=True,
144+
)
145+
def gufunc_complement_bytes(
146+
seq: NDArray[np.uint8],
147+
complement_map: NDArray[np.uint8],
148+
res: NDArray[np.uint8] | None = None,
149+
) -> NDArray[np.uint8]: # type: ignore
150+
res[0] = complement_map[seq] # type: ignore
151+
152+
153+
_COMP = np.frombuffer(bytes.maketrans(b"ACGT", b"TGCA"), np.uint8)
154+
155+
156+
@nb.vectorize(["u1(u1)"], nopython=True, cache=True)
157+
def ufunc_comp_dna(seq: NDArray[np.uint8]) -> NDArray[np.uint8]:
158+
return _COMP[seq]

python/seqpro/_utils.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from __future__ import annotations
22

3-
from typing import Optional, TypeVar, Union, cast, overload
3+
from typing import TypeVar, Union, cast, overload
44

55
import numpy as np
66
from numpy.typing import NDArray
7+
from typing_extensions import TypeGuard
78

89
NestedStr = Union[bytes, str, list["NestedStr"]]
910
"""String or nested list of strings"""
@@ -13,20 +14,22 @@
1314

1415
SeqType = Union[NestedStr, NDArray[Union[np.str_, np.object_, np.bytes_, np.uint8]]]
1516

17+
DTYPE = TypeVar("DTYPE", bound=np.generic)
1618

17-
@overload
18-
def cast_seqs(seqs: NDArray[np.uint8]) -> NDArray[np.uint8]: ...
1919

20+
def is_dtype(
21+
obj: object, dtype: DTYPE | np.dtype[DTYPE] | type[DTYPE]
22+
) -> TypeGuard[NDArray[DTYPE]]:
23+
return isinstance(obj, np.ndarray) and np.issubdtype(obj.dtype, dtype)
2024

25+
26+
@overload
27+
def cast_seqs(seqs: NDArray[np.uint8]) -> NDArray[np.uint8]: ...
2128
@overload
2229
def cast_seqs(seqs: StrSeqType) -> NDArray[np.bytes_]: ...
23-
24-
2530
@overload
26-
def cast_seqs(seqs: SeqType) -> NDArray[Union[np.bytes_, np.uint8]]: ...
27-
28-
29-
def cast_seqs(seqs: SeqType) -> NDArray[Union[np.bytes_, np.uint8]]:
31+
def cast_seqs(seqs: SeqType) -> NDArray[np.bytes_ | np.uint8]: ...
32+
def cast_seqs(seqs: SeqType) -> NDArray[np.bytes_ | np.uint8]:
3033
"""Cast any sequence type to be a NumPy array of ASCII characters (or left alone as
3134
8-bit unsigned integers if the input is OHE).
3235
@@ -54,25 +57,25 @@ def cast_seqs(seqs: SeqType) -> NDArray[Union[np.bytes_, np.uint8]]:
5457

5558
def check_axes(
5659
seqs: SeqType,
57-
length_axis: Optional[Union[int, bool]] = None,
58-
ohe_axis: Optional[Union[int, bool]] = None,
60+
length_axis: int | bool | None = None,
61+
ohe_axis: int | bool | None = None,
5962
):
6063
"""Raise errors if length_axis or ohe_axis is missing when they're needed. Pass
6164
False to corresponding axis to not check for it.
6265
6366
- ndarray with itemsize == 1 => length axis required.
6467
- OHE array => length and OHE axis required.
6568
"""
69+
# OHE
70+
if ohe_axis is None and is_dtype(seqs, np.uint8):
71+
raise ValueError("Need an one hot encoding axis to process OHE sequences.")
72+
6673
# bytes or OHE
67-
if length_axis is None and isinstance(seqs, np.ndarray) and seqs.itemsize == 1:
74+
if length_axis is None and is_dtype(seqs, np.bytes_) and seqs.itemsize == 1:
6875
raise ValueError(
6976
"Need a length axis to process an ndarray with itemsize == 1 (S1, u1)."
7077
)
7178

72-
# OHE
73-
if ohe_axis is None and isinstance(seqs, np.ndarray) and seqs.dtype == np.uint8:
74-
raise ValueError("Need an one hot encoding axis to process OHE sequences.")
75-
7679
# length_axis != ohe_axis
7780
if (
7881
isinstance(length_axis, int)
@@ -82,9 +85,6 @@ def check_axes(
8285
raise ValueError("Length and OHE axis must be different.")
8386

8487

85-
DTYPE = TypeVar("DTYPE", bound=np.generic)
86-
87-
8888
def array_slice(a: NDArray[DTYPE], axis: int, slice_: slice) -> NDArray[DTYPE]:
8989
"""Slice an array from a dynamic axis."""
9090
return a[(slice(None),) * (axis % a.ndim) + (slice_,)]

python/seqpro/alphabets/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from ._alphabets import AminoAlphabet, NucleotideAlphabet
1+
from ._alphabets import DNA, AminoAlphabet, NucleotideAlphabet
22

33
# NOTE the "*" character is termination i.e. STOP codon
44
canonical_codons_to_aas = {
@@ -69,7 +69,6 @@
6969
}
7070

7171

72-
DNA = NucleotideAlphabet(alphabet="ACGT", complement="TGCA")
7372
RNA = NucleotideAlphabet(alphabet="ACGU", complement="UGCA")
7473
AA = AminoAlphabet(*map(list, zip(*canonical_codons_to_aas.items())))
7574

python/seqpro/alphabets/_alphabets.py

Lines changed: 72 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,32 @@
1+
from __future__ import annotations
2+
3+
from types import MethodType
14
from typing import Dict, List, Optional, Union, cast, overload
25

36
import numpy as np
47
from numpy.typing import NDArray
8+
from typing_extensions import assert_never
59

6-
from .._numba import gufunc_ohe, gufunc_ohe_char_idx, gufunc_translate
7-
from .._utils import SeqType, StrSeqType, cast_seqs, check_axes
10+
from .._numba import (
11+
gufunc_complement_bytes,
12+
gufunc_ohe,
13+
gufunc_ohe_char_idx,
14+
gufunc_translate,
15+
ufunc_comp_dna,
16+
)
17+
from .._utils import SeqType, StrSeqType, cast_seqs, check_axes, is_dtype
818

919

1020
class NucleotideAlphabet:
1121
alphabet: str
1222
"""Alphabet excluding ambiguous characters e.g. "N" for DNA."""
1323
complement: str
1424
array: NDArray[np.bytes_]
15-
complement_map: Dict[str, str]
16-
complement_map_bytes: Dict[bytes, bytes]
17-
str_comp_table: Dict[int, str]
25+
complement_map: dict[str, str]
26+
complement_map_bytes: dict[bytes, bytes]
27+
str_comp_table: dict[int, str]
1828
bytes_comp_table: bytes
29+
bytes_comp_array: NDArray[np.bytes_]
1930

2031
def __init__(self, alphabet: str, complement: str) -> None:
2132
"""Parse and validate sequence alphabets.
@@ -36,16 +47,15 @@ def __init__(self, alphabet: str, complement: str) -> None:
3647
self.array = cast(
3748
NDArray[np.bytes_], np.frombuffer(self.alphabet.encode("ascii"), "|S1")
3849
)
39-
self.complement_map: Dict[str, str] = dict(
40-
zip(list(self.alphabet), list(self.complement))
41-
)
50+
self.complement_map = dict(zip(list(self.alphabet), list(self.complement)))
4251
self.complement_map_bytes = {
4352
k.encode("ascii"): v.encode("ascii") for k, v in self.complement_map.items()
4453
}
4554
self.str_comp_table = str.maketrans(self.complement_map)
4655
self.bytes_comp_table = bytes.maketrans(
4756
self.alphabet.encode("ascii"), self.complement.encode("ascii")
4857
)
58+
self.bytes_comp_array = np.frombuffer(self.bytes_comp_table, "S1")
4959

5060
def __len__(self):
5161
return len(self.alphabet)
@@ -109,31 +119,37 @@ def decode_ohe(
109119

110120
return _alphabet[idx].reshape(shape)
111121

112-
def complement_bytes(self, byte_arr: NDArray[np.bytes_]) -> NDArray[np.bytes_]:
122+
def complement_bytes(
123+
self, byte_arr: NDArray[np.bytes_], out: NDArray[np.bytes_] | None = None
124+
) -> NDArray[np.bytes_]:
113125
"""Get reverse complement of byte (S1) array.
114126
115127
Parameters
116128
----------
117129
byte_arr : ndarray[bytes]
118130
"""
119-
# * a vectorized implementation using np.unique or np.char.translate is NOT
120-
# * faster even for longer alphabets like IUPAC DNA/RNA. Another optimization to
121-
# * try would be using vectorized bit manipulations.
122-
out = byte_arr.copy()
123-
for nuc, comp in self.complement_map_bytes.items():
124-
out[byte_arr == nuc] = comp
125-
return out
131+
if out is None:
132+
_out = out
133+
else:
134+
_out = out.view(np.uint8)
135+
_out = gufunc_complement_bytes(
136+
byte_arr.view(np.uint8), self.bytes_comp_array.view(np.uint8), _out
137+
)
138+
return _out.view("S1")
126139

127140
def rev_comp_byte(
128-
self, byte_arr: NDArray[np.bytes_], length_axis: int
141+
self,
142+
byte_arr: NDArray[np.bytes_],
143+
length_axis: int,
144+
out: NDArray[np.bytes_] | None = None,
129145
) -> NDArray[np.bytes_]:
130146
"""Get reverse complement of byte (S1) array.
131147
132148
Parameters
133149
----------
134150
byte_arr : ndarray[bytes]
135151
"""
136-
out = self.complement_bytes(byte_arr)
152+
out = self.complement_bytes(byte_arr, out)
137153
return np.flip(out, length_axis)
138154

139155
def rev_comp_string(self, string: str):
@@ -150,27 +166,31 @@ def reverse_complement(
150166
seqs: StrSeqType,
151167
length_axis: Optional[int] = None,
152168
ohe_axis: Optional[int] = None,
169+
out: NDArray[np.bytes_] | None = None,
153170
) -> NDArray[np.bytes_]: ...
154171
@overload
155172
def reverse_complement(
156173
self,
157174
seqs: NDArray[np.uint8],
158175
length_axis: Optional[int] = None,
159176
ohe_axis: Optional[int] = None,
177+
out: NDArray[np.bytes_] | None = None,
160178
) -> NDArray[np.uint8]: ...
161179
@overload
162180
def reverse_complement(
163181
self,
164182
seqs: SeqType,
165183
length_axis: Optional[int] = None,
166184
ohe_axis: Optional[int] = None,
185+
out: NDArray[np.bytes_] | None = None,
167186
) -> NDArray[Union[np.bytes_, np.uint8]]: ...
168187
def reverse_complement(
169188
self,
170189
seqs: SeqType,
171190
length_axis: Optional[int] = None,
172191
ohe_axis: Optional[int] = None,
173-
) -> NDArray[Union[np.bytes_, np.uint8]]:
192+
out: NDArray[np.bytes_] | None = None,
193+
) -> NDArray[np.bytes_ | np.uint8]:
174194
"""Reverse complement a sequence.
175195
176196
Parameters
@@ -190,14 +210,20 @@ def reverse_complement(
190210

191211
seqs = cast_seqs(seqs)
192212

193-
if seqs.dtype == np.uint8: # OHE
213+
if is_dtype(seqs, np.bytes_):
214+
if length_axis is None:
215+
length_axis = -1
216+
return self.rev_comp_byte(seqs, length_axis, out)
217+
elif is_dtype(seqs, np.uint8): # OHE
194218
assert length_axis is not None
195219
assert ohe_axis is not None
196-
return np.flip(seqs, axis=(length_axis, ohe_axis))
220+
_out = np.flip(seqs, axis=(length_axis, ohe_axis))
221+
if out is not None:
222+
out[:] = _out
223+
_out = out
224+
return _out
197225
else:
198-
if length_axis is None:
199-
length_axis = -1
200-
return self.rev_comp_byte(seqs, length_axis) # type: ignore
226+
assert_never(seqs) # type: ignore
201227

202228

203229
class AminoAlphabet:
@@ -334,3 +360,25 @@ def decode_ohe(
334360
_alphabet = np.concatenate([self.aa_array, [unknown_char.encode("ascii")]])
335361

336362
return _alphabet[idx].reshape(shape)
363+
364+
365+
DNA = NucleotideAlphabet("ACGT", "TGCA")
366+
367+
368+
# * Monkey patch DNA instance with a faster complement function using
369+
# * a static, const lookup table. The base method is slower because it uses a
370+
# * dynamic lookup table.
371+
def complement_bytes(
372+
self: NucleotideAlphabet,
373+
byte_arr: NDArray[np.bytes_],
374+
out: NDArray[np.bytes_] | None = None,
375+
) -> NDArray[np.bytes_]:
376+
if out is None:
377+
_out = out
378+
else:
379+
_out = out.view(np.uint8)
380+
_out = ufunc_comp_dna(byte_arr.view(np.uint8), _out) # type: ignore
381+
return _out.view("S1")
382+
383+
384+
DNA.complement_bytes = MethodType(complement_bytes, DNA)

python/seqpro/rag/_array.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def __init__(
6161
if isinstance(data, RagParts):
6262
content = _parts_to_content(data)
6363
else:
64-
content = _with_ragged(data, highlevel=False)
64+
content = _as_ragged(data, highlevel=False)
6565
super().__init__(content, behavior=deepcopy(ak.behavior))
6666
self._parts = unbox(self)
6767
type_parts: list[str] = []
@@ -232,7 +232,7 @@ def __getitem__(self, where):
232232
if _n_var(arr) == 1:
233233
return type(self)(arr)
234234
else:
235-
return _without_ragged(arr)
235+
return _as_ak(arr)
236236
else:
237237
return arr
238238

@@ -293,7 +293,7 @@ def reshape(self, *shape: int | None | tuple[int | None, ...]) -> Self:
293293

294294
def to_ak(self):
295295
"""Convert to an Awkward array."""
296-
arr = _without_ragged(self)
296+
arr = _as_ak(self)
297297
arr.behavior = None
298298
return arr
299299

@@ -331,12 +331,12 @@ def _n_var(arr: ak.Array) -> int:
331331

332332

333333
@overload
334-
def _with_ragged(
334+
def _as_ragged(
335335
arr: ak.Array | Content, highlevel: Literal[True] = True
336336
) -> ak.Array: ...
337337
@overload
338-
def _with_ragged(arr: ak.Array | Content, highlevel: Literal[False]) -> Content: ...
339-
def _with_ragged(arr: ak.Array | Content, highlevel: bool = True) -> ak.Array | Content:
338+
def _as_ragged(arr: ak.Array | Content, highlevel: Literal[False]) -> Content: ...
339+
def _as_ragged(arr: ak.Array | Content, highlevel: bool = True) -> ak.Array | Content:
340340
def fn(layout: Content, **kwargs):
341341
if isinstance(layout, (ListArray, ListOffsetArray)):
342342
return ak.with_parameter(
@@ -350,16 +350,12 @@ def fn(layout: Content, **kwargs):
350350

351351

352352
@overload
353-
def _without_ragged(
353+
def _as_ak(
354354
arr: ak.Array | Ragged[DTYPE], highlevel: Literal[True] = True
355355
) -> ak.Array: ...
356356
@overload
357-
def _without_ragged(
358-
arr: ak.Array | Ragged[DTYPE], highlevel: Literal[False]
359-
) -> Content: ...
360-
def _without_ragged(
361-
arr: ak.Array | Ragged[DTYPE], highlevel: bool = True
362-
) -> ak.Array | Content:
357+
def _as_ak(arr: ak.Array | Ragged[DTYPE], highlevel: Literal[False]) -> Content: ...
358+
def _as_ak(arr: ak.Array | Ragged[DTYPE], highlevel: bool = True) -> ak.Array | Content:
363359
def fn(layout, **kwargs):
364360
if isinstance(layout, (ListArray, ListOffsetArray)):
365361
return ak.with_parameter(layout, "__list__", None, highlevel=False)

0 commit comments

Comments
 (0)