Skip to content

Commit abf179a

Browse files
[SYSTEMDS-3913] Combined unimodal representation
This patch adds a combined unimodal representation feature to the unimodal representation optimizer. As an example multiple audio representations can be combined into one single representation using either Concatenation, Hadamard-product, or Addition. To ensure this feature works as intended, the unimodal optimization test was adapted.
1 parent feafd0e commit abf179a

26 files changed

+919
-163
lines changed

src/main/python/systemds/scuro/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@
3838
from systemds.scuro.representations.lstm import LSTM
3939
from systemds.scuro.representations.max import RowMax
4040
from systemds.scuro.representations.mel_spectrogram import MelSpectrogram
41+
from systemds.scuro.representations.multimodal_attention_fusion import (
42+
AttentionFusion,
43+
)
4144
from systemds.scuro.representations.mfcc import MFCC
4245
from systemds.scuro.representations.hadamard import Hadamard
4346
from systemds.scuro.representations.optical_flow import OpticalFlow
@@ -73,6 +76,12 @@
7376
from systemds.scuro.drsearch.unimodal_representation_optimizer import (
7477
UnimodalRepresentationOptimizer,
7578
)
79+
from systemds.scuro.representations.covarep_audio_features import (
80+
RMSE,
81+
Spectral,
82+
ZeroCrossing,
83+
Pitch,
84+
)
7685
from systemds.scuro.drsearch.multimodal_optimizer import MultimodalOptimizer
7786
from systemds.scuro.drsearch.unimodal_optimizer import UnimodalOptimizer
7887

@@ -131,4 +140,9 @@
131140
"UnimodalRepresentationOptimizer",
132141
"UnimodalOptimizer",
133142
"MultimodalOptimizer",
143+
"ZeroCrossing",
144+
"Pitch",
145+
"RMSE",
146+
"Spectral",
147+
"AttentionFusion",
134148
]

src/main/python/systemds/scuro/dataloader/audio_loader.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
#
2020
# -------------------------------------------------------------
2121
from typing import List, Optional, Union
22-
2322
import librosa
2423
import numpy as np
2524

src/main/python/systemds/scuro/drsearch/multimodal_optimizer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ def optimize_intramodal_representations(self, task):
9898
],
9999
)
100100

101+
# TODO: check if order matters for reused reps - only compute once - check in cache
102+
# TODO: parallelize - whenever an item of len 0 comes along give it to a new thread - merge results
103+
# TODO: change the algorithm so that one representation is used until there is no more representations to add - saves a lot of memory
101104
def optimize_intermodal_representations(self, task):
102105
modality_combos = []
103106
n = len(self.k_best_cache[task.model.name])
@@ -122,8 +125,7 @@ def generate_extensions(current_combo, remaining_indices):
122125
reuse_fused_representations = False
123126
for i, modality_combo in enumerate(modality_combos):
124127
# clear reuse cache
125-
if i % 5 == 0:
126-
reuse_cache = self.prune_cache(modality_combos[i:], reuse_cache)
128+
reuse_cache = self.prune_cache(modality_combos[i:], reuse_cache)
127129

128130
if i != 0:
129131
reuse_fused_representations = self.is_prefix_match(

src/main/python/systemds/scuro/drsearch/operator_registry.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,12 @@ def __new__(cls):
4343
cls._representations[m_type] = []
4444
return cls._instance
4545

46+
def set_fusion_operators(self, fusion_operators):
47+
if isinstance(fusion_operators, list):
48+
self._context_operators = fusion_operators
49+
else:
50+
self._fusion_operators = [fusion_operators]
51+
4652
def add_representation(
4753
self, representation: Representation, modality: ModalityType
4854
):
@@ -57,6 +63,13 @@ def add_fusion_operator(self, fusion_operator):
5763
def get_representations(self, modality: ModalityType):
5864
return self._representations[modality]
5965

66+
def get_not_self_contained_representations(self, modality: ModalityType):
67+
reps = []
68+
for rep in self.get_representations(modality):
69+
if not rep().self_contained:
70+
reps.append(rep)
71+
return reps
72+
6073
def get_context_operators(self):
6174
# TODO: return modality specific context operations
6275
return self._context_operators

src/main/python/systemds/scuro/drsearch/task.py

Lines changed: 59 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@
1919
#
2020
# -------------------------------------------------------------
2121
import time
22-
from typing import List
22+
from typing import List, Union
2323

24+
from systemds.scuro.modality.modality import Modality
25+
from systemds.scuro.representations.representation import Representation
2426
from systemds.scuro.models.model import Model
2527
import numpy as np
2628
from sklearn.model_selection import KFold
@@ -57,6 +59,8 @@ def __init__(
5759
self.inference_time = []
5860
self.training_time = []
5961
self.expected_dim = 1
62+
self.train_scores = []
63+
self.val_scores = []
6064

6165
def get_train_test_split(self, data):
6266
X_train = [data[i] for i in self.train_indices]
@@ -73,28 +77,69 @@ def run(self, data):
7377
:param data: The aligned data used in the prediction process
7478
:return: the validation accuracy
7579
"""
80+
self._reset_params()
81+
skf = KFold(n_splits=self.kfold, shuffle=True, random_state=11)
82+
83+
fold = 0
84+
X, y, _, _ = self.get_train_test_split(data)
85+
86+
for train, test in skf.split(X, y):
87+
train_X = np.array(X)[train]
88+
train_y = np.array(y)[train]
89+
test_X = np.array(X)[test]
90+
test_y = np.array(y)[test]
91+
self._run_fold(train_X, train_y, test_X, test_y)
92+
fold += 1
93+
94+
if self.measure_performance:
95+
self.inference_time = np.mean(self.inference_time)
96+
self.training_time = np.mean(self.training_time)
97+
98+
return [np.mean(self.train_scores), np.mean(self.val_scores)]
99+
100+
def _reset_params(self):
76101
self.inference_time = []
77102
self.training_time = []
103+
self.train_scores = []
104+
self.val_scores = []
105+
106+
def _run_fold(self, train_X, train_y, test_X, test_y):
107+
train_start = time.time()
108+
train_score = self.model.fit(train_X, train_y, test_X, test_y)
109+
train_end = time.time()
110+
self.training_time.append(train_end - train_start)
111+
self.train_scores.append(train_score)
112+
test_start = time.time()
113+
test_score = self.model.test(np.array(test_X), test_y)
114+
test_end = time.time()
115+
self.inference_time.append(test_end - test_start)
116+
self.val_scores.append(test_score)
117+
118+
def create_representation_and_run(
119+
self,
120+
representation: Representation,
121+
modalities: Union[List[Modality], Modality],
122+
):
123+
self._reset_params()
78124
skf = KFold(n_splits=self.kfold, shuffle=True, random_state=11)
79-
train_scores = []
80-
test_scores = []
125+
81126
fold = 0
82-
X, y, X_test, y_test = self.get_train_test_split(data)
127+
X, y, _, _ = self.get_train_test_split(data)
83128

84129
for train, test in skf.split(X, y):
85130
train_X = np.array(X)[train]
86131
train_y = np.array(y)[train]
87-
train_start = time.time()
88-
train_score = self.model.fit(train_X, train_y, X_test, y_test)
89-
train_end = time.time()
90-
self.training_time.append(train_end - train_start)
91-
train_scores.append(train_score)
92-
test_start = time.time()
93-
test_score = self.model.test(np.array(X_test), y_test)
94-
test_end = time.time()
95-
self.inference_time.append(test_end - test_start)
96-
test_scores.append(test_score)
132+
test_X = s.transform(np.array(X)[test])
133+
test_y = np.array(y)[test]
134+
135+
if isinstance(modalities, Modality):
136+
rep = modality.apply_representation(representation())
137+
else:
138+
representation().transform(
139+
train_X, train_y
140+
) # TODO: think about a way how to handle masks
97141

142+
self._run_fold(train_X, train_y, test_X, test_y)
98143
fold += 1
99144

100145
if self.measure_performance:

src/main/python/systemds/scuro/drsearch/unimodal_optimizer.py

Lines changed: 68 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
# -------------------------------------------------------------
2121
import pickle
2222
import time
23+
import copy
2324
from concurrent.futures import ProcessPoolExecutor, as_completed
2425
from dataclasses import dataclass, field, asdict
2526

@@ -28,11 +29,17 @@
2829

2930
import numpy as np
3031
from systemds.scuro.representations.window_aggregation import WindowAggregation
32+
from systemds.scuro.representations.concatenation import Concatenation
33+
from systemds.scuro.representations.hadamard import Hadamard
34+
from systemds.scuro.representations.sum import Sum
3135

3236
from systemds.scuro.representations.aggregated_representation import (
3337
AggregatedRepresentation,
3438
)
35-
from systemds.scuro import ModalityType, Aggregation
39+
from systemds.scuro.modality.type import ModalityType
40+
from systemds.scuro.modality.modality import Modality
41+
from systemds.scuro.modality.transformed import TransformedModality
42+
from systemds.scuro.representations.aggregate import Aggregation
3643
from systemds.scuro.drsearch.operator_registry import Registry
3744
from systemds.scuro.utils.schema_helpers import get_shape
3845

@@ -84,7 +91,6 @@ def optimize_parallel(self, n_workers=None):
8491
def optimize(self):
8592
for modality in self.modalities:
8693
local_result = self._process_modality(modality, False)
87-
# self._merge_results(local_result)
8894

8995
def _process_modality(self, modality, parallel):
9096
if parallel:
@@ -95,43 +101,59 @@ def _process_modality(self, modality, parallel):
95101
local_results = self.operator_performance
96102

97103
context_operators = self.operator_registry.get_context_operators()
98-
99-
for context_operator in context_operators:
100-
context_representation = None
101-
if (
102-
modality.modality_type != ModalityType.TEXT
103-
and modality.modality_type != ModalityType.VIDEO
104-
):
105-
con_op = context_operator()
106-
context_representation = modality.context(con_op)
107-
self._evaluate_local(context_representation, [con_op], local_results)
108-
109-
modality_specific_operators = self.operator_registry.get_representations(
104+
not_self_contained_reps = (
105+
self.operator_registry.get_not_self_contained_representations(
110106
modality.modality_type
111107
)
112-
for modality_specific_operator in modality_specific_operators:
113-
mod_context = None
114-
mod_op = modality_specific_operator()
115-
if context_representation is not None:
116-
mod_context = context_representation.apply_representation(mod_op)
117-
self._evaluate_local(mod_context, [con_op, mod_op], local_results)
118-
119-
mod = modality.apply_representation(mod_op)
120-
self._evaluate_local(mod, [mod_op], local_results)
121-
122-
for context_operator_after in context_operators:
123-
con_op_after = context_operator_after()
124-
if mod_context is not None:
125-
mod_context = mod_context.context(con_op_after)
126-
self._evaluate_local(
127-
mod_context, [con_op, mod_op, con_op_after], local_results
128-
)
129-
130-
mod = mod.context(con_op_after)
131-
self._evaluate_local(mod, [mod_op, con_op_after], local_results)
108+
)
109+
modality_specific_operators = self.operator_registry.get_representations(
110+
modality.modality_type
111+
)
112+
for modality_specific_operator in modality_specific_operators:
113+
mod_op = modality_specific_operator()
114+
115+
mod = modality.apply_representation(mod_op)
116+
self._evaluate_local(mod, [mod_op], local_results)
117+
118+
if not mod_op.self_contained:
119+
self._combine_non_self_contained_representations(
120+
modality, mod, not_self_contained_reps, local_results
121+
)
122+
123+
for context_operator_after in context_operators:
124+
con_op_after = context_operator_after()
125+
mod = mod.context(con_op_after)
126+
self._evaluate_local(mod, [mod_op, con_op_after], local_results)
132127

133128
return local_results
134129

130+
def _combine_non_self_contained_representations(
131+
self,
132+
modality: Modality,
133+
representation: TransformedModality,
134+
other_representations,
135+
local_results,
136+
):
137+
combined = representation
138+
context_operators = self.operator_registry.get_context_operators()
139+
used_representations = representation.transformation
140+
for other_representation in other_representations:
141+
used_representations.append(other_representation())
142+
for combination in [Concatenation(), Hadamard(), Sum()]:
143+
combined = combined.combine(
144+
modality.apply_representation(other_representation()), combination
145+
)
146+
self._evaluate_local(
147+
combined, used_representations, local_results, combination
148+
)
149+
150+
for context_op in context_operators:
151+
con_op = context_op()
152+
mod = combined.context(con_op)
153+
c_t = copy.deepcopy(used_representations)
154+
c_t.append(con_op)
155+
self._evaluate_local(mod, c_t, local_results, combination)
156+
135157
def _merge_results(self, local_results):
136158
"""Merge local results into the main results"""
137159
for modality_id in local_results.results:
@@ -145,7 +167,9 @@ def _merge_results(self, local_results):
145167
for key, value in local_results.cache[modality][task_name].items():
146168
self.operator_performance.cache[modality][task_name][key] = value
147169

148-
def _evaluate_local(self, modality, representations, local_results):
170+
def _evaluate_local(
171+
self, modality, representations, local_results, combination=None
172+
):
149173
if self._tasks_require_same_dims:
150174
if self.expected_dimensions == 1 and get_shape(modality.metadata) > 1:
151175
# for aggregation in Aggregation().get_aggregation_functions():
@@ -165,6 +189,7 @@ def _evaluate_local(self, modality, representations, local_results):
165189
modality,
166190
task.model.name,
167191
end - start,
192+
combination,
168193
)
169194
else:
170195
modality.pad()
@@ -178,6 +203,7 @@ def _evaluate_local(self, modality, representations, local_results):
178203
modality,
179204
task.model.name,
180205
end - start,
206+
combination,
181207
)
182208
else:
183209
for task in self.tasks:
@@ -198,6 +224,7 @@ def _evaluate_local(self, modality, representations, local_results):
198224
modality,
199225
task.model.name,
200226
end - start,
227+
combination,
201228
)
202229
else:
203230
# modality.pad()
@@ -210,6 +237,7 @@ def _evaluate_local(self, modality, representations, local_results):
210237
modality,
211238
task.model.name,
212239
end - start,
240+
combination,
213241
)
214242

215243

@@ -228,7 +256,9 @@ def __init__(self, modalities, tasks, debug=False):
228256
self.cache[modality][task_name] = {}
229257
self.results[modality][task_name] = []
230258

231-
def add_result(self, scores, representations, modality, task_name, task_time):
259+
def add_result(
260+
self, scores, representations, modality, task_name, task_time, combination
261+
):
232262
parameters = []
233263
representation_names = []
234264

@@ -256,6 +286,7 @@ def add_result(self, scores, representations, modality, task_name, task_time):
256286
val_score=scores[1],
257287
representation_time=modality.transform_time,
258288
task_time=task_time,
289+
combination=combination.name if combination else "",
259290
)
260291
self.results[modality.modality_id][task_name].append(entry)
261292
self.cache[modality.modality_id][task_name][
@@ -302,3 +333,4 @@ class ResultEntry:
302333
train_score: float
303334
representation_time: float
304335
task_time: float
336+
combination: str

src/main/python/systemds/scuro/modality/modality.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,6 @@ def is_aligned(self, other_modality):
168168
!= list(other_modality.metadata.values())[i]["data_layout"]["shape"]
169169
):
170170
aligned = False
171-
continue
171+
break
172172

173173
return aligned

0 commit comments

Comments
 (0)