Skip to content

Commit dc1f363

Browse files
[SYSTEMDS-3835] Add Modality Data Type
This patch adds a data type for all modalities. Closes #2270.
1 parent a5b298c commit dc1f363

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+846
-297
lines changed

src/main/python/systemds/scuro/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from systemds.scuro.representations.max import RowMax
4040
from systemds.scuro.representations.mel_spectrogram import MelSpectrogram
4141
from systemds.scuro.representations.mfcc import MFCC
42-
from systemds.scuro.representations.multiplication import Multiplication
42+
from systemds.scuro.representations.hadamard import Hadamard
4343
from systemds.scuro.representations.optical_flow import OpticalFlow
4444
from systemds.scuro.representations.representation import Representation
4545
from systemds.scuro.representations.representation_dataloader import NPY
@@ -52,7 +52,7 @@
5252
from systemds.scuro.representations.tfidf import TfIdf
5353
from systemds.scuro.representations.unimodal import UnimodalRepresentation
5454
from systemds.scuro.representations.wav2vec import Wav2Vec
55-
from systemds.scuro.representations.window import WindowAggregation
55+
from systemds.scuro.representations.window_aggregation import WindowAggregation
5656
from systemds.scuro.representations.word2vec import W2V
5757
from systemds.scuro.representations.x3d import X3D
5858
from systemds.scuro.models.model import Model
@@ -94,7 +94,7 @@
9494
"RowMax",
9595
"MelSpectrogram",
9696
"MFCC",
97-
"Multiplication",
97+
"Hadamard",
9898
"OpticalFlow",
9999
"Representation",
100100
"NPY",

src/main/python/systemds/scuro/dataloader/audio_loader.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from typing import List, Optional, Union
2222

2323
import librosa
24+
import numpy as np
25+
2426
from systemds.scuro.dataloader.base_loader import BaseLoader
2527
from systemds.scuro.modality.type import ModalityType
2628

@@ -30,15 +32,27 @@ def __init__(
3032
self,
3133
source_path: str,
3234
indices: List[str],
35+
data_type: Union[np.dtype, str] = np.float32,
3336
chunk_size: Optional[int] = None,
3437
normalize: bool = True,
38+
load=True,
3539
):
36-
super().__init__(source_path, indices, chunk_size, ModalityType.AUDIO)
40+
super().__init__(
41+
source_path, indices, data_type, chunk_size, ModalityType.AUDIO
42+
)
3743
self.normalize = normalize
44+
self.load_data_from_file = load
3845

3946
def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):
4047
self.file_sanity_check(file)
41-
audio, sr = librosa.load(file)
48+
# if not self.load_data_from_file:
49+
# import numpy as np
50+
#
51+
# self.metadata[file] = self.modality_type.create_audio_metadata(
52+
# 1000, np.array([0])
53+
# )
54+
# else:
55+
audio, sr = librosa.load(file, dtype=self._data_type)
4256

4357
if self.normalize:
4458
audio = librosa.util.normalize(audio)

src/main/python/systemds/scuro/dataloader/base_loader.py

Lines changed: 56 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,17 @@
2121
import os
2222
from abc import ABC, abstractmethod
2323
from typing import List, Optional, Union
24+
import math
25+
26+
import numpy as np
2427

2528

2629
class BaseLoader(ABC):
2730
def __init__(
2831
self,
2932
source_path: str,
3033
indices: List[str],
34+
data_type: Union[np.dtype, str],
3135
chunk_size: Optional[int] = None,
3236
modality_type=None,
3337
):
@@ -48,6 +52,7 @@ def __init__(
4852
self._next_chunk = 0
4953
self._num_chunks = 1
5054
self._chunk_size = None
55+
self._data_type = data_type
5156

5257
if chunk_size:
5358
self.chunk_size = chunk_size
@@ -59,7 +64,7 @@ def chunk_size(self):
5964
@chunk_size.setter
6065
def chunk_size(self, value):
6166
self._chunk_size = value
62-
self._num_chunks = int(len(self.indices) / self._chunk_size)
67+
self._num_chunks = int(math.ceil(len(self.indices) / self._chunk_size))
6368

6469
@property
6570
def num_chunks(self):
@@ -69,6 +74,14 @@ def num_chunks(self):
6974
def next_chunk(self):
7075
return self._next_chunk
7176

77+
@property
78+
def data_type(self):
79+
return self._data_type
80+
81+
@data_type.setter
82+
def data_type(self, data_type):
83+
self._data_type = self.resolve_data_type(data_type)
84+
7285
def reset(self):
7386
self._next_chunk = 0
7487
self.data = []
@@ -110,16 +123,25 @@ def _load_next_chunk(self):
110123
return self._load(next_chunk_indices)
111124

112125
def _load(self, indices: List[str]):
113-
is_dir = True if os.path.isdir(self.source_path) else False
126+
file_names = self.get_file_names(indices)
127+
if isinstance(file_names, str):
128+
self.extract(file_names, indices)
129+
else:
130+
for file_name in file_names:
131+
self.extract(file_name)
132+
133+
return self.data, self.metadata
114134

135+
def get_file_names(self, indices=None):
136+
is_dir = True if os.path.isdir(self.source_path) else False
137+
file_names = []
115138
if is_dir:
116139
_, ext = os.path.splitext(os.listdir(self.source_path)[0])
117-
for index in indices:
118-
self.extract(self.source_path + index + ext)
140+
for index in self.indices if indices is None else indices:
141+
file_names.append(self.source_path + index + ext)
142+
return file_names
119143
else:
120-
self.extract(self.source_path, indices)
121-
122-
return self.data, self.metadata
144+
return self.source_path
123145

124146
@abstractmethod
125147
def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):
@@ -137,3 +159,30 @@ def file_sanity_check(file):
137159

138160
if file_size == 0:
139161
raise ("File {0} is empty".format(file))
162+
163+
@staticmethod
164+
def resolve_data_type(data_type):
165+
if isinstance(data_type, str):
166+
if data_type.lower() in [
167+
"float16",
168+
"float32",
169+
"float64",
170+
"int16",
171+
"int32",
172+
"int64",
173+
]:
174+
return np.dtype(data_type)
175+
else:
176+
raise ValueError(f"Unsupported data_type string: {data_type}")
177+
elif data_type in [
178+
np.float16,
179+
np.float32,
180+
np.float64,
181+
np.int16,
182+
np.int32,
183+
np.int64,
184+
str,
185+
]:
186+
return data_type
187+
else:
188+
raise ValueError(f"Unsupported data_type: {data_type}")

src/main/python/systemds/scuro/dataloader/json_loader.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
# -------------------------------------------------------------
2121
import json
2222

23+
import numpy as np
24+
2325
from systemds.scuro.modality.type import ModalityType
2426
from systemds.scuro.dataloader.base_loader import BaseLoader
2527
from typing import Optional, List, Union
@@ -31,9 +33,10 @@ def __init__(
3133
source_path: str,
3234
indices: List[str],
3335
field: str,
36+
data_type: Union[np.dtype, str] = str,
3437
chunk_size: Optional[int] = None,
3538
):
36-
super().__init__(source_path, indices, chunk_size, ModalityType.TEXT)
39+
super().__init__(source_path, indices, data_type, chunk_size, ModalityType.TEXT)
3740
self.field = field
3841

3942
def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):

src/main/python/systemds/scuro/dataloader/text_loader.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,11 @@ def __init__(
2929
self,
3030
source_path: str,
3131
indices: List[str],
32+
data_type: str = str,
3233
chunk_size: Optional[int] = None,
3334
prefix: Optional[Pattern[str]] = None,
3435
):
35-
super().__init__(source_path, indices, chunk_size, ModalityType.TEXT)
36+
super().__init__(source_path, indices, data_type, chunk_size, ModalityType.TEXT)
3637
self.prefix = prefix
3738

3839
def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):

src/main/python/systemds/scuro/dataloader/video_loader.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,22 @@ def __init__(
3232
self,
3333
source_path: str,
3434
indices: List[str],
35+
data_type: Union[np.dtype, str] = np.float16,
3536
chunk_size: Optional[int] = None,
37+
load=True,
3638
):
37-
super().__init__(source_path, indices, chunk_size, ModalityType.VIDEO)
39+
super().__init__(
40+
source_path, indices, data_type, chunk_size, ModalityType.VIDEO
41+
)
42+
self.load_data_from_file = load
3843

3944
def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):
4045
self.file_sanity_check(file)
46+
# if not self.load_data_from_file:
47+
# self.metadata[file] = self.modality_type.create_video_metadata(
48+
# 30, 10, 100, 100, 3
49+
# )
50+
# else:
4151
cap = cv2.VideoCapture(file)
4252

4353
if not cap.isOpened():
@@ -60,8 +70,8 @@ def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):
6070
if not ret:
6171
break
6272
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
63-
frame = frame.astype(np.float32) / 255.0
73+
frame = frame.astype(self._data_type) / 255.0
6474

6575
frames.append(frame)
6676

67-
self.data.append(frames)
77+
self.data.append(np.stack(frames))

src/main/python/systemds/scuro/drsearch/operator_registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def get_representations(self, modality: ModalityType):
5858
return self._representations[modality]
5959

6060
def get_context_operators(self):
61+
# TODO: return modality specific context operations
6162
return self._context_operators
6263

6364
def get_fusion_operators(self):

src/main/python/systemds/scuro/drsearch/representation_cache.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,8 @@ def load_from_cache(self, modality, operators):
112112
metadata = pickle.load(f)
113113

114114
transformed_modality = TransformedModality(
115-
modality.modality_type, op_names, modality.modality_id, metadata
115+
modality,
116+
op_names,
116117
)
117118
data = None
118119
with open(f"{filename}.pkl", "rb") as f:

src/main/python/systemds/scuro/modality/joined_transformed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
from systemds.scuro.modality.modality import Modality
2727
from systemds.scuro.representations.utils import pad_sequences
28-
from systemds.scuro.representations.window import WindowAggregation
28+
from systemds.scuro.representations.window_aggregation import WindowAggregation
2929

3030

3131
class JoinedTransformedModality(Modality):
@@ -70,7 +70,7 @@ def combine(self, fusion_method):
7070
self.data = pad_sequences(self.data)
7171
return self
7272

73-
def window(self, window_size, aggregation):
73+
def window_aggregation(self, window_size, aggregation):
7474
w = WindowAggregation(window_size, aggregation)
7575
self.left_modality.data = w.execute(self.left_modality)
7676
self.right_modality.data = w.execute(self.right_modality)

src/main/python/systemds/scuro/modality/modality.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@
2929

3030
class Modality:
3131

32-
def __init__(self, modalityType: ModalityType, modality_id=-1, metadata={}):
32+
def __init__(
33+
self, modalityType: ModalityType, modality_id=-1, metadata={}, data_type=None
34+
):
3335
"""
3436
Parent class of the different Modalities (unimodal & multimodal)
3537
:param modality_type: Type of the modality
@@ -38,7 +40,7 @@ def __init__(self, modalityType: ModalityType, modality_id=-1, metadata={}):
3840
self.schema = modalityType.get_schema()
3941
self.metadata = metadata
4042
self.data = []
41-
self.data_type = None
43+
self.data_type = data_type
4244
self.cost = None
4345
self.shape = None
4446
self.modality_id = modality_id
@@ -67,7 +69,9 @@ def copy_from_instance(self):
6769
"""
6870
Create a copy of the modality instance
6971
"""
70-
return type(self)(self.modality_type, self.metadata)
72+
return type(self)(
73+
self.modality_type, self.modality_id, self.metadata, self.data_type
74+
)
7175

7276
def update_metadata(self):
7377
"""

0 commit comments

Comments
 (0)