1+ from __future__ import annotations
2+
3+ from types import MethodType
14from typing import Dict , List , Optional , Union , cast , overload
25
36import numpy as np
47from numpy .typing import NDArray
8+ from typing_extensions import assert_never
59
6- from .._numba import gufunc_ohe , gufunc_ohe_char_idx , gufunc_translate
7- from .._utils import SeqType , StrSeqType , cast_seqs , check_axes
10+ from .._numba import (
11+ gufunc_complement_bytes ,
12+ gufunc_ohe ,
13+ gufunc_ohe_char_idx ,
14+ gufunc_translate ,
15+ ufunc_comp_dna ,
16+ )
17+ from .._utils import SeqType , StrSeqType , cast_seqs , check_axes , is_dtype
818
919
1020class NucleotideAlphabet :
1121 alphabet : str
1222 """Alphabet excluding ambiguous characters e.g. "N" for DNA."""
1323 complement : str
1424 array : NDArray [np .bytes_ ]
15- complement_map : Dict [str , str ]
16- complement_map_bytes : Dict [bytes , bytes ]
17- str_comp_table : Dict [int , str ]
25+ complement_map : dict [str , str ]
26+ complement_map_bytes : dict [bytes , bytes ]
27+ str_comp_table : dict [int , str ]
1828 bytes_comp_table : bytes
29+ bytes_comp_array : NDArray [np .bytes_ ]
1930
2031 def __init__ (self , alphabet : str , complement : str ) -> None :
2132 """Parse and validate sequence alphabets.
@@ -36,16 +47,15 @@ def __init__(self, alphabet: str, complement: str) -> None:
3647 self .array = cast (
3748 NDArray [np .bytes_ ], np .frombuffer (self .alphabet .encode ("ascii" ), "|S1" )
3849 )
39- self .complement_map : Dict [str , str ] = dict (
40- zip (list (self .alphabet ), list (self .complement ))
41- )
50+ self .complement_map = dict (zip (list (self .alphabet ), list (self .complement )))
4251 self .complement_map_bytes = {
4352 k .encode ("ascii" ): v .encode ("ascii" ) for k , v in self .complement_map .items ()
4453 }
4554 self .str_comp_table = str .maketrans (self .complement_map )
4655 self .bytes_comp_table = bytes .maketrans (
4756 self .alphabet .encode ("ascii" ), self .complement .encode ("ascii" )
4857 )
58+ self .bytes_comp_array = np .frombuffer (self .bytes_comp_table , "S1" )
4959
5060 def __len__ (self ):
5161 return len (self .alphabet )
@@ -109,31 +119,37 @@ def decode_ohe(
109119
110120 return _alphabet [idx ].reshape (shape )
111121
112- def complement_bytes (self , byte_arr : NDArray [np .bytes_ ]) -> NDArray [np .bytes_ ]:
122+ def complement_bytes (
123+ self , byte_arr : NDArray [np .bytes_ ], out : NDArray [np .bytes_ ] | None = None
124+ ) -> NDArray [np .bytes_ ]:
113125 """Get reverse complement of byte (S1) array.
114126
115127 Parameters
116128 ----------
117129 byte_arr : ndarray[bytes]
118130 """
119- # * a vectorized implementation using np.unique or np.char.translate is NOT
120- # * faster even for longer alphabets like IUPAC DNA/RNA. Another optimization to
121- # * try would be using vectorized bit manipulations.
122- out = byte_arr .copy ()
123- for nuc , comp in self .complement_map_bytes .items ():
124- out [byte_arr == nuc ] = comp
125- return out
131+ if out is None :
132+ _out = out
133+ else :
134+ _out = out .view (np .uint8 )
135+ _out = gufunc_complement_bytes (
136+ byte_arr .view (np .uint8 ), self .bytes_comp_array .view (np .uint8 ), _out
137+ )
138+ return _out .view ("S1" )
126139
127140 def rev_comp_byte (
128- self , byte_arr : NDArray [np .bytes_ ], length_axis : int
141+ self ,
142+ byte_arr : NDArray [np .bytes_ ],
143+ length_axis : int ,
144+ out : NDArray [np .bytes_ ] | None = None ,
129145 ) -> NDArray [np .bytes_ ]:
130146 """Get reverse complement of byte (S1) array.
131147
132148 Parameters
133149 ----------
134150 byte_arr : ndarray[bytes]
135151 """
136- out = self .complement_bytes (byte_arr )
152+ out = self .complement_bytes (byte_arr , out )
137153 return np .flip (out , length_axis )
138154
139155 def rev_comp_string (self , string : str ):
@@ -150,27 +166,31 @@ def reverse_complement(
150166 seqs : StrSeqType ,
151167 length_axis : Optional [int ] = None ,
152168 ohe_axis : Optional [int ] = None ,
169+ out : NDArray [np .bytes_ ] | None = None ,
153170 ) -> NDArray [np .bytes_ ]: ...
154171 @overload
155172 def reverse_complement (
156173 self ,
157174 seqs : NDArray [np .uint8 ],
158175 length_axis : Optional [int ] = None ,
159176 ohe_axis : Optional [int ] = None ,
177+ out : NDArray [np .bytes_ ] | None = None ,
160178 ) -> NDArray [np .uint8 ]: ...
161179 @overload
162180 def reverse_complement (
163181 self ,
164182 seqs : SeqType ,
165183 length_axis : Optional [int ] = None ,
166184 ohe_axis : Optional [int ] = None ,
185+ out : NDArray [np .bytes_ ] | None = None ,
167186 ) -> NDArray [Union [np .bytes_ , np .uint8 ]]: ...
168187 def reverse_complement (
169188 self ,
170189 seqs : SeqType ,
171190 length_axis : Optional [int ] = None ,
172191 ohe_axis : Optional [int ] = None ,
173- ) -> NDArray [Union [np .bytes_ , np .uint8 ]]:
192+ out : NDArray [np .bytes_ ] | None = None ,
193+ ) -> NDArray [np .bytes_ | np .uint8 ]:
174194 """Reverse complement a sequence.
175195
176196 Parameters
@@ -190,14 +210,20 @@ def reverse_complement(
190210
191211 seqs = cast_seqs (seqs )
192212
193- if seqs .dtype == np .uint8 : # OHE
213+ if is_dtype (seqs , np .bytes_ ):
214+ if length_axis is None :
215+ length_axis = - 1
216+ return self .rev_comp_byte (seqs , length_axis , out )
217+ elif is_dtype (seqs , np .uint8 ): # OHE
194218 assert length_axis is not None
195219 assert ohe_axis is not None
196- return np .flip (seqs , axis = (length_axis , ohe_axis ))
220+ _out = np .flip (seqs , axis = (length_axis , ohe_axis ))
221+ if out is not None :
222+ out [:] = _out
223+ _out = out
224+ return _out
197225 else :
198- if length_axis is None :
199- length_axis = - 1
200- return self .rev_comp_byte (seqs , length_axis ) # type: ignore
226+ assert_never (seqs ) # type: ignore
201227
202228
203229class AminoAlphabet :
@@ -334,3 +360,25 @@ def decode_ohe(
334360 _alphabet = np .concatenate ([self .aa_array , [unknown_char .encode ("ascii" )]])
335361
336362 return _alphabet [idx ].reshape (shape )
363+
364+
365+ DNA = NucleotideAlphabet ("ACGT" , "TGCA" )
366+
367+
368+ # * Monkey patch DNA instance with a faster complement function using
369+ # * a static, const lookup table. The base method is slower because it uses a
370+ # * dynamic lookup table.
371+ def complement_bytes (
372+ self : NucleotideAlphabet ,
373+ byte_arr : NDArray [np .bytes_ ],
374+ out : NDArray [np .bytes_ ] | None = None ,
375+ ) -> NDArray [np .bytes_ ]:
376+ if out is None :
377+ _out = out
378+ else :
379+ _out = out .view (np .uint8 )
380+ _out = ufunc_comp_dna (byte_arr .view (np .uint8 ), _out ) # type: ignore
381+ return _out .view ("S1" )
382+
383+
384+ DNA .complement_bytes = MethodType (complement_bytes , DNA )
0 commit comments