|
2 | 2 | from __future__ import absolute_import |
3 | 3 | from __future__ import print_function |
4 | 4 |
|
| 5 | +from typing import Any |
5 | 6 | from kipoiseq.utils import DNA |
6 | 7 | from copy import deepcopy |
7 | 8 | import numpy as np |
8 | 9 | from six import string_types |
9 | 10 |
|
10 | | -try: |
11 | | - # use the fast genomelake's one-hot-encode if it's installed |
12 | | - from genomelake.util import one_hot_encode_sequence |
13 | | -except ImportError: |
14 | | - one_hot_encode_sequence = None |
15 | | - |
16 | 11 |
|
17 | 12 | # sequence -> array |
18 | 13 |
|
@@ -119,21 +114,23 @@ def one_hot(seq, alphabet=DNA, neutral_alphabet=['N'], neutral_value=.25, dtype= |
119 | 114 | raise ValueError("seq needs to be a string") |
120 | 115 | return token2one_hot(tokenize(seq, alphabet, neutral_alphabet), len(alphabet), neutral_value, dtype=dtype) |
121 | 116 |
|
122 | | - |
123 | | -def one_hot_dna(seq, dtype=None): |
124 | | - """One-hot encode DNA sequence |
125 | | - """ |
| 117 | +# Reference: https://github.com/deepmind/deepmind-research/blob/fa8c9be4bb0cfd0b8492203eb2a9f31ef995633c/enformer/enformer.py#L306-L318 |
| 118 | +def one_hot_dna(seq: str, |
| 119 | + alphabet: list = DNA, |
| 120 | + neutral_alphabet: str = 'N', |
| 121 | + neutral_value: Any = 0.25, |
| 122 | + dtype=np.float32) -> np.ndarray: |
| 123 | + """One-hot encode sequence.""" |
126 | 124 | if not isinstance(seq, str): |
127 | | - raise ValueError("seq needs to be a string") |
128 | | - |
129 | | - if one_hot_encode_sequence is not None: |
130 | | - # genomelake's one_hot_encode_sequence could be imported |
131 | | - out = np.zeros((len(seq), 4), dtype=np.float32) |
132 | | - one_hot_encode_sequence(seq, out) |
133 | | - return out.astype(dtype) |
134 | | - else: |
135 | | - return one_hot(seq, alphabet=DNA, neutral_alphabet=['N'], neutral_value=.25, dtype=dtype) |
136 | | - |
| 125 | + raise ValueError("sequence needs to be a string") |
| 126 | + def to_uint8(string): |
| 127 | + return np.frombuffer(string.encode('ascii'), dtype=np.uint8) |
| 128 | + |
| 129 | + hash_table = np.zeros((np.iinfo(np.uint8).max, len(alphabet)), dtype=dtype) |
| 130 | + hash_table[to_uint8(''.join(alphabet))] = np.eye(len(alphabet), dtype=dtype) |
| 131 | + hash_table[to_uint8(''.join(neutral_alphabet))] = neutral_value |
| 132 | + hash_table = hash_table.astype(dtype) |
| 133 | + return hash_table[to_uint8(seq)] |
137 | 134 |
|
138 | 135 | # sequence trimming |
139 | 136 |
|
|
0 commit comments