Skip to content

Commit 7da96fd

Browse files
Add possibility to use task-specific transformations.
1 parent d1db57e commit 7da96fd

File tree

4 files changed

+35
-15
lines changed

4 files changed

+35
-15
lines changed

continuum/scenarios/base.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,16 @@ class _BaseScenario(abc.ABC):
1616
1717
:param cl_dataset: A Continuum dataset.
1818
:param nb_tasks: The number of tasks to do.
19-
:param transformations: The PyTorch transformations.
20-
:param train: Boolean flag whether to use the train or test subset.
19+
:param transformations: A list of transformations applied to all tasks. If
20+
it's a list of list, then the transformation will be
21+
different per task.
2122
"""
2223

2324
def __init__(
2425
self,
2526
cl_dataset: _ContinuumDataset,
2627
nb_tasks: int,
27-
transformations: List[Callable] = None
28+
transformations: Union[List[Callable], List[List[Callable]]] = None
2829
) -> None:
2930

3031
self.cl_dataset = cl_dataset
@@ -33,9 +34,15 @@ def __init__(
3334
if transformations is None:
3435
transformations = self.cl_dataset.transformations
3536
if self.cl_dataset.data_type == "segmentation":
36-
self.trsf = SegmentationCompose(transformations)
37+
composer = SegmentationCompose
3738
else:
38-
self.trsf = transforms.Compose(transformations)
39+
composer = transforms.Compose
40+
if transformations is not None and isinstance(transformations[0], list):
41+
# We have list of list of callable, where each sublist is dedicated to
42+
# a task.
43+
self.trsf = [composer(trsf) for trsf in transformations]
44+
else:
45+
self.trsf = composer(transformations)
3946

4047
@abc.abstractmethod
4148
def _setup(self, nb_tasks: int) -> int:
@@ -87,10 +94,17 @@ def __getitem__(self, task_index: Union[int, slice]):
8794
even slices.
8895
:return: A train PyTorch's Datasets.
8996
"""
97+
if isinstance(task_index, slice) and isinstance(self.trsf, list):
98+
raise ValueError(
99+
f"You cannot select multiple task ({task_index}) when you have a "
100+
"different set of transformations per task"
101+
)
102+
90103
x, y, t, _ = self._select_data_by_task(task_index)
104+
91105
return TaskSet(
92106
x, y, t,
93-
trsf=self.trsf,
107+
trsf=self.trsf[task_index] if isinstance(self.trsf, list) else self.trsf,
94108
data_type=self.cl_dataset.data_type,
95109
bounding_boxes=self.cl_dataset.bounding_boxes
96110
)

continuum/scenarios/class_incremental.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ class ClassIncremental(_BaseScenario):
2020
(e.g. increment=[5,1,1,1,1]).
2121
:param initial_increment: A different task size applied only for the first task.
2222
Desactivated if `increment` is a list.
23-
:param transformations: A list of transformations applied to all tasks.
23+
:param transformations: A list of transformations applied to all tasks. If
24+
it's a list of list, then the transformation will be
25+
different per task.
2426
:param class_order: An optional custom class order, used for NC.
2527
e.g. [0,1,2,3,4,5,6,7,8,9] or [5,2,4,1,8,6,7,9,0,3]
2628
"""
@@ -31,7 +33,7 @@ def __init__(
3133
nb_tasks: int = 0,
3234
increment: Union[List[int], int] = 0,
3335
initial_increment: int = 0,
34-
transformations: List[Callable] = None,
36+
transformations: Union[List[Callable], List[List[Callable]]] = None,
3537
class_order: Union[List[int], None]=None
3638
) -> None:
3739

continuum/scenarios/continual_scenario.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,18 @@ class ContinualScenario(_BaseScenario):
1414
Scenario: the scenario is entirely defined by the task label vector in the cl_dataset
1515
1616
:param cl_dataset: A continual dataset.
17-
:param transformations: A list of transformations applied to all tasks.
17+
:param transformations: A list of transformations applied to all tasks. If
18+
it's a list of list, then the transformation will be
19+
different per task.
1820
"""
1921

2022
def __init__(
2123
self,
2224
cl_dataset: _ContinuumDataset,
23-
transformations: List[Callable] = None,
25+
transformations: Union[List[Callable], List[List[Callable]]] = None,
2426
) -> None:
2527
self.check_data(cl_dataset)
26-
super().__init__(cl_dataset=cl_dataset, nb_tasks=self.nb_tasks,transformations=transformations)
28+
super().__init__(cl_dataset=cl_dataset, nb_tasks=self.nb_tasks, transformations=transformations)
2729

2830
def check_data(self, cl_dataset: _ContinuumDataset):
2931
x, y, t = cl_dataset.get_data()
@@ -47,4 +49,4 @@ def check_data(self, cl_dataset: _ContinuumDataset):
4749

4850
#nothing to do in the setup function
4951
def _setup(self, nb_tasks: int) -> int:
50-
return nb_tasks
52+
return nb_tasks

continuum/scenarios/instance_incremental.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import warnings
2-
from typing import Callable, List, Optional
2+
from typing import Callable, List, Optional, Union
33

44
import numpy as np
55

@@ -14,15 +14,17 @@ class InstanceIncremental(_BaseScenario):
1414
1515
:param cl_dataset: A continual dataset.
1616
:param nb_tasks: The scenario number of tasks.
17-
:param transformations: List of transformations to apply to all tasks.
17+
:param transformations: A list of transformations applied to all tasks. If
18+
it's a list of list, then the transformation will be
19+
different per task.
1820
:param random_seed: A random seed to init random processes.
1921
"""
2022

2123
def __init__(
2224
self,
2325
cl_dataset: _ContinuumDataset,
2426
nb_tasks: Optional[int] = None,
25-
transformations: List[Callable] = None,
27+
transformations: Union[List[Callable], List[List[Callable]]] = None,
2628
random_seed: int = 1
2729
):
2830
super().__init__(cl_dataset=cl_dataset, nb_tasks=nb_tasks, transformations=transformations)

0 commit comments

Comments
 (0)