Skip to content

Commit c9a54fe

Browse files
[SYSTEMDS-3887] Refactor representation optimizers (#2308)
This patch adds an updated version of the unimodal and multimodal representation optimizers. It includes improved handling of optimization results, and more readable debug output for better operator tracing.I added additional tests for the adapted optimizers and the fusion representations.
1 parent 9e23ad8 commit c9a54fe

39 files changed

+1161
-211
lines changed

src/main/python/systemds/scuro/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@
7373
from systemds.scuro.drsearch.unimodal_representation_optimizer import (
7474
UnimodalRepresentationOptimizer,
7575
)
76+
from systemds.scuro.drsearch.multimodal_optimizer import MultimodalOptimizer
77+
from systemds.scuro.drsearch.unimodal_optimizer import UnimodalOptimizer
7678

7779

7880
__all__ = [
@@ -127,4 +129,6 @@
127129
"OptimizationData",
128130
"RepresentationCache",
129131
"UnimodalRepresentationOptimizer",
132+
"UnimodalOptimizer",
133+
"MultimodalOptimizer",
130134
]

src/main/python/systemds/scuro/dataloader/audio_loader.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -45,18 +45,18 @@ def __init__(
4545

4646
def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):
4747
self.file_sanity_check(file)
48-
# if not self.load_data_from_file:
49-
# import numpy as np
50-
#
51-
# self.metadata[file] = self.modality_type.create_audio_metadata(
52-
# 1000, np.array([0])
53-
# )
54-
# else:
55-
audio, sr = librosa.load(file, dtype=self._data_type)
48+
if not self.load_data_from_file:
49+
import numpy as np
5650

57-
if self.normalize:
58-
audio = librosa.util.normalize(audio)
51+
self.metadata[file] = self.modality_type.create_audio_metadata(
52+
1000, np.array([0])
53+
)
54+
else:
55+
audio, sr = librosa.load(file, dtype=self._data_type)
5956

60-
self.metadata[file] = self.modality_type.create_audio_metadata(sr, audio)
57+
if self.normalize:
58+
audio = librosa.util.normalize(audio)
6159

62-
self.data.append(audio)
60+
self.metadata[file] = self.modality_type.create_audio_metadata(sr, audio)
61+
62+
self.data.append(audio)

src/main/python/systemds/scuro/dataloader/video_loader.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,13 @@ def __init__(
3535
data_type: Union[np.dtype, str] = np.float16,
3636
chunk_size: Optional[int] = None,
3737
load=True,
38+
fps=None,
3839
):
3940
super().__init__(
4041
source_path, indices, data_type, chunk_size, ModalityType.VIDEO
4142
)
4243
self.load_data_from_file = load
44+
self.fps = fps
4345

4446
def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):
4547
self.file_sanity_check(file)
@@ -53,25 +55,33 @@ def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):
5355
if not cap.isOpened():
5456
raise f"Could not read video at path: {file}"
5557

56-
fps = cap.get(cv2.CAP_PROP_FPS)
58+
orig_fps = cap.get(cv2.CAP_PROP_FPS)
59+
frame_interval = 1
60+
if self.fps is not None and self.fps < orig_fps:
61+
frame_interval = int(round(orig_fps / self.fps))
62+
else:
63+
self.fps = orig_fps
64+
5765
length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
5866
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
5967
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
6068
num_channels = 3
6169

6270
self.metadata[file] = self.modality_type.create_video_metadata(
63-
fps, length, width, height, num_channels
71+
self.fps, length, width, height, num_channels
6472
)
6573

6674
frames = []
75+
idx = 0
6776
while cap.isOpened():
6877
ret, frame = cap.read()
6978

7079
if not ret:
7180
break
72-
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
73-
frame = frame.astype(self._data_type) / 255.0
74-
75-
frames.append(frame)
81+
if idx % frame_interval == 0:
82+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
83+
frame = frame.astype(self._data_type) / 255.0
84+
frames.append(frame)
85+
idx += 1
7686

7787
self.data.append(np.stack(frames))

src/main/python/systemds/scuro/drsearch/dr_search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def set_best_params(
7676
"""
7777

7878
# check if modality name is already in dictionary
79-
if "_".join(modality_names) not in self.scores.keys():
79+
if "_".join(modality_names) not in list(self.scores.keys()):
8080
# if not add it to dictionary
8181
self.scores["_".join(modality_names)] = {}
8282

0 commit comments

Comments
 (0)