Skip to content

Commit 758e060

Browse files
christinadionysiomboehm7
authored andcommitted
[SYSTEMDS-3701] Additional scuro data representations
Closes #2111.
1 parent c86aa0a commit 758e060

File tree

27 files changed

+858
-193
lines changed

27 files changed

+858
-193
lines changed

src/main/python/systemds/scuro/__init__.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,54 @@
1818
# under the License.
1919
#
2020
# -------------------------------------------------------------
21+
from systemds.scuro.representations.representation import Representation
22+
from systemds.scuro.representations.average import Average
23+
from systemds.scuro.representations.concatenation import Concatenation
24+
from systemds.scuro.representations.fusion import Fusion
25+
from systemds.scuro.representations.sum import Sum
26+
from systemds.scuro.representations.max import RowMax
27+
from systemds.scuro.representations.multiplication import Multiplication
28+
from systemds.scuro.representations.mel_spectrogram import MelSpectrogram
29+
from systemds.scuro.representations.resnet import ResNet
30+
from systemds.scuro.representations.bert import Bert
31+
from systemds.scuro.representations.unimodal import UnimodalRepresentation
32+
from systemds.scuro.representations.lstm import LSTM
33+
from systemds.scuro.representations.utils import NPY, Pickle, HDF5, JSON
34+
from systemds.scuro.models.model import Model
35+
from systemds.scuro.models.discrete_model import DiscreteModel
36+
from systemds.scuro.modality.aligned_modality import AlignedModality
37+
from systemds.scuro.modality.audio_modality import AudioModality
38+
from systemds.scuro.modality.video_modality import VideoModality
39+
from systemds.scuro.modality.text_modality import TextModality
40+
from systemds.scuro.modality.modality import Modality
41+
from systemds.scuro.aligner.dr_search import DRSearch
42+
from systemds.scuro.aligner.task import Task
43+
44+
45+
__all__ = ["Representation",
46+
"Average",
47+
"Concatenation",
48+
"Fusion",
49+
"Sum",
50+
"RowMax",
51+
"Multiplication",
52+
"MelSpectrogram",
53+
"ResNet",
54+
"Bert",
55+
"UnimodalRepresentation",
56+
"LSTM",
57+
"NPY",
58+
"Pickle",
59+
"HDF5",
60+
"JSON",
61+
"Model",
62+
"DiscreteModel",
63+
"AlignedModality",
64+
"AudioModality",
65+
"VideoModality",
66+
"TextModality",
67+
"Modality",
68+
"DRSearch",
69+
"Task"
70+
]
71+

src/main/python/systemds/scuro/aligner/dr_search.py

Lines changed: 62 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,24 @@
1919
#
2020
# -------------------------------------------------------------
2121
import itertools
22+
import random
2223
from typing import List
2324

24-
from aligner.task import Task
25-
from modality.aligned_modality import AlignedModality
26-
from modality.modality import Modality
27-
from representations.representation import Representation
25+
from systemds.scuro.aligner.task import Task
26+
from systemds.scuro.modality.aligned_modality import AlignedModality
27+
from systemds.scuro.modality.modality import Modality
28+
from systemds.scuro.representations.representation import Representation
29+
30+
import warnings
31+
32+
warnings.filterwarnings('ignore')
2833

2934

3035
def get_modalities_by_name(modalities, name):
3136
for modality in modalities:
3237
if modality.name == name:
3338
return modality
34-
39+
3540
raise 'Modality ' + name + 'not in modalities'
3641

3742

@@ -51,9 +56,9 @@ def __init__(self, modalities: List[Modality], task: Task, representations: List
5156
self.best_modalities = None
5257
self.best_representation = None
5358
self.best_score = -1
54-
59+
5560
def set_best_params(self, modality_name: str, representation: Representation,
56-
score: float, modality_names: List[str]):
61+
scores: List[float], modality_names: List[str]):
5762
"""
5863
Updates the best parameters for given modalities, representation, and score
5964
:param modality_name: The name of the aligned modality
@@ -62,43 +67,66 @@ def set_best_params(self, modality_name: str, representation: Representation,
6267
:param modality_names: List of modality names used in this setting
6368
:return:
6469
"""
65-
70+
6671
# check if modality name is already in dictionary
6772
if modality_name not in self.scores.keys():
6873
# if not add it to dictionary
6974
self.scores[modality_name] = {}
70-
75+
7176
# set score for representation
72-
self.scores[modality_name][representation] = score
73-
77+
self.scores[modality_name][representation] = scores
78+
7479
# compare current score with best score
75-
if score > self.best_score:
76-
self.best_score = score
80+
if scores[1] > self.best_score:
81+
self.best_score = scores[1]
7782
self.best_representation = representation
7883
self.best_modalities = modality_names
79-
80-
def fit(self):
84+
85+
def reset_best_params(self):
86+
self.best_score = -1
87+
self.best_modalities = None
88+
self.best_representation = None
89+
self.scores = {}
90+
91+
def fit_random(self, seed=-1):
92+
"""
93+
This method randomly selects a modality or combination of modalities and representation
94+
"""
95+
if seed != -1:
96+
random.seed(seed)
97+
98+
modalities = []
99+
for M in range(1, len(self.modalities) + 1):
100+
for combination in itertools.combinations(self.modalities, M):
101+
modalities.append(combination)
102+
103+
modality_combination = random.choice(modalities)
104+
representation = random.choice(self.representations)
105+
106+
modality = AlignedModality(representation, list(modality_combination)) # noqa
107+
modality.combine()
108+
109+
scores = self.task.run(modality.data)
110+
self.set_best_params(modality.name, representation, scores, modality.get_modality_names())
111+
112+
return self.best_representation, self.best_score, self.best_modalities
113+
114+
def fit_enumerate_all(self):
81115
"""
82116
This method finds the best representation out of a given List of uni-modal modalities and
83117
representations
84118
:return: The best parameters found in the search procedure
85119
"""
86-
120+
87121
for M in range(1, len(self.modalities) + 1):
88122
for combination in itertools.combinations(self.modalities, M):
89-
if len(combination) == 1:
90-
modality = combination[0]
91-
score = self.task.run(modality.representation.scale_data(modality.data, self.task.train_indices))
92-
self.set_best_params(modality.name, modality.representation.name, score, [modality.name])
93-
self.scores[modality] = score
94-
else:
95-
for representation in self.representations:
96-
modality = AlignedModality(representation, list(combination)) # noqa
97-
modality.combine(self.task.train_indices)
98-
99-
score = self.task.run(modality.data)
100-
self.set_best_params(modality.name, representation, score, modality.get_modality_names())
101-
123+
for representation in self.representations:
124+
modality = AlignedModality(representation, list(combination)) # noqa
125+
modality.combine()
126+
127+
scores = self.task.run(modality.data)
128+
self.set_best_params(modality.name, representation, scores, modality.get_modality_names())
129+
102130
return self.best_representation, self.best_score, self.best_modalities
103131

104132
def transform(self, modalities: List[Modality]):
@@ -108,17 +136,16 @@ def transform(self, modalities: List[Modality]):
108136
:param modalities: List of uni-modal modalities
109137
:return: aligned data
110138
"""
111-
139+
112140
if self.best_score == -1:
113141
raise 'Please fit representations first!'
114-
142+
115143
used_modalities = []
116-
144+
117145
for modality_name in self.best_modalities:
118146
used_modalities.append(get_modalities_by_name(modalities, modality_name))
119-
147+
120148
modality = AlignedModality(self.best_representation, used_modalities) # noqa
121149
modality.combine(self.task.train_indices)
122-
150+
123151
return modality.data
124-

src/main/python/systemds/scuro/aligner/task.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
# -------------------------------------------------------------
2121
from typing import List
2222

23-
from models.model import Model
23+
from systemds.scuro.models.model import Model
2424

2525

2626
class Task:

src/main/python/systemds/scuro/main.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,16 @@
2222
import json
2323
from datetime import datetime
2424

25-
from representations.average import Averaging
26-
from representations.concatenation import Concatenation
27-
from modality.aligned_modality import AlignedModality
28-
from modality.text_modality import TextModality
29-
from modality.video_modality import VideoModality
30-
from modality.audio_modality import AudioModality
31-
from representations.unimodal import Pickle, JSON, HDF5, NPY
32-
from models.discrete_model import DiscreteModel
33-
from aligner.task import Task
34-
from aligner.dr_search import DRSearch
25+
from systemds.scuro.representations.average import Average
26+
from systemds.scuro.representations.concatenation import Concatenation
27+
from systemds.scuro.modality.aligned_modality import AlignedModality
28+
from systemds.scuro.modality.text_modality import TextModality
29+
from systemds.scuro.modality.video_modality import VideoModality
30+
from systemds.scuro.modality.audio_modality import AudioModality
31+
from systemds.scuro.representations.unimodal import Pickle, JSON, HDF5, NPY
32+
from systemds.scuro.models.discrete_model import DiscreteModel
33+
from systemds.scuro.aligner.task import Task
34+
from systemds.scuro.aligner.dr_search import DRSearch
3535

3636

3737
class CustomTask(Task):
@@ -66,8 +66,8 @@ def run(self, data):
6666

6767
model = DiscreteModel()
6868
custom_task = CustomTask(model, labels, train_indices, val_indices)
69-
representations = [Concatenation(), Averaging()]
69+
representations = [Concatenation(), Average()]
7070

7171
dr_search = DRSearch(modalities, custom_task, representations)
72-
best_representation, best_score, best_modalities = dr_search.fit()
72+
best_representation, best_score, best_modalities = dr_search.fit_random()
7373
aligned_representation = dr_search.transform(modalities)

src/main/python/systemds/scuro/modality/__init__.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,3 @@
1818
# under the License.
1919
#
2020
# -------------------------------------------------------------
21-
from systemds.scuro.modality.aligned_modality import AlignedModality
22-
from systemds.scuro.modality.audio_modality import AudioModality
23-
from systemds.scuro.modality.video_modality import VideoModality
24-
from systemds.scuro.modality.test_modality import TextModality
25-
from systemds.scuro.modality.modality import Modality
26-
27-
28-
__all__ = ["AlignedModality", "AudioModality", "VideoModality", "TextModality", "Modality"]

src/main/python/systemds/scuro/modality/aligned_modality.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
# -------------------------------------------------------------
2121
from typing import List
2222

23-
from modality.modality import Modality
24-
from representations.fusion import Fusion
23+
from systemds.scuro.modality.modality import Modality
24+
from systemds.scuro.representations.fusion import Fusion
2525

2626

2727
class AlignedModality(Modality):
@@ -36,9 +36,16 @@ def __init__(self, representation: Fusion, modalities: List[Modality]):
3636
name += modality.name
3737
super().__init__(representation, modality_name=name)
3838
self.modalities = modalities
39-
39+
4040
def combine(self):
4141
"""
4242
Initiates the call to fuse the given modalities depending on the Fusion type
4343
"""
4444
self.data = self.representation.fuse(self.modalities) # noqa
45+
46+
def get_modality_names(self):
47+
names = []
48+
for modality in self.modalities:
49+
names.append(modality.name)
50+
51+
return names

src/main/python/systemds/scuro/modality/audio_modality.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
# -------------------------------------------------------------
2121
import os
2222

23-
from modality.modality import Modality
24-
from representations.unimodal import UnimodalRepresentation
23+
from systemds.scuro.modality.modality import Modality
24+
from systemds.scuro.representations.unimodal import UnimodalRepresentation
2525

2626

2727
class AudioModality(Modality):

src/main/python/systemds/scuro/modality/modality.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
#
2020
# -------------------------------------------------------------
2121

22-
from representations.representation import Representation
22+
from systemds.scuro.representations.representation import Representation
2323

2424

2525
class Modality:

src/main/python/systemds/scuro/modality/text_modality.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
# -------------------------------------------------------------
2121
import os
2222

23-
from modality.modality import Modality
24-
from representations.unimodal import UnimodalRepresentation
23+
from systemds.scuro.modality.modality import Modality
24+
from systemds.scuro.representations.unimodal import UnimodalRepresentation
2525

2626

2727
class TextModality(Modality):

src/main/python/systemds/scuro/modality/video_modality.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
# -------------------------------------------------------------
2121
import os
2222

23-
from modality.modality import Modality
24-
from representations.unimodal import UnimodalRepresentation
23+
from systemds.scuro.modality.modality import Modality
24+
from systemds.scuro.representations.unimodal import UnimodalRepresentation
2525

2626

2727
class VideoModality(Modality):

0 commit comments

Comments
 (0)