Skip to content

Commit 443cd1e

Browse files
authored
Merge pull request #17 from kipoi/reordered_shape
Ranname dataloaders and reordered shape
2 parents 47a4d4e + 2bb23ea commit 443cd1e

20 files changed

+657
-765
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,3 +106,4 @@ venv.bak/
106106
/tests/data/sample.fasta.fai
107107
**/downloaded/**
108108
/docs/mkdocs.yml
109+
/tests/data/sample.5kb.fa.fai

README.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88

99
Standard set of data-loaders for training and making predictions for DNA sequence-based models.
1010

11-
All dataloaders in `kipoiseq.datasets` decorated with `@kipoi_dataloader` (SeqDataset and SeqStringDataset) are compatible Kipoi models and can be directly used when specifying a new model in `model.yaml`:
11+
All dataloaders in `kipoiseq.dataloaders` decorated with `@kipoi_dataloader` (IntervalSeqDl and IntervalSeqStringDl) are compatible Kipoi models and can be directly used when specifying a new model in `model.yaml`:
1212
```yaml
1313
...
1414
default_dataloader:
15-
defined_as: kipoiseq.datasets.SeqDataset
15+
defined_as: kipoiseq.dataloaders.IntervalSeqDl
1616
default_args:
17-
auto_resize_len: 1000 # override default args in SeqDataset
17+
auto_resize_len: 1000 # override default args in IntervalSeqDl
1818

1919
dependencies:
2020
pip:
@@ -31,11 +31,11 @@ pip install kipoiseq
3131
## Getting started
3232

3333
```python
34-
from kipoiseq.datasets import SeqDataset
34+
from kipoiseq.dataloaders import IntervalSeqDl
3535

36-
dl = SeqDataset.init_example() # use the provided example files
36+
dl = IntervalSeqDl.init_example() # use the provided example files
3737
# your own files
38-
dl = SeqDataset("intervals.bed", "genome.fa")
38+
dl = IntervalSeqDl("intervals.bed", "genome.fa")
3939

4040
len(dl) # length of the dataset
4141

@@ -59,7 +59,7 @@ More info:
5959
- See [docs](https://kipoi.org/kipoiseq/)
6060

6161
## How to write your own data-loaders
62-
- Read the pytorch [Data Loading and Processing Tutorial](https://pytorch.org/tutorials/beginner/data_loading_tutorial.html) to become more familiar with transforms and datasets
63-
- Read the code for `SeqDataset` in [kipoiseq/datasets/sequence.py](https://github.com/kipoi/kipoiseq/blob/master/kipoiseq/datasets/sequence.py)
62+
- Read the pytorch [Data Loading and Processing Tutorial](https://pytorch.org/tutorials/beginner/data_loading_tutorial.html) to become more familiar with transforms and dataloaders
63+
- Read the code for `IntervalSeqDl` in [kipoiseq/dataloaders/sequence.py](https://github.com/kipoi/kipoiseq/blob/master/kipoiseq/dataloaders/sequence.py)
6464
- you can skip the `@kipoi_dataloader` and the long yaml doc-string. These are only required if you want to use dataloaders in Kipoi's model.yaml files.
6565
- Explore the available transforms ([functional](http://kipoi.org/kipoiseq/transforms/functional/), [class-based](http://kipoi.org/kipoiseq/transforms/transforms/)) or extractors ([kipoiseq](https://github.com/kipoi/kipoiseq/blob/master/kipoiseq/extractors.py), [genomelake](https://github.com/kundajelab/genomelake/blob/master/genomelake/extractors.py))

docs/pydocmd.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ generate:
2727
- kipoiseq.transforms.functional+
2828
- transforms/transforms.md:
2929
- kipoiseq.transforms.transforms+
30-
- datasets.md:
31-
- kipoiseq.datasets.sequence+
32-
- kipoiseq.datasets.splicing.SpliceDataset
30+
- dataloaders.md:
31+
- kipoiseq.dataloaders.sequence+
32+
- kipoiseq.dataloaders.splicing.MMSpliceDl
3333

3434
# - baz/cool-stuff.md:
3535
# - foobar.baz:
@@ -48,7 +48,7 @@ generate:
4848

4949
pages:
5050
- Home: index.md << ../README.md
51-
- Datsets: datasets.md
51+
- Dataloaders: dataloaders.md
5252
- Transforms:
5353
- Functional: transforms/functional.md
5454
- Class-based: transforms/transforms.md

kipoiseq/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
__email__ = '[email protected]'
55
__version__ = '0.1.1'
66

7-
from . import datasets
7+
from . import dataloaders
88
from . import extractors
99
from . import transforms
10-
11-
# from .datasets.sequence import SeqDataset, SeqStringDataset
Lines changed: 37 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,11 @@
99
from kipoi.data import Dataset, kipoi_dataloader
1010
from kipoi.specs import Author, Dependencies
1111
from kipoi.utils import default_kwargs
12-
from six import string_types
13-
1412

1513
from kipoiseq.extractors import FastaStringExtractor
16-
from kipoiseq.transforms import SwapAxes, DummyAxis, Compose, OneHot
17-
from kipoiseq.transforms.functional import one_hot, resize_interval
18-
from kipoiseq.utils import get_alphabet, get_onehot_shape, to_scalar
14+
from kipoiseq.transforms import SwapAxes, DummyAxis, Compose, OneHot, ReorderedOneHot
15+
from kipoiseq.transforms.functional import resize_interval
16+
from kipoiseq.utils import to_scalar, parse_dtype
1917

2018
import pybedtools
2119
from pybedtools import BedTool, Interval
@@ -28,33 +26,7 @@
2826
Author(name='Roman Kreuzhuber', github='krrome')]
2927

3028
# Object exported on import *
31-
__all__ = ['BedDataset', 'SeqDataset', 'SeqStringDataset']
32-
33-
34-
def parse_dtype(dtype):
35-
dtypes = {'int': int, 'string': str, 'float': float, 'bool': bool}
36-
if dtype is None:
37-
return None
38-
if dtype in list(dtypes.values()):
39-
return dtype
40-
if dtype not in dtypes:
41-
raise Exception("Datatype '{0}' not recognized. Allowed are: {1}".format(dtype, str(list(dtypes.keys()))))
42-
return dtypes[dtype]
43-
44-
45-
def parse_alphabet(alphabet):
46-
if isinstance(alphabet, str):
47-
return list(alphabet)
48-
else:
49-
return alphabet
50-
51-
52-
def parse_type(dtype):
53-
if isinstance(dtype, string_types):
54-
if dtype in dir(np):
55-
return getattr(np, dtype)
56-
else:
57-
return dtype
29+
__all__ = ['IntervalSeqDl', 'IntervalSeqStringDl', 'BedDataset']
5830

5931

6032
class BedDataset(object):
@@ -68,7 +40,7 @@ class BedDataset(object):
6840
bed_columns: number of columns corresponding to the bed file. All the columns
6941
after that will be parsed as targets
7042
num_chr: if specified, 'chr' in the chromosome name will be dropped
71-
label_dtype: specific data type for labels
43+
label_dtype: specific data type for labels, Example: `float` or `np.float32`
7244
ambiguous_mask: if specified, rows containing only ambiguous_mask values will be skipped
7345
incl_chromosomes: exclusive list of chromosome names to include in the final dataset.
7446
if not None, only these will be present in the dataset
@@ -99,6 +71,7 @@ def __init__(self, tsv_file,
9971
incl_chromosomes=None,
10072
excl_chromosomes=None,
10173
ignore_targets=False):
74+
# TODO - `chrom` column: use pd.Categorical for memory efficiency
10275
self.tsv_file = tsv_file
10376
self.bed_columns = bed_columns
10477
self.num_chr = num_chr
@@ -159,7 +132,7 @@ def get_targets(self):
159132

160133

161134
@kipoi_dataloader(override={"dependencies": deps, 'info.authors': package_authors})
162-
class SeqStringDataset(Dataset):
135+
class IntervalSeqStringDl(Dataset):
163136
"""
164137
info:
165138
doc: >
@@ -180,7 +153,7 @@ class SeqStringDataset(Dataset):
180153
num_chr_fasta:
181154
doc: True, the the dataloader will make sure that the chromosomes don't start with chr.
182155
label_dtype:
183-
doc: None, datatype of the task labels taken from the intervals_file. Allowed - string', 'int', 'float', 'bool'
156+
doc: None, datatype of the task labels taken from the intervals_file. Example - str, int, float, np.float32
184157
auto_resize_len:
185158
doc: None, required sequence length.
186159
# max_seq_len:
@@ -281,15 +254,11 @@ def get_output_schema(cls):
281254
return cls.output_schema
282255

283256

284-
# TODO - check lzamparo's dataloader:
285-
# - https://github.com/kipoi/kipoiseq/issues/1#issuecomment-427412487
286-
# - https://raw.githubusercontent.com/lzamparo/bindspace_revisions/master/deepbind/src/dataloader.py
287-
288-
# TODO - properly deal with samples outside
257+
# TODO - properly deal with samples outside of the genome
289258

290259

291260
@kipoi_dataloader(override={"dependencies": deps, 'info.authors': package_authors})
292-
class SeqDataset(Dataset):
261+
class IntervalSeqDl(Dataset):
293262
"""
294263
info:
295264
doc: >
@@ -311,7 +280,7 @@ class SeqDataset(Dataset):
311280
num_chr_fasta:
312281
doc: True, the the dataloader will make sure that the chromosomes don't start with chr.
313282
label_dtype:
314-
doc: None, datatype of the task labels taken from the intervals_file. Allowed - string', 'int', 'float', 'bool'
283+
doc: 'None, datatype of the task labels taken from the intervals_file. Example: str, int, float, np.float32'
315284
auto_resize_len:
316285
doc: None, required sequence length.
317286
# use_strand:
@@ -323,9 +292,9 @@ class SeqDataset(Dataset):
323292
alphabet:
324293
doc: >
325294
alphabet to use for the one-hot encoding. This defines the order of the one-hot encoding.
326-
Can either be a list or a string: 'DNA', 'RNA', 'AMINO_ACIDS'.
295+
Can either be a list or a string: 'ACGT' or ['A, 'C', 'G', 'T']. Default: 'ACGT'
327296
dtype:
328-
doc: defines the numpy dtype of the returned array.
297+
doc: 'defines the numpy dtype of the returned array. Example: int, np.int32, np.float32, float'
329298
ignore_targets:
330299
doc: if True, don't return any target variables
331300
@@ -362,95 +331,45 @@ def __init__(self,
362331
alphabet="ACGT",
363332
ignore_targets=False,
364333
dtype=None):
365-
# TODO - add disable target loading to manage the Basenji case
366-
367-
# make sure the alphabet axis and the dummy axis are valid:
368-
assert alphabet_axis >= 0 and (alphabet_axis < 2 or (alphabet_axis <= 2 and dummy_axis is not None))
369-
assert dummy_axis is None or (dummy_axis >= 0 and dummy_axis <= 2 and alphabet_axis != dummy_axis)
370-
371-
# transform parameters
372-
self.alphabet_axis = alphabet_axis
373-
self.dummy_axis = dummy_axis
374-
self.alphabet = parse_alphabet(alphabet)
375-
self.dtype = parse_type(dtype)
376-
377-
# core dataset
378-
self.seq_string_dataset = SeqStringDataset(intervals_file, fasta_file, num_chr_fasta=num_chr_fasta,
379-
label_dtype=label_dtype, auto_resize_len=auto_resize_len,
380-
# use_strand=use_strand,
381-
ignore_targets=ignore_targets)
382-
383-
if dummy_axis is not None and alphabet_axis == dummy_axis:
384-
raise ValueError("dummy_axis can't be the same as dummy_axis")
385-
386-
# set the transform parameters correctly
387-
if dummy_axis is not None and dummy_axis < 2:
388-
# dummy axis is added somewhere in the middle, so the alphabet axis is at the end now
389-
existing_alphabet_axis = 2
390-
else:
391-
# alphabet axis stayed the same
392-
existing_alphabet_axis = 1
334+
# core dataset, not using the one-hot encoding params
335+
self.seq_dl = IntervalSeqStringDl(intervals_file, fasta_file, num_chr_fasta=num_chr_fasta,
336+
label_dtype=label_dtype, auto_resize_len=auto_resize_len,
337+
# use_strand=use_strand,
338+
ignore_targets=ignore_targets)
393339

394-
# check if no swapping needed
395-
if existing_alphabet_axis == self.alphabet_axis:
396-
self.alphabet_axis = None
397-
398-
# how to transform the input
399-
self.input_tranform = Compose([
400-
OneHot(self.alphabet, dtype=self.dtype), # one-hot-encode
401-
DummyAxis(self.dummy_axis), # optionally inject the dummy axis
402-
SwapAxes(existing_alphabet_axis, self.alphabet_axis), # put the alphabet axis elsewhere
403-
])
340+
self.input_transform = ReorderedOneHot(alphabet=alphabet,
341+
dtype=dtype,
342+
alphabet_axis=alphabet_axis,
343+
dummy_axis=dummy_axis)
404344

405345
def __len__(self):
406-
return len(self.seq_string_dataset)
346+
return len(self.seq_dl)
407347

408348
def __getitem__(self, idx):
409-
ret = self.seq_string_dataset[idx]
410-
ret['inputs'] = self.input_tranform(str(ret["inputs"]))
349+
ret = self.seq_dl[idx]
350+
ret['inputs'] = self.input_transform(str(ret["inputs"]))
411351
return ret
412352

413353
@classmethod
414354
def get_output_schema(cls):
415355
"""Get the output schema. Overrides the default `cls.output_schema`
416356
"""
417357

418-
# override the parent method
358+
# get the default kwargs
419359
kwargs = default_kwargs(cls)
420-
n_channels = len(kwargs['alphabet'])
421-
seqlen = kwargs['auto_resize_len']
422-
dummy_axis = kwargs['dummy_axis']
423-
alphabet_axis = kwargs['alphabet_axis']
424-
ignore_targets = kwargs['ignore_targets']
425360

426-
if ignore_targets:
427-
cls.output_schema.targets = None
428-
429-
if dummy_axis is not None and alphabet_axis == dummy_axis:
430-
raise ValueError("dummy_axis can't be the same as dummy_axis")
431-
432-
# default
433-
input_shape = (seqlen, n_channels)
434-
435-
if dummy_axis is not None and dummy_axis < 2:
436-
# dummy axis is added somewhere in the middle, so the alphabet axis is at the end now
437-
existing_alphabet_axis = 2
438-
else:
439-
existing_alphabet_axis = 1
440-
441-
if existing_alphabet_axis == alphabet_axis:
442-
alphabet_axis = None
361+
# figure out the input shape
362+
mock_input_transform = ReorderedOneHot(alphabet=kwargs['alphabet'],
363+
dtype=kwargs['dtype'],
364+
alphabet_axis=kwargs['alphabet_axis'],
365+
dummy_axis=kwargs['dummy_axis'])
366+
input_shape = mock_input_transform.get_output_shape(kwargs['auto_resize_len'])
443367

444-
# inject the dummy axis
445-
if dummy_axis is not None:
446-
input_shape = input_shape[:dummy_axis] + (1,) + input_shape[dummy_axis:]
368+
# modify it
369+
cls.output_schema.inputs.shape = input_shape
447370

448-
# swap axes
449-
if alphabet_axis is not None:
450-
sh = list(input_shape)
451-
sh[alphabet_axis], sh[existing_alphabet_axis] = sh[existing_alphabet_axis], sh[alphabet_axis]
452-
input_shape = tuple(sh)
371+
# (optionally) get rid of the target shape
372+
if kwargs['ignore_targets']:
373+
cls.output_schema.targets = None
453374

454-
# now, modify the input schema
455-
cls.output_schema.inputs.shape = input_shape
456375
return cls.output_schema

0 commit comments

Comments
 (0)