Skip to content

Commit 51aaecb

Browse files
authored
Merge branch 'master' into fix_remove_use_strand
2 parents 96efedc + 9648ea8 commit 51aaecb

File tree

4 files changed

+112
-31
lines changed

4 files changed

+112
-31
lines changed

.pep8speaks.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pycodestyle:
2+
max-line-length: 140 # Default is 79 in PEP8
3+
ignore: # Errors and warnings to ignore
4+
- E111
5+
- E731

kipoiseq/datasets/sequence.py

Lines changed: 73 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from kipoi.plugin import is_installed
99
from kipoi.data import Dataset, kipoi_dataloader
1010
from kipoi.specs import Author, Dependencies
11+
from kipoi.utils import default_kwargs
1112
from six import string_types
1213

1314

@@ -72,6 +73,7 @@ class BedDataset(object):
7273
incl_chromosomes: exclusive list of chromosome names to include in the final dataset.
7374
if not None, only these will be present in the dataset
7475
excl_chromosomes: list of chromosome names to omit from the dataset.
76+
ignore_targets: if True, target variables are ignored
7577
"""
7678

7779
# bed types accorging to
@@ -95,14 +97,16 @@ def __init__(self, tsv_file,
9597
num_chr=False,
9698
ambiguous_mask=None,
9799
incl_chromosomes=None,
98-
excl_chromosomes=None):
100+
excl_chromosomes=None,
101+
ignore_targets=False):
99102
self.tsv_file = tsv_file
100103
self.bed_columns = bed_columns
101104
self.num_chr = num_chr
102105
self.label_dtype = label_dtype
103106
self.ambiguous_mask = ambiguous_mask
104107
self.incl_chromosomes = incl_chromosomes
105108
self.excl_chromosomes = excl_chromosomes
109+
self.ignore_targets = ignore_targets
106110

107111
df_peek = pd.read_table(self.tsv_file,
108112
header=None,
@@ -141,7 +145,7 @@ def __getitem__(self, idx):
141145
row = self.df.iloc[idx]
142146
interval = pybedtools.create_interval_from_list([to_scalar(x) for x in row.iloc[:self.bed_columns]])
143147

144-
if self.n_tasks == 0:
148+
if self.ignore_targets or self.n_tasks == 0:
145149
labels = {}
146150
else:
147151
labels = row.iloc[self.bed_columns:].values.astype(self.label_dtype)
@@ -185,6 +189,8 @@ class SeqStringDataset(Dataset):
185189
# doc: reverse-complement fasta sequence if bed file defines negative strand
186190
force_upper:
187191
doc: Force uppercase output of sequences
192+
ignore_targets:
193+
doc: if True, don't return any target variables
188194
output_schema:
189195
inputs:
190196
name: seq
@@ -213,7 +219,8 @@ def __init__(self,
213219
auto_resize_len=None,
214220
# max_seq_len=None,
215221
# use_strand=False,
216-
force_upper=True):
222+
force_upper=True,
223+
ignore_targets=False):
217224

218225
self.num_chr_fasta = num_chr_fasta
219226
self.intervals_file = intervals_file
@@ -232,7 +239,8 @@ def __init__(self,
232239
self.bed = BedDataset(self.intervals_file,
233240
num_chr=self.num_chr_fasta,
234241
bed_columns=3,
235-
label_dtype=parse_dtype(label_dtype))
242+
label_dtype=parse_dtype(label_dtype),
243+
ignore_targets=ignore_targets)
236244
self.fasta_extractors = None
237245

238246
def __len__(self):
@@ -265,15 +273,12 @@ def __getitem__(self, idx):
265273
}
266274

267275
@classmethod
268-
def default_shape(cls):
269-
# correct the output schema - TODO - required?
270-
# self.output_schema_params = deepcopy(self.output_schema_params)
271-
# self.output_schema_params['inputs_shape'] = (1,)
272-
# if self.bed.n_tasks != 0:
273-
# self.output_schema_params['targets_shape'] = (self.bed.n_tasks,)
274-
275-
# self.output_schema = get_seq_dataset_output_schema(**self.output_schema_params)
276-
pass
276+
def get_output_schema(cls):
277+
kwargs = default_kwargs(cls)
278+
ignore_targets = kwargs['ignore_targets']
279+
if ignore_targets:
280+
cls.output_schema.targets = None
281+
return cls.output_schema
277282

278283

279284
# TODO - check lzamparo's dataloader:
@@ -320,7 +325,9 @@ class SeqDataset(Dataset):
320325
alphabet to use for the one-hot encoding. This defines the order of the one-hot encoding.
321326
Can either be a list or a string: 'DNA', 'RNA', 'AMINO_ACIDS'.
322327
dtype:
323-
doc: defines the numpy dtype of the returned array.
328+
doc: defines the numpy dtype of the returned array.
329+
ignore_targets:
330+
doc: if True, don't return any target variables
324331
325332
output_schema:
326333
inputs:
@@ -353,7 +360,9 @@ def __init__(self,
353360
alphabet_axis=1,
354361
dummy_axis=None,
355362
alphabet="ACGT",
363+
ignore_targets=False,
356364
dtype=None):
365+
# TODO - add disable target loading to manage the Basenji case
357366

358367
# make sure the alphabet axis and the dummy axis are valid:
359368
assert alphabet_axis >= 0 and (alphabet_axis < 2 or (alphabet_axis <= 2 and dummy_axis is not None))
@@ -369,13 +378,18 @@ def __init__(self,
369378
self.seq_string_dataset = SeqStringDataset(intervals_file, fasta_file, num_chr_fasta=num_chr_fasta,
370379
label_dtype=label_dtype, auto_resize_len=auto_resize_len,
371380
# use_strand=use_strand,
372-
force_upper=True)
381+
ignore_targets=ignore_targets)
382+
383+
if dummy_axis is not None and alphabet_axis == dummy_axis:
384+
raise ValueError("dummy_axis can't be the same as dummy_axis")
373385

374386
# set the transform parameters correctly
375-
existing_alphabet_axis = 1
376387
if dummy_axis is not None and dummy_axis < 2:
377388
# dummy axis is added somewhere in the middle, so the alphabet axis is at the end now
378389
existing_alphabet_axis = 2
390+
else:
391+
# alphabet axis stayed the same
392+
existing_alphabet_axis = 1
379393

380394
# check if no swapping needed
381395
if existing_alphabet_axis == self.alphabet_axis:
@@ -396,18 +410,47 @@ def __getitem__(self, idx):
396410
ret['inputs'] = self.input_tranform(str(ret["inputs"]))
397411
return ret
398412

399-
# TODO - compute the output shape based on the default value of parameters
400-
# - executed in kipoi_dataloader
401-
# TODO - how to specify the shape properly when using differnet default parameters?
402-
# - example: Basset dataloader
403413
@classmethod
404-
def default_shape(cls):
405-
# setup output schema
406-
# self.output_schema_params = deepcopy(self.output_schema_params)
407-
408-
# self.output_schema_params['inputs_shape'] = get_onehot_shape(self.alphabet_axis, self.dummy_axis,
409-
# self.auto_resize_len, self.alphabet)
410-
# if self.bed.n_tasks != 0:
411-
# self.output_schema_params['targets_shape'] = (self.bed.n_tasks,)
412-
# self.output_schema = get_seq_dataset_output_schema(**self.output_schema_params)
413-
pass
414+
def get_output_schema(cls):
415+
"""Get the output schema. Overrides the default `cls.output_schema`
416+
"""
417+
418+
# override the parent method
419+
kwargs = default_kwargs(cls)
420+
n_channels = len(kwargs['alphabet'])
421+
seqlen = kwargs['auto_resize_len']
422+
dummy_axis = kwargs['dummy_axis']
423+
alphabet_axis = kwargs['alphabet_axis']
424+
ignore_targets = kwargs['ignore_targets']
425+
426+
if ignore_targets:
427+
cls.output_schema.targets = None
428+
429+
if dummy_axis is not None and alphabet_axis == dummy_axis:
430+
raise ValueError("dummy_axis can't be the same as dummy_axis")
431+
432+
# default
433+
input_shape = (seqlen, n_channels)
434+
435+
if dummy_axis is not None and dummy_axis < 2:
436+
# dummy axis is added somewhere in the middle, so the alphabet axis is at the end now
437+
existing_alphabet_axis = 2
438+
else:
439+
existing_alphabet_axis = 1
440+
441+
if existing_alphabet_axis == alphabet_axis:
442+
alphabet_axis = None
443+
444+
# inject the dummy axis
445+
if dummy_axis is not None:
446+
input_shape = input_shape[:dummy_axis] + (1,) + input_shape[dummy_axis:]
447+
448+
# swap axes
449+
if alphabet_axis is not None:
450+
sh = list(input_shape)
451+
sh[alphabet_axis], sh[existing_alphabet_axis] = sh[existing_alphabet_axis], sh[alphabet_axis]
452+
input_shape = tuple(sh)
453+
454+
# now, modify the input schema
455+
cls.output_schema.inputs.shape = input_shape
456+
return cls.output_schema

kipoiseq/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import numpy as np
2-
from pybedtools import Interval
32

43

54
# alphabets:

tests/datasets/test_sequence.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import os
22
import numpy as np
33
import pytest
4+
from copy import deepcopy
45
from pybedtools import Interval
6+
from kipoi.utils import override_default_kwargs
57
from kipoiseq.transforms.functional import one_hot_dna
68
from kipoiseq.datasets.sequence import SeqStringDataset, SeqDataset, parse_dtype, BedDataset
79

@@ -86,3 +88,35 @@ def test_examples_exist(cls):
8688
dl_entries += 1
8789
assert dl_entries == len(ex)
8890
assert len(ex) == bed_entries
91+
92+
93+
def test_output_schape():
94+
Dl = deepcopy(SeqDataset)
95+
assert Dl.get_output_schema().inputs.shape == (None, 4)
96+
override_default_kwargs(Dl, {"auto_resize_len": 100})
97+
assert Dl.get_output_schema().inputs.shape == (100, 4)
98+
99+
override_default_kwargs(Dl, {"auto_resize_len": 100, "dummy_axis": 1, "alphabet_axis": 2})
100+
assert Dl.get_output_schema().inputs.shape == (100, 1, 4)
101+
override_default_kwargs(Dl, {"auto_resize_len": 100, "dummy_axis": None, "alphabet_axis": 1}) # reset
102+
override_default_kwargs(Dl, {"auto_resize_len": 100, "dummy_axis": 2})
103+
assert Dl.get_output_schema().inputs.shape == (100, 4, 1)
104+
override_default_kwargs(Dl, {"auto_resize_len": 100, "dummy_axis": None, "alphabet_axis": 1}) # reset
105+
106+
override_default_kwargs(Dl, {"auto_resize_len": 100, "alphabet": "ACGTD"})
107+
assert Dl.get_output_schema().inputs.shape == (100, 5)
108+
override_default_kwargs(Dl, {"auto_resize_len": 100, "alphabet": "ACGT"}) # reset
109+
110+
override_default_kwargs(Dl, {"auto_resize_len": 160, "dummy_axis": 2, "alphabet_axis": 0})
111+
assert Dl.get_output_schema().inputs.shape == (4, 160, 1)
112+
113+
override_default_kwargs(Dl, {"auto_resize_len": 160, "dummy_axis": 2, "alphabet_axis": 1})
114+
assert Dl.get_output_schema().inputs.shape == (160, 4, 1)
115+
targets = Dl.get_output_schema().targets
116+
assert targets.shape == (None,)
117+
118+
override_default_kwargs(Dl, {"ignore_targets": True})
119+
assert Dl.get_output_schema().targets is None
120+
# reset back
121+
override_default_kwargs(Dl, {"ignore_targets": False})
122+
Dl.output_schema.targets = targets

0 commit comments

Comments
 (0)