88from kipoi .plugin import is_installed
99from kipoi .data import Dataset , kipoi_dataloader
1010from kipoi .specs import Author , Dependencies
11+ from kipoi .utils import default_kwargs
1112from six import string_types
1213
1314
@@ -72,6 +73,7 @@ class BedDataset(object):
7273 incl_chromosomes: exclusive list of chromosome names to include in the final dataset.
7374 if not None, only these will be present in the dataset
7475 excl_chromosomes: list of chromosome names to omit from the dataset.
76+ ignore_targets: if True, target variables are ignored
7577 """
7678
7779 # bed types accorging to
@@ -95,14 +97,16 @@ def __init__(self, tsv_file,
9597 num_chr = False ,
9698 ambiguous_mask = None ,
9799 incl_chromosomes = None ,
98- excl_chromosomes = None ):
100+ excl_chromosomes = None ,
101+ ignore_targets = False ):
99102 self .tsv_file = tsv_file
100103 self .bed_columns = bed_columns
101104 self .num_chr = num_chr
102105 self .label_dtype = label_dtype
103106 self .ambiguous_mask = ambiguous_mask
104107 self .incl_chromosomes = incl_chromosomes
105108 self .excl_chromosomes = excl_chromosomes
109+ self .ignore_targets = ignore_targets
106110
107111 df_peek = pd .read_table (self .tsv_file ,
108112 header = None ,
@@ -141,7 +145,7 @@ def __getitem__(self, idx):
141145 row = self .df .iloc [idx ]
142146 interval = pybedtools .create_interval_from_list ([to_scalar (x ) for x in row .iloc [:self .bed_columns ]])
143147
144- if self .n_tasks == 0 :
148+ if self .ignore_targets or self . n_tasks == 0 :
145149 labels = {}
146150 else :
147151 labels = row .iloc [self .bed_columns :].values .astype (self .label_dtype )
@@ -185,6 +189,8 @@ class SeqStringDataset(Dataset):
185189 # doc: reverse-complement fasta sequence if bed file defines negative strand
186190 force_upper:
187191 doc: Force uppercase output of sequences
192+ ignore_targets:
193+ doc: if True, don't return any target variables
188194 output_schema:
189195 inputs:
190196 name: seq
@@ -213,7 +219,8 @@ def __init__(self,
213219 auto_resize_len = None ,
214220 # max_seq_len=None,
215221 # use_strand=False,
216- force_upper = True ):
222+ force_upper = True ,
223+ ignore_targets = False ):
217224
218225 self .num_chr_fasta = num_chr_fasta
219226 self .intervals_file = intervals_file
@@ -232,7 +239,8 @@ def __init__(self,
232239 self .bed = BedDataset (self .intervals_file ,
233240 num_chr = self .num_chr_fasta ,
234241 bed_columns = 3 ,
235- label_dtype = parse_dtype (label_dtype ))
242+ label_dtype = parse_dtype (label_dtype ),
243+ ignore_targets = ignore_targets )
236244 self .fasta_extractors = None
237245
238246 def __len__ (self ):
@@ -265,15 +273,12 @@ def __getitem__(self, idx):
265273 }
266274
267275 @classmethod
268- def default_shape (cls ):
269- # correct the output schema - TODO - required?
270- # self.output_schema_params = deepcopy(self.output_schema_params)
271- # self.output_schema_params['inputs_shape'] = (1,)
272- # if self.bed.n_tasks != 0:
273- # self.output_schema_params['targets_shape'] = (self.bed.n_tasks,)
274-
275- # self.output_schema = get_seq_dataset_output_schema(**self.output_schema_params)
276- pass
276+ def get_output_schema (cls ):
277+ kwargs = default_kwargs (cls )
278+ ignore_targets = kwargs ['ignore_targets' ]
279+ if ignore_targets :
280+ cls .output_schema .targets = None
281+ return cls .output_schema
277282
278283
279284# TODO - check lzamparo's dataloader:
@@ -320,7 +325,9 @@ class SeqDataset(Dataset):
320325 alphabet to use for the one-hot encoding. This defines the order of the one-hot encoding.
321326 Can either be a list or a string: 'DNA', 'RNA', 'AMINO_ACIDS'.
322327 dtype:
323- doc: defines the numpy dtype of the returned array.
328+ doc: defines the numpy dtype of the returned array.
329+ ignore_targets:
330+ doc: if True, don't return any target variables
324331
325332 output_schema:
326333 inputs:
@@ -353,7 +360,9 @@ def __init__(self,
353360 alphabet_axis = 1 ,
354361 dummy_axis = None ,
355362 alphabet = "ACGT" ,
363+ ignore_targets = False ,
356364 dtype = None ):
365+ # TODO - add disable target loading to manage the Basenji case
357366
358367 # make sure the alphabet axis and the dummy axis are valid:
359368 assert alphabet_axis >= 0 and (alphabet_axis < 2 or (alphabet_axis <= 2 and dummy_axis is not None ))
@@ -369,13 +378,18 @@ def __init__(self,
369378 self .seq_string_dataset = SeqStringDataset (intervals_file , fasta_file , num_chr_fasta = num_chr_fasta ,
370379 label_dtype = label_dtype , auto_resize_len = auto_resize_len ,
371380 # use_strand=use_strand,
372- force_upper = True )
381+ ignore_targets = ignore_targets )
382+
383+ if dummy_axis is not None and alphabet_axis == dummy_axis :
384+ raise ValueError ("dummy_axis can't be the same as dummy_axis" )
373385
374386 # set the transform parameters correctly
375- existing_alphabet_axis = 1
376387 if dummy_axis is not None and dummy_axis < 2 :
377388 # dummy axis is added somewhere in the middle, so the alphabet axis is at the end now
378389 existing_alphabet_axis = 2
390+ else :
391+ # alphabet axis stayed the same
392+ existing_alphabet_axis = 1
379393
380394 # check if no swapping needed
381395 if existing_alphabet_axis == self .alphabet_axis :
@@ -396,18 +410,47 @@ def __getitem__(self, idx):
396410 ret ['inputs' ] = self .input_tranform (str (ret ["inputs" ]))
397411 return ret
398412
399- # TODO - compute the output shape based on the default value of parameters
400- # - executed in kipoi_dataloader
401- # TODO - how to specify the shape properly when using differnet default parameters?
402- # - example: Basset dataloader
403413 @classmethod
404- def default_shape (cls ):
405- # setup output schema
406- # self.output_schema_params = deepcopy(self.output_schema_params)
407-
408- # self.output_schema_params['inputs_shape'] = get_onehot_shape(self.alphabet_axis, self.dummy_axis,
409- # self.auto_resize_len, self.alphabet)
410- # if self.bed.n_tasks != 0:
411- # self.output_schema_params['targets_shape'] = (self.bed.n_tasks,)
412- # self.output_schema = get_seq_dataset_output_schema(**self.output_schema_params)
413- pass
414+ def get_output_schema (cls ):
415+ """Get the output schema. Overrides the default `cls.output_schema`
416+ """
417+
418+ # override the parent method
419+ kwargs = default_kwargs (cls )
420+ n_channels = len (kwargs ['alphabet' ])
421+ seqlen = kwargs ['auto_resize_len' ]
422+ dummy_axis = kwargs ['dummy_axis' ]
423+ alphabet_axis = kwargs ['alphabet_axis' ]
424+ ignore_targets = kwargs ['ignore_targets' ]
425+
426+ if ignore_targets :
427+ cls .output_schema .targets = None
428+
429+ if dummy_axis is not None and alphabet_axis == dummy_axis :
430+ raise ValueError ("dummy_axis can't be the same as dummy_axis" )
431+
432+ # default
433+ input_shape = (seqlen , n_channels )
434+
435+ if dummy_axis is not None and dummy_axis < 2 :
436+ # dummy axis is added somewhere in the middle, so the alphabet axis is at the end now
437+ existing_alphabet_axis = 2
438+ else :
439+ existing_alphabet_axis = 1
440+
441+ if existing_alphabet_axis == alphabet_axis :
442+ alphabet_axis = None
443+
444+ # inject the dummy axis
445+ if dummy_axis is not None :
446+ input_shape = input_shape [:dummy_axis ] + (1 ,) + input_shape [dummy_axis :]
447+
448+ # swap axes
449+ if alphabet_axis is not None :
450+ sh = list (input_shape )
451+ sh [alphabet_axis ], sh [existing_alphabet_axis ] = sh [existing_alphabet_axis ], sh [alphabet_axis ]
452+ input_shape = tuple (sh )
453+
454+ # now, modify the input schema
455+ cls .output_schema .inputs .shape = input_shape
456+ return cls .output_schema
0 commit comments