Skip to content

Commit 5dfa26f

Browse files
[SYSTEMDS-3835] Add additional context operators
This patch adds two additional context operators to Scuro. The first one is a StaticWindow operator that, given a number of desired windows, defines the suitable window size and aggregates a sequence into num_window features. The second context operator is a DynamicWindow where a sequence is also aggregated into num_window features with the difference that the window size for more recent data points is smaller than the window size for more historic data points in a timeseries.
1 parent abf179a commit 5dfa26f

File tree

7 files changed

+159
-29
lines changed

7 files changed

+159
-29
lines changed

src/main/python/systemds/scuro/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,11 @@
5555
from systemds.scuro.representations.tfidf import TfIdf
5656
from systemds.scuro.representations.unimodal import UnimodalRepresentation
5757
from systemds.scuro.representations.wav2vec import Wav2Vec
58-
from systemds.scuro.representations.window_aggregation import WindowAggregation
58+
from systemds.scuro.representations.window_aggregation import (
59+
WindowAggregation,
60+
DynamicWindow,
61+
StaticWindow,
62+
)
5963
from systemds.scuro.representations.word2vec import W2V
6064
from systemds.scuro.representations.x3d import X3D
6165
from systemds.scuro.models.model import Model
@@ -145,4 +149,6 @@
145149
"RMSE",
146150
"Spectral",
147151
"AttentionFusion",
152+
"DynamicWindow",
153+
"StaticWindow",
148154
]

src/main/python/systemds/scuro/drsearch/unimodal_optimizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,8 @@ def _process_modality(self, modality, parallel):
122122

123123
for context_operator_after in context_operators:
124124
con_op_after = context_operator_after()
125-
mod = mod.context(con_op_after)
126-
self._evaluate_local(mod, [mod_op, con_op_after], local_results)
125+
mod_con = mod.context(con_op_after)
126+
self._evaluate_local(mod_con, [mod_op, con_op_after], local_results)
127127

128128
return local_results
129129

src/main/python/systemds/scuro/representations/fusion.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ def get_max_embedding_size(self, modalities: List[Modality]):
105105
curr_shape = modalities[idx].data[0].shape
106106
if len(modalities[idx - 1].data) != len(modalities[idx].data):
107107
raise f"Modality sizes don't match!"
108+
elif len(curr_shape) == 1:
109+
continue
108110
elif curr_shape[1] > max_size:
109111
max_size = curr_shape[1]
110112

src/main/python/systemds/scuro/representations/window_aggregation.py

Lines changed: 90 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
# under the License.
1919
#
2020
# -------------------------------------------------------------
21+
import copy
22+
2123
import numpy as np
2224
import math
2325

@@ -28,17 +30,13 @@
2830
from systemds.scuro.representations.context import Context
2931

3032

31-
@register_context_operator()
32-
class WindowAggregation(Context):
33-
def __init__(self, window_size=10, aggregation_function="mean", pad=True):
33+
class Window(Context):
34+
def __init__(self, name, aggregation_function):
3435
parameters = {
35-
"window_size": [window_size],
3636
"aggregation_function": list(Aggregation().get_aggregation_functions()),
37-
} # TODO: window_size should be dynamic and adapted to the shape of the data
38-
super().__init__("WindowAggregation", parameters)
39-
self.window_size = window_size
37+
}
38+
super().__init__(name, parameters)
4039
self.aggregation_function = aggregation_function
41-
self.pad = pad
4240

4341
@property
4442
def aggregation_function(self):
@@ -48,6 +46,15 @@ def aggregation_function(self):
4846
def aggregation_function(self, value):
4947
self._aggregation_function = Aggregation(value)
5048

49+
50+
@register_context_operator()
51+
class WindowAggregation(Window):
52+
def __init__(self, window_size=10, aggregation_function="mean", pad=False):
53+
super().__init__("WindowAggregation", aggregation_function)
54+
self.parameters["window_size"] = [window_size]
55+
self.window_size = window_size
56+
self.pad = pad
57+
5158
def execute(self, modality):
5259
windowed_data = []
5360
original_lengths = []
@@ -107,24 +114,90 @@ def execute(self, modality):
107114
def window_aggregate_single_level(self, instance, new_length):
108115
if isinstance(instance, str):
109116
return instance
110-
instance = np.array(instance)
111-
num_cols = instance.shape[1] if instance.ndim > 1 else 1
112-
result = np.empty((new_length, num_cols))
117+
instance = np.array(copy.deepcopy(instance))
118+
119+
result = []
113120
for i in range(0, new_length):
114-
result[i] = self.aggregation_function.aggregate_instance(
115-
instance[i * self.window_size : i * self.window_size + self.window_size]
121+
result.append(
122+
self.aggregation_function.aggregate_instance(
123+
instance[
124+
i * self.window_size : i * self.window_size + self.window_size
125+
]
126+
)
116127
)
117128

118-
if num_cols == 1:
119-
result = result.reshape(-1)
120-
return result
129+
return np.array(result)
121130

122131
def window_aggregate_nested_level(self, instance, new_length):
123132
result = [[] for _ in range(0, new_length)]
124-
data = np.stack(instance)
133+
data = np.stack(copy.deepcopy(instance))
125134
for i in range(0, new_length):
126135
result[i] = self.aggregation_function.aggregate_instance(
127136
data[i * self.window_size : i * self.window_size + self.window_size]
128137
)
129138

130139
return np.array(result)
140+
141+
142+
@register_context_operator()
143+
class StaticWindow(Window):
144+
def __init__(self, num_windows=100, aggregation_function="mean"):
145+
super().__init__("StaticWindow", aggregation_function)
146+
self.parameters["num_windows"] = [num_windows]
147+
self.num_windows = num_windows
148+
149+
def execute(self, modality):
150+
windowed_data = []
151+
152+
for instance in modality.data:
153+
window_size = len(instance) // self.num_windows
154+
remainder = len(instance) % self.num_windows
155+
output = []
156+
start = 0
157+
for i in range(0, self.num_windows):
158+
extra = 1 if i < remainder else 0
159+
end = start + window_size + extra
160+
window = copy.deepcopy(instance[start:end])
161+
val = (
162+
self.aggregation_function.aggregate_instance(window)
163+
if len(window) > 0
164+
else np.zeros_like(output[i - 1])
165+
)
166+
output.append(val)
167+
start = end
168+
169+
windowed_data.append(output)
170+
return np.array(windowed_data)
171+
172+
173+
@register_context_operator()
174+
class DynamicWindow(Window):
175+
def __init__(self, num_windows=100, aggregation_function="mean"):
176+
super().__init__("DynamicWindow", aggregation_function)
177+
self.parameters["num_windows"] = [num_windows]
178+
self.num_windows = num_windows
179+
180+
def execute(self, modality):
181+
windowed_data = []
182+
183+
for instance in modality.data:
184+
N = len(instance)
185+
weights = np.geomspace(4, 256, num=self.num_windows)
186+
weights = weights / np.sum(weights)
187+
window_sizes = (weights * N).astype(int)
188+
window_sizes[-1] += N - np.sum(window_sizes)
189+
indices = np.cumsum(window_sizes)
190+
output = []
191+
start = 0
192+
for end in indices:
193+
window = copy.deepcopy(instance[start:end])
194+
val = (
195+
self.aggregation_function.aggregate_instance(window)
196+
if len(window) > 0
197+
else np.zeros_like(instance[0])
198+
)
199+
output.append(val)
200+
start = end
201+
windowed_data.append(output)
202+
203+
return np.array(windowed_data)

src/main/python/tests/scuro/test_operator_registry.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,11 @@
3030
from systemds.scuro.representations.mfcc import MFCC
3131
from systemds.scuro.representations.swin_video_transformer import SwinVideoTransformer
3232
from systemds.scuro.representations.wav2vec import Wav2Vec
33-
from systemds.scuro.representations.window_aggregation import WindowAggregation
33+
from systemds.scuro.representations.window_aggregation import (
34+
WindowAggregation,
35+
StaticWindow,
36+
DynamicWindow,
37+
)
3438
from systemds.scuro.representations.bow import BoW
3539
from systemds.scuro.representations.word2vec import W2V
3640
from systemds.scuro.representations.tfidf import TfIdf
@@ -83,7 +87,11 @@ def test_text_representations_in_registry(self):
8387

8488
def test_context_operator_in_registry(self):
8589
registry = Registry()
86-
assert registry.get_context_operators() == [WindowAggregation]
90+
assert registry.get_context_operators() == [
91+
WindowAggregation,
92+
StaticWindow,
93+
DynamicWindow,
94+
]
8795

8896
# def test_fusion_operator_in_registry(self):
8997
# registry = Registry()

src/main/python/tests/scuro/test_unimodal_optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def setUpClass(cls):
141141

142142
def test_unimodal_optimizer_for_audio_modality(self):
143143
audio_data, audio_md = ModalityRandomDataGenerator().create_audio_data(
144-
self.num_instances, 100
144+
self.num_instances, 3000
145145
)
146146
audio = UnimodalModality(
147147
TestDataLoader(

src/main/python/tests/scuro/test_window_operations.py

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,13 @@
2424

2525
import numpy as np
2626

27-
from tests.scuro.data_generator import ModalityRandomDataGenerator
27+
from tests.scuro.data_generator import ModalityRandomDataGenerator, TestDataLoader
2828
from systemds.scuro.modality.type import ModalityType
29+
from systemds.scuro.modality.unimodal_modality import UnimodalModality
30+
from systemds.scuro.representations.window_aggregation import (
31+
StaticWindow,
32+
DynamicWindow,
33+
)
2934

3035

3136
class TestWindowOperations(unittest.TestCase):
@@ -35,20 +40,56 @@ def setUpClass(cls):
3540
cls.data_generator = ModalityRandomDataGenerator()
3641
cls.aggregations = ["mean", "sum", "max", "min"]
3742

38-
def test_window_operations_on_audio_representations(self):
43+
def test_static_window(self):
44+
num_windows = 5
45+
data, md = self.data_generator.create_visual_modality(self.num_instances, 50)
46+
modality = UnimodalModality(
47+
TestDataLoader(
48+
[i for i in range(0, self.num_instances)],
49+
None,
50+
ModalityType.VIDEO,
51+
data,
52+
np.float32,
53+
md,
54+
)
55+
)
56+
aggregated_window = modality.context(StaticWindow(num_windows))
57+
58+
for i in range(0, self.num_instances):
59+
assert len(aggregated_window.data[i]) == num_windows
60+
61+
def test_dynamic_window(self):
62+
num_windows = 5
63+
data, md = self.data_generator.create_visual_modality(self.num_instances, 50)
64+
modality = UnimodalModality(
65+
TestDataLoader(
66+
[i for i in range(0, self.num_instances)],
67+
None,
68+
ModalityType.VIDEO,
69+
data,
70+
np.float32,
71+
md,
72+
)
73+
)
74+
aggregated_window = modality.context(DynamicWindow(num_windows))
75+
76+
for i in range(0, self.num_instances):
77+
assert len(aggregated_window.data[i]) == num_windows
78+
79+
def test_window_aggregation_on_audio_representations(self):
3980
window_size = 10
40-
self.run_window_operations_for_modality(ModalityType.AUDIO, window_size)
81+
self.run_window_aggregation_for_modality(ModalityType.AUDIO, window_size)
4182

4283
def test_window_operations_on_video_representations(self):
4384
window_size = 10
44-
self.run_window_operations_for_modality(ModalityType.VIDEO, window_size)
85+
self.run_window_aggregation_for_modality(ModalityType.VIDEO, window_size)
4586

4687
def test_window_operations_on_text_representations(self):
4788
window_size = 10
4889

49-
self.run_window_operations_for_modality(ModalityType.TEXT, window_size)
90+
self.run_window_aggregation_for_modality(ModalityType.TEXT, window_size)
5091

51-
def run_window_operations_for_modality(self, modality_type, window_size):
92+
def run_window_aggregation_for_modality(self, modality_type, window_size):
5293
r = self.data_generator.create1DModality(40, 100, modality_type)
5394
for aggregation in self.aggregations:
5495
windowed_modality = r.window_aggregation(window_size, aggregation)

0 commit comments

Comments
 (0)