Skip to content

Commit 4a28c91

Browse files
authored
Merge pull request #19 from kipoi/fix_output_schema
fix get_output_schema
2 parents 6cc8c9a + 4af5af3 commit 4a28c91

File tree

3 files changed

+36
-30
lines changed

3 files changed

+36
-30
lines changed

kipoiseq/dataloaders/sequence.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -247,11 +247,12 @@ def __getitem__(self, idx):
247247

248248
@classmethod
249249
def get_output_schema(cls):
250+
output_schema = deepcopy(cls.output_schema)
250251
kwargs = default_kwargs(cls)
251252
ignore_targets = kwargs['ignore_targets']
252253
if ignore_targets:
253-
cls.output_schema.targets = None
254-
return cls.output_schema
254+
output_schema.targets = None
255+
return output_schema
255256

256257

257258
# TODO - properly deal with samples outside of the genome
@@ -354,6 +355,7 @@ def __getitem__(self, idx):
354355
def get_output_schema(cls):
355356
"""Get the output schema. Overrides the default `cls.output_schema`
356357
"""
358+
output_schema = deepcopy(cls.output_schema)
357359

358360
# get the default kwargs
359361
kwargs = default_kwargs(cls)
@@ -366,10 +368,10 @@ def get_output_schema(cls):
366368
input_shape = mock_input_transform.get_output_shape(kwargs['auto_resize_len'])
367369

368370
# modify it
369-
cls.output_schema.inputs.shape = input_shape
371+
output_schema.inputs.shape = input_shape
370372

371373
# (optionally) get rid of the target shape
372374
if kwargs['ignore_targets']:
373-
cls.output_schema.targets = None
375+
output_schema.targets = None
374376

375-
return cls.output_schema
377+
return output_schema

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from setuptools import setup, find_packages
55

66
requirements = [
7-
"kipoi>=0.4.2",
7+
"kipoi>=0.5.5",
88
# "genomelake",
99
"pybedtools",
1010
"pyfaidx",

tests/dataloaders/test_sequence.py

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -122,30 +122,34 @@ def test_examples_exist(cls):
122122
def test_output_schape():
123123
Dl = deepcopy(SeqIntervalDl)
124124
assert Dl.get_output_schema().inputs.shape == (None, 4)
125-
override_default_kwargs(Dl, {"auto_resize_len": 100})
126-
assert Dl.get_output_schema().inputs.shape == (100, 4)
127-
128-
override_default_kwargs(Dl, {"auto_resize_len": 100, "dummy_axis": 1, "alphabet_axis": 2})
129-
assert Dl.get_output_schema().inputs.shape == (100, 1, 4)
130-
override_default_kwargs(Dl, {"auto_resize_len": 100, "dummy_axis": None, "alphabet_axis": 1}) # reset
131-
override_default_kwargs(Dl, {"auto_resize_len": 100, "dummy_axis": 2})
132-
assert Dl.get_output_schema().inputs.shape == (100, 4, 1)
133-
override_default_kwargs(Dl, {"auto_resize_len": 100, "dummy_axis": None, "alphabet_axis": 1}) # reset
134-
135-
override_default_kwargs(Dl, {"auto_resize_len": 100, "alphabet": "ACGTD"})
136-
assert Dl.get_output_schema().inputs.shape == (100, 5)
137-
override_default_kwargs(Dl, {"auto_resize_len": 100, "alphabet": "ACGT"}) # reset
138-
139-
override_default_kwargs(Dl, {"auto_resize_len": 160, "dummy_axis": 2, "alphabet_axis": 0})
140-
assert Dl.get_output_schema().inputs.shape == (4, 160, 1)
141-
142-
override_default_kwargs(Dl, {"auto_resize_len": 160, "dummy_axis": 2, "alphabet_axis": 1})
143-
assert Dl.get_output_schema().inputs.shape == (160, 4, 1)
144-
targets = Dl.get_output_schema().targets
125+
Dlc = override_default_kwargs(Dl, {"auto_resize_len": 100})
126+
assert Dlc.get_output_schema().inputs.shape == (100, 4)
127+
128+
# original left intact
129+
assert Dl.get_output_schema().inputs.shape == (None, 4)
130+
131+
Dlc = override_default_kwargs(Dl, {"auto_resize_len": 100, "dummy_axis": 1, "alphabet_axis": 2})
132+
assert Dlc.get_output_schema().inputs.shape == (100, 1, 4)
133+
Dlc = override_default_kwargs(Dl, {"auto_resize_len": 100, "dummy_axis": 2})
134+
assert Dlc.get_output_schema().inputs.shape == (100, 4, 1)
135+
# original left intact
136+
assert Dl.get_output_schema().inputs.shape == (None, 4)
137+
138+
Dlc = override_default_kwargs(Dl, {"auto_resize_len": 100, "alphabet": "ACGTD"})
139+
assert Dlc.get_output_schema().inputs.shape == (100, 5)
140+
141+
Dlc = override_default_kwargs(Dl, {"auto_resize_len": 160, "dummy_axis": 2, "alphabet_axis": 0})
142+
assert Dlc.get_output_schema().inputs.shape == (4, 160, 1)
143+
144+
Dlc = override_default_kwargs(Dl, {"auto_resize_len": 160, "dummy_axis": 2, "alphabet_axis": 1})
145+
assert Dlc.get_output_schema().inputs.shape == (160, 4, 1)
146+
targets = Dlc.get_output_schema().targets
145147
assert targets.shape == (None,)
146148

147-
override_default_kwargs(Dl, {"ignore_targets": True})
148-
assert Dl.get_output_schema().targets is None
149+
Dlc = override_default_kwargs(Dl, {"ignore_targets": True})
150+
assert Dlc.get_output_schema().targets is None
149151
# reset back
150-
override_default_kwargs(Dl, {"ignore_targets": False})
151-
Dl.output_schema.targets = targets
152+
153+
# original left intact
154+
assert Dl.get_output_schema().inputs.shape == (None, 4)
155+
assert Dl.get_output_schema().targets.shape == (None, )

0 commit comments

Comments
 (0)