Skip to content

Commit afdb1cf

Browse files
committed
use genomelake's fast one_hot_encode if it's installed
1 parent d813a45 commit afdb1cf

File tree

1 file changed

+15
-20
lines changed

1 file changed

+15
-20
lines changed

kipoiseq/transforms/functional.py

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,19 @@
22
from __future__ import absolute_import
33
from __future__ import print_function
44

5-
# from genomelake.util import one_hot_encode_sequence # TODO include once the pip install works for genomelake
65
from kipoiseq.utils import DNA
76
from copy import deepcopy
87
import numpy as np
98
from six import string_types
109

1110

11+
try:
12+
# use the fast genomelake's one-hot-encode if it's installed
13+
from genomelake.util import one_hot_encode_sequence
14+
except ImportError:
15+
one_hot_encode_sequence = None
16+
17+
1218
# sequence -> array
1319

1420
def _get_alphabet_dict(alphabet):
@@ -83,11 +89,14 @@ def one_hot_dna(seq, dtype=None):
8389
"""
8490
if not isinstance(seq, str):
8591
raise ValueError("seq needs to be a string")
86-
# TODO - include one you use genomelake again
87-
# out = np.zeros((len(seq), 4), dtype=np.float32)
88-
# one_hot_encode_sequence(seq, out)
89-
# return out
90-
return one_hot(seq, alphabet=DNA, neutral_alphabet=['N'], neutral_value=.25, dtype=dtype)
92+
93+
if one_hot_encode_sequence is not None:
94+
# genomelake's one_hot_encode_sequence could be imported
95+
out = np.zeros((len(seq), 4), dtype=np.float32)
96+
one_hot_encode_sequence(seq, out)
97+
return out.astype(dtype)
98+
else:
99+
return one_hot(seq, alphabet=DNA, neutral_alphabet=['N'], neutral_value=.25, dtype=dtype)
91100

92101
# sequence trimming
93102

@@ -187,10 +196,6 @@ def fixed_len(seq, length, anchor="center", value="N"):
187196
return seq
188197

189198

190-
# --------------------------------------------
191-
# TODO - lookup what is this used for
192-
193-
194199
def resize_interval(interval, width, anchor='center'):
195200
"""Resize the Interval. Returns new Interval instance with correct length.
196201
@@ -214,13 +219,3 @@ def resize_interval(interval, width, anchor='center'):
214219
raise Exception("Interval resizing anchor point can only be 'start', 'end' or 'center'")
215220

216221
return interval
217-
218-
219-
# TODO - put all these into classes
220-
# define the keras-type interface with get() - use default kwargs when using it
221-
# def get_string_transforms(trafo):
222-
# if trafo is not None and isinstance(trafo, string_types):
223-
# if trafo in TRANSFORMS:
224-
# trafo = TRANSFORMS[trafo]
225-
# return trafo
226-
# TRANSFORMS = {"onehot_trafo": onehot_transform}

0 commit comments

Comments
 (0)