Skip to content

Commit 2bb23ea

Browse files
committed
fix parse_dtype
1 parent afdb1cf commit 2bb23ea

File tree

11 files changed

+302
-416
lines changed

11 files changed

+302
-416
lines changed

kipoiseq/dataloaders/sequence.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class BedDataset(object):
4040
bed_columns: number of columns corresponding to the bed file. All the columns
4141
after that will be parsed as targets
4242
num_chr: if specified, 'chr' in the chromosome name will be dropped
43-
label_dtype: specific data type for labels
43+
label_dtype: specific data type for labels, Example: `float` or `np.float32`
4444
ambiguous_mask: if specified, rows containing only ambiguous_mask values will be skipped
4545
incl_chromosomes: exclusive list of chromosome names to include in the final dataset.
4646
if not None, only these will be present in the dataset
@@ -153,7 +153,7 @@ class IntervalSeqStringDl(Dataset):
153153
num_chr_fasta:
154154
doc: True, the the dataloader will make sure that the chromosomes don't start with chr.
155155
label_dtype:
156-
doc: None, datatype of the task labels taken from the intervals_file. Allowed - string', 'int', 'float', 'bool'
156+
doc: None, datatype of the task labels taken from the intervals_file. Example - str, int, float, np.float32
157157
auto_resize_len:
158158
doc: None, required sequence length.
159159
# max_seq_len:
@@ -280,7 +280,7 @@ class IntervalSeqDl(Dataset):
280280
num_chr_fasta:
281281
doc: True, the the dataloader will make sure that the chromosomes don't start with chr.
282282
label_dtype:
283-
doc: None, datatype of the task labels taken from the intervals_file. Allowed - string', 'int', 'float', 'bool'
283+
doc: 'None, datatype of the task labels taken from the intervals_file. Example: str, int, float, np.float32'
284284
auto_resize_len:
285285
doc: None, required sequence length.
286286
# use_strand:
@@ -294,7 +294,7 @@ class IntervalSeqDl(Dataset):
294294
alphabet to use for the one-hot encoding. This defines the order of the one-hot encoding.
295295
Can either be a list or a string: 'ACGT' or ['A, 'C', 'G', 'T']. Default: 'ACGT'
296296
dtype:
297-
doc: defines the numpy dtype of the returned array.
297+
doc: 'defines the numpy dtype of the returned array. Example: int, np.int32, np.float32, float'
298298
ignore_targets:
299299
doc: if True, don't return any target variables
300300

kipoiseq/transforms/transforms.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import numpy as np
66
from kipoiseq.transforms import functional as F
7-
from kipoiseq.utils import DNA, parse_alphabet, parse_type
7+
from kipoiseq.utils import DNA, parse_alphabet, parse_dtype
88

99

1010
class Compose(object):
@@ -146,19 +146,20 @@ def __init__(self,
146146
alphabet_axis=1,
147147
dummy_axis=None):
148148
# make sure the alphabet axis and the dummy axis are valid:
149+
if dummy_axis is not None:
150+
if alphabet_axis == dummy_axis:
151+
raise ValueError("dummy_axis can't be the same as dummy_axis")
152+
if not (dummy_axis >= 0 and dummy_axis <= 2):
153+
raise ValueError("dummy_axis can be either 0,1 or 2")
149154
assert alphabet_axis >= 0 and (alphabet_axis < 2 or (alphabet_axis <= 2 and dummy_axis is not None))
150-
assert dummy_axis is None or (dummy_axis >= 0 and dummy_axis <= 2 and alphabet_axis != dummy_axis)
151155

152156
self.alphabet_axis = alphabet_axis
153157
self.dummy_axis = dummy_axis
154158
self.alphabet = parse_alphabet(alphabet)
155-
self.dtype = parse_type(dtype)
159+
self.dtype = parse_dtype(dtype)
156160
self.neutral_alphabet = neutral_alphabet
157161
self.neutral_value = neutral_value
158162

159-
if dummy_axis is not None and alphabet_axis == dummy_axis:
160-
raise ValueError("dummy_axis can't be the same as dummy_axis")
161-
162163
# set the transform parameters correctly
163164
if dummy_axis is not None and dummy_axis < 2:
164165
# dummy axis is added somewhere in the middle, so the alphabet axis is at the end now

kipoiseq/utils.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,27 +22,18 @@ def to_scalar(obj):
2222
return obj
2323

2424

25-
def parse_dtype(dtype):
26-
dtypes = {'int': int, 'string': str, 'float': float, 'bool': bool}
27-
if dtype is None:
28-
return None
29-
if dtype in list(dtypes.values()):
30-
return dtype
31-
if dtype not in dtypes:
32-
raise Exception("Datatype '{0}' not recognized. Allowed are: {1}".format(dtype, str(list(dtypes.keys()))))
33-
return dtypes[dtype]
34-
35-
3625
def parse_alphabet(alphabet):
3726
if isinstance(alphabet, str):
3827
return list(alphabet)
3928
else:
4029
return alphabet
4130

4231

43-
def parse_type(dtype):
32+
def parse_dtype(dtype):
4433
if isinstance(dtype, string_types):
45-
if dtype in dir(np):
46-
return getattr(np, dtype)
34+
try:
35+
return eval(dtype)
36+
except Exception as e:
37+
raise ValueError("Unable to parse dtype: {}. \nException: {}".format(dtype, e))
4738
else:
4839
return dtype

notebooks/getting-started.ipynb

Lines changed: 158 additions & 303 deletions
Large diffs are not rendered by default.
Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from pybedtools import Interval
66
from kipoi.utils import override_default_kwargs
77
from kipoiseq.transforms.functional import one_hot_dna
8-
from kipoiseq.dataloaders.sequence import IntervalSeqStringDl, IntervalSeqDl, parse_dtype, BedDataset
8+
from kipoiseq.dataloaders.sequence import IntervalSeqStringDl, IntervalSeqDl, BedDataset
99

1010

1111
@pytest.fixture
@@ -32,15 +32,6 @@ def test_min_props():
3232
assert all([el in props for el in min_set_props])
3333

3434

35-
def test_parse_dtype():
36-
dtypes = {'int': int, 'string': str, 'float': float, 'bool': bool}
37-
assert all([parse_dtype(dt) == dtypes[dt] for dt in dtypes.keys()])
38-
assert all([parse_dtype(dt) == dt for dt in dtypes.values()])
39-
with pytest.raises(Exception):
40-
parse_dtype("int8")
41-
assert parse_dtype(None) is None
42-
43-
4435
def test_fasta_based_dataset(intervals_file, fasta_file):
4536
# just test the functionality
4637
dl = IntervalSeqStringDl(intervals_file, fasta_file)
@@ -52,7 +43,7 @@ def test_fasta_based_dataset(intervals_file, fasta_file):
5243
# with pytest.raises(Exception):
5344
# dl[0]
5445

55-
dl = IntervalSeqStringDl(intervals_file, fasta_file, label_dtype="string")
46+
dl = IntervalSeqStringDl(intervals_file, fasta_file, label_dtype="str")
5647
ret_val = dl[0]
5748
assert isinstance(ret_val['targets'][0], np.str_)
5849
dl = IntervalSeqStringDl(intervals_file, fasta_file, label_dtype="int")
@@ -74,6 +65,44 @@ def test_seq_dataset(intervals_file, fasta_file):
7465
assert ret_val["inputs"].shape == (2, 4)
7566

7667

68+
@pytest.fixture
69+
def example_kwargs():
70+
return IntervalSeqDl.example_kwargs
71+
72+
73+
@pytest.mark.parametrize("alphabet_axis", list(range(0, 4)))
74+
@pytest.mark.parametrize("dummy_axis", [None] + list(range(0, 4)))
75+
def test_seq_dataset_reshape(alphabet_axis, dummy_axis, example_kwargs):
76+
seq_len, alphabet_len = 3, 4
77+
78+
kwargs = example_kwargs
79+
kwargs['auto_resize_len'] = seq_len
80+
kwargs['alphabet_axis'] = alphabet_axis
81+
kwargs['dummy_axis'] = dummy_axis
82+
83+
dummy_axis_int = dummy_axis
84+
if dummy_axis is None:
85+
dummy_axis_int = -2
86+
87+
if (alphabet_axis == dummy_axis_int) or (alphabet_axis == -1) or (dummy_axis_int == -1) or \
88+
(alphabet_axis >= 3) or (dummy_axis_int >= 3) or ((alphabet_axis >= 2) and (dummy_axis is None)):
89+
with pytest.raises(Exception):
90+
seq_dataset = IntervalSeqDl(**kwargs)
91+
return None
92+
93+
seq_dataset = IntervalSeqDl(**kwargs)
94+
95+
# test the single sample works
96+
reshaped = seq_dataset[0]['inputs']
97+
for i in range(len(reshaped.shape)):
98+
if i == dummy_axis:
99+
assert reshaped.shape[i] == 1
100+
elif i == alphabet_axis:
101+
assert reshaped.shape[i] == alphabet_len
102+
else:
103+
assert reshaped.shape[i] == seq_len
104+
105+
77106
# download example files
78107
@pytest.mark.parametrize("cls", [IntervalSeqStringDl, IntervalSeqDl])
79108
def test_examples_exist(cls):
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,16 @@
11
import pytest
22
import numpy as np
33
import copy
4-
from kipoiseq.transforms.functional import resize_interval
5-
from kipoiseq.transforms.transforms import SplitSplicingSeq, ReorderedOneHot
4+
from kipoiseq.transforms.transforms import Compose, OneHot, SplitSplicingSeq, ReorderedOneHot
65
from kipoiseq.utils import DNA
76
from pybedtools import Interval
87

98

109
# --------------------------------------------
11-
12-
@pytest.mark.parametrize("anchor", ['start', 'end', 'center'])
13-
@pytest.mark.parametrize("ilen", [3, 4])
14-
def test_resize_interval(anchor, ilen):
15-
import pybedtools
16-
dummy_start, dummy_end = 10, 20
17-
dummy_centre = int((dummy_start + dummy_end) / 2)
18-
19-
dummy_inter = pybedtools.create_interval_from_list(['chr2', dummy_start, dummy_end, 'intname'])
20-
ret_inter = resize_interval(dummy_inter, ilen, anchor)
21-
22-
# the original interval was left intact
23-
assert dummy_inter.chrom == 'chr2'
24-
assert dummy_inter.start == dummy_start
25-
assert dummy_inter.end == dummy_end
26-
assert dummy_inter.name == 'intname'
27-
28-
# metadata kept
29-
assert ret_inter.chrom == dummy_inter.chrom
30-
assert ret_inter.name == 'intname'
31-
32-
# desired output width
33-
assert ret_inter.length == ilen
34-
35-
# correct anchor point
36-
if anchor == "start":
37-
assert ret_inter.start == dummy_start
38-
elif anchor == "end":
39-
assert ret_inter.end == dummy_end
40-
elif anchor == "centre":
41-
assert int((ret_inter.start + ret_inter.end) / 2) == dummy_centre
10+
def test_compose():
11+
c = Compose([OneHot()])
12+
print(str(c))
13+
assert c("ACGT").shape == (4, 4)
4214

4315

4416
def test_ReorderedOneHot():
@@ -60,10 +32,10 @@ def test_ReorderedOneHot():
6032
assert out.shape == tr.get_output_shape(seqlen)
6133
assert out.shape == result
6234

63-
with pytest.raises(Exception):
35+
with pytest.raises(ValueError):
6436
ReorderedOneHot(alphabet_axis=1, dummy_axis=1)
6537

66-
with pytest.raises(Exception):
38+
with pytest.raises(ValueError):
6739
ReorderedOneHot(dummy_axis=1)
6840

6941

@@ -86,3 +58,7 @@ def test_SplitSplicingSeq():
8658
assert splited['exon'] == 'GTAGTAGA'
8759
assert splited['donor'] == 'AGAGT'
8860
assert splited['intron3prime'] == 'CC'
61+
62+
63+
def test_ResizeInterval():
64+
pass

tests/test_0_transforms_functional.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pytest
2-
from kipoiseq.transforms.functional import tokenize, token2one_hot, one_hot, one_hot_dna, pad, trim, fixed_len
2+
from kipoiseq.transforms.functional import resize_interval, tokenize, token2one_hot, one_hot, one_hot_dna, pad, trim, fixed_len
3+
from kipoiseq.transforms.transforms import ResizeInterval
34
from kipoiseq.utils import DNA
45
import numpy as np
56

@@ -81,3 +82,61 @@ def test_pad_sequences():
8182

8283
assert fixed_len(seq, 10, anchor="start", value="N") == seq
8384
assert fixed_len(seq, 10, anchor="end", value="N") == 'CTTACTCAGA'
85+
86+
87+
@pytest.mark.parametrize("anchor", ['start', 'end', 'center'])
88+
@pytest.mark.parametrize("ilen", [3, 4])
89+
def test_resize_interval(anchor, ilen):
90+
import pybedtools
91+
dummy_start, dummy_end = 10, 20
92+
dummy_center = int((dummy_start + dummy_end) / 2)
93+
94+
dummy_inter = pybedtools.create_interval_from_list(['chr2', dummy_start, dummy_end, 'intname'])
95+
ret_inter = resize_interval(dummy_inter, ilen, anchor)
96+
97+
# the original interval was left intact
98+
assert dummy_inter.chrom == 'chr2'
99+
assert dummy_inter.start == dummy_start
100+
assert dummy_inter.end == dummy_end
101+
assert dummy_inter.name == 'intname'
102+
103+
# metadata kept
104+
assert ret_inter.chrom == dummy_inter.chrom
105+
assert ret_inter.name == 'intname'
106+
107+
# desired output width
108+
assert ret_inter.length == ilen
109+
110+
# correct anchor point
111+
if anchor == "start":
112+
assert ret_inter.start == dummy_start
113+
elif anchor == "end":
114+
assert ret_inter.end == dummy_end
115+
elif anchor == "center":
116+
assert int((ret_inter.start + ret_inter.end) / 2) == dummy_center
117+
118+
119+
def test_ResizeInterval():
120+
"""Same test as before
121+
"""
122+
import pybedtools
123+
dummy_start, dummy_end = 10, 20
124+
dummy_center = int((dummy_start + dummy_end) / 2)
125+
ilen = 4
126+
dummy_inter = pybedtools.create_interval_from_list(['chr2', dummy_start, dummy_end, 'intname'])
127+
ri = ResizeInterval(ilen, 'center')
128+
ret_inter = ri(dummy_inter)
129+
assert int((ret_inter.start + ret_inter.end) / 2) == dummy_center
130+
131+
# the original interval was left intact
132+
assert dummy_inter.chrom == 'chr2'
133+
assert dummy_inter.start == dummy_start
134+
assert dummy_inter.end == dummy_end
135+
assert dummy_inter.name == 'intname'
136+
137+
# metadata kept
138+
assert ret_inter.chrom == dummy_inter.chrom
139+
assert ret_inter.name == 'intname'
140+
141+
# desired output width
142+
assert ret_inter.length == ilen

tests/test_2_datasets.py

Lines changed: 0 additions & 40 deletions
This file was deleted.

0 commit comments

Comments
 (0)