Skip to content

Commit 496a22f

Browse files
[SYSTEMDS-3835] Scuro window aggregation operator
This patch adds a window aggregation operator with a mean, min, max, sum aggregation function. The window aggregation is applied to the individual modalities and can handle multiple subtypes of modalities. This PR also adds tests to verify the correctness of the operator. Closes #2225
1 parent b49c5ff commit 496a22f

32 files changed

+825
-236
lines changed

.github/workflows/python.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,9 @@ jobs:
116116
h5py \
117117
gensim \
118118
black \
119-
opt-einsum
120-
119+
opt-einsum \
120+
nltk
121+
121122
- name: Build Python Package
122123
run: |
123124
cd src/main/python

src/main/python/systemds/scuro/dataloader/audio_loader.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,23 +22,18 @@
2222

2323
import librosa
2424
from systemds.scuro.dataloader.base_loader import BaseLoader
25-
from systemds.scuro.utils.schema_helpers import create_timestamps
25+
from systemds.scuro.modality.type import ModalityType
2626

2727

2828
class AudioLoader(BaseLoader):
2929
def __init__(
30-
self,
31-
source_path: str,
32-
indices: List[str],
33-
chunk_size: Optional[int] = None,
30+
self, source_path: str, indices: List[str], chunk_size: Optional[int] = None
3431
):
35-
super().__init__(source_path, indices, chunk_size)
32+
super().__init__(source_path, indices, chunk_size, ModalityType.AUDIO)
3633

3734
def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):
3835
self.file_sanity_check(file)
3936
audio, sr = librosa.load(file)
40-
self.metadata[file] = {"sample_rate": sr, "length": audio.shape[0]}
41-
self.metadata[file]["timestamp"] = create_timestamps(
42-
self.metadata[file]["sample_rate"], self.metadata[file]["length"]
43-
)
37+
self.metadata[file] = self.modality_type.create_audio_metadata(sr, audio)
38+
4439
self.data.append(audio)

src/main/python/systemds/scuro/dataloader/base_loader.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@
2525

2626
class BaseLoader(ABC):
2727
def __init__(
28-
self, source_path: str, indices: List[str], chunk_size: Optional[int] = None
28+
self,
29+
source_path: str,
30+
indices: List[str],
31+
chunk_size: Optional[int] = None,
32+
modality_type=None,
2933
):
3034
"""
3135
Base class to load raw data for a given list of indices and stores them in the data object
@@ -40,6 +44,7 @@ def __init__(
4044
) # TODO: check what the index should be for storing the metadata (file_name, counter, ...)
4145
self.source_path = source_path
4246
self.indices = indices
47+
self.modality_type = modality_type
4348
self._next_chunk = 0
4449
self._num_chunks = 1
4550
self._chunk_size = None
@@ -64,6 +69,11 @@ def num_chunks(self):
6469
def next_chunk(self):
6570
return self._next_chunk
6671

72+
def reset(self):
73+
self._next_chunk = 0
74+
self.data = []
75+
self.metadata = {}
76+
6777
def load(self):
6878
"""
6979
Takes care of loading the raw data either chunk wise (if chunk size is defined) or all at once

src/main/python/systemds/scuro/dataloader/json_loader.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
# -------------------------------------------------------------
2121
import json
2222

23+
from systemds.scuro.modality.type import ModalityType
2324
from systemds.scuro.dataloader.base_loader import BaseLoader
2425
from typing import Optional, List, Union
2526

@@ -32,12 +33,16 @@ def __init__(
3233
field: str,
3334
chunk_size: Optional[int] = None,
3435
):
35-
super().__init__(source_path, indices, chunk_size)
36+
super().__init__(source_path, indices, chunk_size, ModalityType.TEXT)
3637
self.field = field
3738

3839
def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):
3940
self.file_sanity_check(file)
4041
with open(file) as f:
4142
json_file = json.load(f)
4243
for idx in index:
43-
self.data.append(json_file[idx][self.field])
44+
sentence = json_file[idx][self.field]
45+
self.data.append(sentence)
46+
self.metadata[idx] = self.modality_type.create_text_metadata(
47+
len(sentence), sentence
48+
)

src/main/python/systemds/scuro/dataloader/text_loader.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
# -------------------------------------------------------------
2121
from systemds.scuro.dataloader.base_loader import BaseLoader
2222
from typing import Optional, Pattern, List, Union
23+
from systemds.scuro.modality.type import ModalityType
2324
import re
2425

2526

@@ -31,7 +32,7 @@ def __init__(
3132
chunk_size: Optional[int] = None,
3233
prefix: Optional[Pattern[str]] = None,
3334
):
34-
super().__init__(source_path, indices, chunk_size)
35+
super().__init__(source_path, indices, chunk_size, ModalityType.TEXT)
3536
self.prefix = prefix
3637

3738
def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):
@@ -41,5 +42,7 @@ def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):
4142
if self.prefix:
4243
line = re.sub(self.prefix, "", line)
4344
line = line.replace("\n", "")
44-
self.metadata[file] = {"length": len(line.split())}
45+
self.metadata[file] = self.modality_type.create_text_metadata(
46+
len(line.split()), line
47+
)
4548
self.data.append(line)

src/main/python/systemds/scuro/dataloader/video_loader.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
import numpy as np
2424

2525
from systemds.scuro.dataloader.base_loader import BaseLoader
26-
from systemds.scuro.utils.schema_helpers import create_timestamps
2726
import cv2
27+
from systemds.scuro.modality.type import ModalityType
2828

2929

3030
class VideoLoader(BaseLoader):
@@ -34,7 +34,7 @@ def __init__(
3434
indices: List[str],
3535
chunk_size: Optional[int] = None,
3636
):
37-
super().__init__(source_path, indices, chunk_size)
37+
super().__init__(source_path, indices, chunk_size, ModalityType.VIDEO)
3838

3939
def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):
4040
self.file_sanity_check(file)
@@ -43,16 +43,14 @@ def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):
4343
if not cap.isOpened():
4444
raise f"Could not read video at path: {file}"
4545

46-
self.metadata[file] = {
47-
"fps": cap.get(cv2.CAP_PROP_FPS),
48-
"length": int(cap.get(cv2.CAP_PROP_FRAME_COUNT)),
49-
"width": int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
50-
"height": int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)),
51-
"num_channels": 3,
52-
}
46+
fps = cap.get(cv2.CAP_PROP_FPS)
47+
length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
48+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
49+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
50+
num_channels = 3
5351

54-
self.metadata[file]["timestamp"] = create_timestamps(
55-
self.metadata[file]["fps"], self.metadata[file]["length"]
52+
self.metadata[file] = self.modality_type.create_video_metadata(
53+
fps, length, width, height, num_channels
5654
)
5755

5856
frames = []

src/main/python/systemds/scuro/modality/joined.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def execute(self, starting_idx=0):
104104
self.joined_right.data[i - starting_idx].append([])
105105
right = np.array([])
106106
if self.condition.join_type == "<":
107-
while c < len(idx_2) and idx_2[c] < nextIdx[j]:
107+
while c < len(idx_2) - 1 and idx_2[c] < nextIdx[j]:
108108
if right.size == 0:
109109
right = self.right_modality.data[i][c]
110110
if right.ndim == 1:
@@ -125,7 +125,7 @@ def execute(self, starting_idx=0):
125125
)
126126
c = c + 1
127127
else:
128-
while c < len(idx_2) and idx_2[c] <= idx_1[j]:
128+
while c < len(idx_2) - 1 and idx_2[c] <= idx_1[j]:
129129
if idx_2[c] == idx_1[j]:
130130
right.append(self.right_modality.data[i][c])
131131
c = c + 1
@@ -141,18 +141,17 @@ def execute(self, starting_idx=0):
141141

142142
self.joined_right.data[i - starting_idx][j] = right
143143

144-
def apply_representation(self, representation, aggregation):
144+
def apply_representation(self, representation, aggregation=None):
145145
self.aggregation = aggregation
146146
if self.chunked_execution:
147147
return self._handle_chunked_execution(representation)
148-
elif self.left_type.__name__.__contains__("Unimodal"):
149-
self.left_modality.extract_raw_data()
150-
if self.left_type == self.right_type:
151-
self.right_modality.extract_raw_data()
152-
elif self.right_type.__name__.__contains__("Unimodal"):
153-
self.right_modality.extract_raw_data()
148+
# elif self.left_type.__name__.__contains__("Unimodal"):
149+
# self.left_modality.extract_raw_data()
150+
# if self.left_type == self.right_type:
151+
# self.right_modality.extract_raw_data()
152+
# elif self.right_type.__name__.__contains__("Unimodal") and not self.right_modality.has_data():
153+
# self.right_modality.extract_raw_data()
154154

155-
self.execute()
156155
left_transformed = self._apply_representation(
157156
self.left_modality, representation
158157
)
@@ -263,12 +262,12 @@ def _apply_representation_chunked(
263262

264263
def _apply_representation(self, modality, representation):
265264
transformed = representation.transform(modality)
266-
if self.aggregation:
267-
aggregated_data_left = self.aggregation.window(transformed)
268-
transformed = Modality(
269-
transformed.modality_type,
270-
transformed.metadata,
271-
)
272-
transformed.data = aggregated_data_left
265+
# if self.aggregation:
266+
# aggregated_data_left = self.aggregation.execute(transformed)
267+
# transformed = Modality(
268+
# transformed.modality_type,
269+
# transformed.metadata,
270+
# )
271+
# transformed.data = aggregated_data_left
273272

274273
return transformed

src/main/python/systemds/scuro/modality/joined_transformed.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
from systemds.scuro.modality.modality import Modality
2727
from systemds.scuro.representations.utils import pad_sequences
28+
from systemds.scuro.representations.window import WindowAggregation
2829

2930

3031
class JoinedTransformedModality(Modality):
@@ -68,3 +69,9 @@ def combine(self, fusion_method):
6869
self.data[i] = np.array(r)
6970
self.data = pad_sequences(self.data)
7071
return self
72+
73+
def window(self, window_size, aggregation):
74+
w = WindowAggregation(window_size, aggregation)
75+
self.left_modality.data = w.execute(self.left_modality)
76+
self.right_modality.data = w.execute(self.right_modality)
77+
return self

0 commit comments

Comments
 (0)