Skip to content

Commit 4dd5ab3

Browse files
christinadionysioBaunsgaard
authored andcommitted
[SYSTEMDS-3830] Add join operator to Scuro
This patch adds a new join operator to Scuro. The join operation takes two modalities as well as a join condition as input and joins the two modalities on their common dimension (temporal for now). This includes two new modalities and the ability to apply new representations on top of a joined modality. In the future the join operator will also serve as a simple alignment operator by joining two modalities by a given offset. Closes #2220
1 parent 9090d40 commit 4dd5ab3

30 files changed

+1362
-224
lines changed

.github/workflows/python.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,10 @@ jobs:
114114
torch \
115115
librosa \
116116
h5py \
117-
nltk \
118117
gensim \
119-
black
120-
118+
black \
119+
opt-einsum
120+
121121
- name: Build Python Package
122122
run: |
123123
cd src/main/python

src/main/python/systemds/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,14 @@
2626
__all__ = ["context", "operator", "examples"]
2727

2828
required_packages = [
29-
("torch", "2.4.1"),
30-
("torchvision", "0.19.1"),
29+
("torch", "2.5.1"),
30+
("torchvision", "0.20.1"),
3131
("librosa", "0.10.2"),
3232
("opencv-python", "4.10.0.84"),
3333
("opt-einsum", "3.3.0"),
3434
("h5py", "3.11.0"),
3535
("transformers", "4.46.3"),
36+
("nltk", "3.9.1"),
3637
("gensim", "4.3.3"),
3738
]
3839

src/main/python/systemds/scuro/dataloader/audio_loader.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,11 @@
1818
# under the License.
1919
#
2020
# -------------------------------------------------------------
21-
from typing import List, Optional
21+
from typing import List, Optional, Union
2222

2323
import librosa
2424
from systemds.scuro.dataloader.base_loader import BaseLoader
25+
from systemds.scuro.utils.schema_helpers import create_timestamps
2526

2627

2728
class AudioLoader(BaseLoader):
@@ -33,7 +34,11 @@ def __init__(
3334
):
3435
super().__init__(source_path, indices, chunk_size)
3536

36-
def extract(self, file: str):
37+
def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):
3738
self.file_sanity_check(file)
3839
audio, sr = librosa.load(file)
40+
self.metadata[file] = {"sample_rate": sr, "length": audio.shape[0]}
41+
self.metadata[file]["timestamp"] = create_timestamps(
42+
self.metadata[file]["sample_rate"], self.metadata[file]["length"]
43+
)
3944
self.data.append(audio)

src/main/python/systemds/scuro/dataloader/base_loader.py

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,32 +35,68 @@ def __init__(
3535
(otherwise please provide your own Dataloader that knows about the file name convention)
3636
"""
3737
self.data = []
38+
self.metadata = (
39+
{}
40+
) # TODO: check what the index should be for storing the metadata (file_name, counter, ...)
3841
self.source_path = source_path
3942
self.indices = indices
40-
self.chunk_size = chunk_size
41-
self.next_chunk = 0
43+
self._next_chunk = 0
44+
self._num_chunks = 1
45+
self._chunk_size = None
4246

43-
if self.chunk_size:
44-
self.num_chunks = int(len(self.indices) / self.chunk_size)
47+
if chunk_size:
48+
self.chunk_size = chunk_size
49+
50+
@property
51+
def chunk_size(self):
52+
return self._chunk_size
53+
54+
@chunk_size.setter
55+
def chunk_size(self, value):
56+
self._chunk_size = value
57+
self._num_chunks = int(len(self.indices) / self._chunk_size)
58+
59+
@property
60+
def num_chunks(self):
61+
return self._num_chunks
62+
63+
@property
64+
def next_chunk(self):
65+
return self._next_chunk
4566

4667
def load(self):
4768
"""
4869
Takes care of loading the raw data either chunk wise (if chunk size is defined) or all at once
4970
"""
50-
if self.chunk_size:
71+
if self._chunk_size:
5172
return self._load_next_chunk()
5273

5374
return self._load(self.indices)
5475

76+
def update_chunk_sizes(self, other):
77+
if not self._chunk_size and not other.chunk_size:
78+
return
79+
80+
if (
81+
self._chunk_size
82+
and not other.chunk_size
83+
or self._chunk_size < other.chunk_size
84+
):
85+
other.chunk_size = self.chunk_size
86+
else:
87+
self.chunk_size = other.chunk_size
88+
5589
def _load_next_chunk(self):
5690
"""
5791
Loads the next chunk of data
5892
"""
5993
self.data = []
6094
next_chunk_indices = self.indices[
61-
self.next_chunk * self.chunk_size : (self.next_chunk + 1) * self.chunk_size
95+
self._next_chunk
96+
* self._chunk_size : (self._next_chunk + 1)
97+
* self._chunk_size
6298
]
63-
self.next_chunk += 1
99+
self._next_chunk += 1
64100
return self._load(next_chunk_indices)
65101

66102
def _load(self, indices: List[str]):
@@ -73,13 +109,14 @@ def _load(self, indices: List[str]):
73109
else:
74110
self.extract(self.source_path, indices)
75111

76-
return self.data
112+
return self.data, self.metadata
77113

78114
@abstractmethod
79115
def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):
80116
pass
81117

82-
def file_sanity_check(self, file):
118+
@staticmethod
119+
def file_sanity_check(file):
83120
"""
84121
Checks if the file can be found is not empty
85122
"""

src/main/python/systemds/scuro/dataloader/json_loader.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import json
2222

2323
from systemds.scuro.dataloader.base_loader import BaseLoader
24-
from typing import Optional, List
24+
from typing import Optional, List, Union
2525

2626

2727
class JSONLoader(BaseLoader):
@@ -35,9 +35,9 @@ def __init__(
3535
super().__init__(source_path, indices, chunk_size)
3636
self.field = field
3737

38-
def extract(self, file: str, indices: List[str]):
38+
def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):
3939
self.file_sanity_check(file)
4040
with open(file) as f:
4141
json_file = json.load(f)
42-
for idx in indices:
42+
for idx in index:
4343
self.data.append(json_file[idx][self.field])

src/main/python/systemds/scuro/dataloader/text_loader.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
#
2020
# -------------------------------------------------------------
2121
from systemds.scuro.dataloader.base_loader import BaseLoader
22-
from typing import Optional, Pattern, List
22+
from typing import Optional, Pattern, List, Union
2323
import re
2424

2525

@@ -34,11 +34,12 @@ def __init__(
3434
super().__init__(source_path, indices, chunk_size)
3535
self.prefix = prefix
3636

37-
def extract(self, file: str):
37+
def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):
3838
self.file_sanity_check(file)
3939
with open(file) as text_file:
4040
for i, line in enumerate(text_file):
4141
if self.prefix:
4242
line = re.sub(self.prefix, "", line)
4343
line = line.replace("\n", "")
44+
self.metadata[file] = {"length": len(line.split())}
4445
self.data.append(line)

src/main/python/systemds/scuro/dataloader/video_loader.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,12 @@
1818
# under the License.
1919
#
2020
# -------------------------------------------------------------
21-
from typing import List, Optional
21+
from typing import List, Optional, Union
2222

2323
import numpy as np
2424

2525
from systemds.scuro.dataloader.base_loader import BaseLoader
26+
from systemds.scuro.utils.schema_helpers import create_timestamps
2627
import cv2
2728

2829

@@ -35,9 +36,25 @@ def __init__(
3536
):
3637
super().__init__(source_path, indices, chunk_size)
3738

38-
def extract(self, file: str):
39+
def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):
3940
self.file_sanity_check(file)
4041
cap = cv2.VideoCapture(file)
42+
43+
if not cap.isOpened():
44+
raise f"Could not read video at path: {file}"
45+
46+
self.metadata[file] = {
47+
"fps": cap.get(cv2.CAP_PROP_FPS),
48+
"length": int(cap.get(cv2.CAP_PROP_FRAME_COUNT)),
49+
"width": int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
50+
"height": int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)),
51+
"num_channels": 3,
52+
}
53+
54+
self.metadata[file]["timestamp"] = create_timestamps(
55+
self.metadata[file]["fps"], self.metadata[file]["length"]
56+
)
57+
4158
frames = []
4259
while cap.isOpened():
4360
ret, frame = cap.read()

0 commit comments

Comments
 (0)