Skip to content

Commit 6658c8f

Browse files
[SYSTEMDS-3936] Refactor metadata generation and add additional dataloader
This patch introduces a new dataloader for the image modality. Additionally, it refines the way how the metadata for each modality is created.
1 parent 3779d50 commit 6658c8f

File tree

13 files changed

+163
-42
lines changed

13 files changed

+163
-42
lines changed

src/main/python/systemds/scuro/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
# -------------------------------------------------------------
2121
from systemds.scuro.dataloader.base_loader import BaseLoader
2222
from systemds.scuro.dataloader.audio_loader import AudioLoader
23+
from systemds.scuro.dataloader.image_loader import ImageLoader
2324
from systemds.scuro.dataloader.video_loader import VideoLoader
2425
from systemds.scuro.dataloader.text_loader import TextLoader
2526
from systemds.scuro.dataloader.json_loader import JSONLoader
@@ -103,6 +104,7 @@
103104

104105
__all__ = [
105106
"BaseLoader",
107+
"ImageLoader",
106108
"AudioLoader",
107109
"VideoLoader",
108110
"TextLoader",

src/main/python/systemds/scuro/dataloader/audio_loader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):
4747
if not self.load_data_from_file:
4848
import numpy as np
4949

50-
self.metadata[file] = self.modality_type.create_audio_metadata(
50+
self.metadata[file] = self.modality_type.create_metadata(
5151
1000, np.array([0])
5252
)
5353
else:
@@ -56,6 +56,6 @@ def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):
5656
if self.normalize:
5757
audio = librosa.util.normalize(audio)
5858

59-
self.metadata[file] = self.modality_type.create_audio_metadata(sr, audio)
59+
self.metadata[file] = self.modality_type.create_metadata(sr, audio)
6060

6161
self.data.append(audio)

src/main/python/systemds/scuro/dataloader/base_loader.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def __init__(
3434
data_type: Union[np.dtype, str],
3535
chunk_size: Optional[int] = None,
3636
modality_type=None,
37+
ext=None,
3738
):
3839
"""
3940
Base class to load raw data for a given list of indices and stores them in the data object
@@ -53,6 +54,7 @@ def __init__(
5354
self._num_chunks = 1
5455
self._chunk_size = None
5556
self._data_type = data_type
57+
self._ext = ext
5658

5759
if chunk_size:
5860
self.chunk_size = chunk_size
@@ -136,9 +138,10 @@ def get_file_names(self, indices=None):
136138
is_dir = True if os.path.isdir(self.source_path) else False
137139
file_names = []
138140
if is_dir:
139-
_, ext = os.path.splitext(os.listdir(self.source_path)[0])
141+
if self._ext is None:
142+
_, self._ext = os.path.splitext(os.listdir(self.source_path)[0])
140143
for index in self.indices if indices is None else indices:
141-
file_names.append(self.source_path + index + ext)
144+
file_names.append(self.source_path + index + self._ext)
142145
return file_names
143146
else:
144147
return self.source_path
@@ -155,10 +158,10 @@ def file_sanity_check(file):
155158
try:
156159
file_size = os.path.getsize(file)
157160
except:
158-
raise (f"Error: File {0} not found!".format(file))
161+
raise ValueError(f"Error: File {0} not found!".format(file))
159162

160163
if file_size == 0:
161-
raise ("File {0} is empty".format(file))
164+
raise ValueError("File {0} is empty".format(file))
162165

163166
@staticmethod
164167
def resolve_data_type(data_type):
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# -------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
# -------------------------------------------------------------
21+
from typing import List, Optional, Union
22+
23+
import numpy as np
24+
25+
from systemds.scuro.dataloader.base_loader import BaseLoader
26+
import cv2
27+
from systemds.scuro.modality.type import ModalityType
28+
29+
30+
class ImageLoader(BaseLoader):
31+
def __init__(
32+
self,
33+
source_path: str,
34+
indices: List[str],
35+
data_type: Union[np.dtype, str] = np.float16,
36+
chunk_size: Optional[int] = None,
37+
load=True,
38+
ext=".jpg",
39+
):
40+
super().__init__(
41+
source_path, indices, data_type, chunk_size, ModalityType.IMAGE, ext
42+
)
43+
self.load_data_from_file = load
44+
45+
def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):
46+
self.file_sanity_check(file)
47+
48+
image = cv2.imread(file, cv2.IMREAD_COLOR)
49+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
50+
51+
if image.ndim == 2:
52+
height, width = image.shape
53+
channels = 1
54+
else:
55+
height, width, channels = image.shape
56+
57+
image = image.astype(np.float32) / 255.0
58+
59+
self.metadata[file] = self.modality_type.create_metadata(
60+
width, height, channels
61+
)
62+
63+
self.data.append(image)

src/main/python/systemds/scuro/dataloader/json_loader.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,20 +32,28 @@ def __init__(
3232
self,
3333
source_path: str,
3434
indices: List[str],
35-
field: str,
35+
field: str, # TODO: make this a list so it is easier to get multiple fields from a json file. (i.e. Mustard: context + sentence)
3636
data_type: Union[np.dtype, str] = str,
3737
chunk_size: Optional[int] = None,
38+
ext: str = ".json",
3839
):
39-
super().__init__(source_path, indices, data_type, chunk_size, ModalityType.TEXT)
40+
super().__init__(
41+
source_path, indices, data_type, chunk_size, ModalityType.TEXT, ext
42+
)
4043
self.field = field
4144

4245
def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):
4346
self.file_sanity_check(file)
4447
with open(file) as f:
4548
json_file = json.load(f)
49+
50+
if isinstance(index, str):
51+
index = [index]
4652
for idx in index:
47-
sentence = json_file[idx][self.field]
48-
self.data.append(sentence)
49-
self.metadata[idx] = self.modality_type.create_text_metadata(
50-
len(sentence), sentence
51-
)
53+
try:
54+
text = json_file[idx][self.field]
55+
except:
56+
text = json_file[self.field]
57+
58+
self.data.append(text)
59+
self.metadata[idx] = self.modality_type.create_metadata(len(text), text)

src/main/python/systemds/scuro/dataloader/text_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):
4343
if self.prefix:
4444
line = re.sub(self.prefix, "", line)
4545
line = line.replace("\n", "")
46-
self.metadata[file] = self.modality_type.create_text_metadata(
46+
self.metadata[file] = self.modality_type.create_metadata(
4747
len(line.split()), line
4848
)
4949
self.data.append(line)

src/main/python/systemds/scuro/dataloader/timeseries_loader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,12 @@ def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):
7070
data = self._normalize_signals(data)
7171

7272
if file:
73-
self.metadata[index] = self.modality_type.create_ts_metadata(
73+
self.metadata[index] = self.modality_type.create_metadata(
7474
self.signal_names, data, self.sampling_rate
7575
)
7676
else:
7777
for i, index in enumerate(self.indices):
78-
self.metadata[str(index)] = self.modality_type.create_ts_metadata(
78+
self.metadata[str(index)] = self.modality_type.create_metadata(
7979
self.signal_names, data[i], self.sampling_rate
8080
)
8181
self.data.append(data)

src/main/python/systemds/scuro/dataloader/video_loader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def __init__(
4646
def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):
4747
self.file_sanity_check(file)
4848
# if not self.load_data_from_file:
49-
# self.metadata[file] = self.modality_type.create_video_metadata(
49+
# self.metadata[file] = self.modality_type.create_metadata(
5050
# 30, 10, 100, 100, 3
5151
# )
5252
# else:
@@ -67,7 +67,7 @@ def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):
6767
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
6868
num_channels = 3
6969

70-
self.metadata[file] = self.modality_type.create_video_metadata(
70+
self.metadata[file] = self.modality_type.create_metadata(
7171
self.fps, length, width, height, num_channels
7272
)
7373

src/main/python/systemds/scuro/modality/joined_transformed.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ def __init__(self, left_modality, right_modality, transformation):
3636
:param transformation: Representation to be applied on the modality
3737
"""
3838
super().__init__(
39-
reduce(or_, [left_modality.modality_type], right_modality.modality_type),
39+
left_modality.modality_type,
40+
# reduce(or_, [left_modality.modality_type], right_modality.modality_type),
4041
data_type=left_modality.data_type,
4142
)
4243
self.transformation = transformation

src/main/python/systemds/scuro/modality/transformed.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@ def join(self, right, join_condition):
8787
right.extract_raw_data()
8888

8989
joined_modality = JoinedModality(
90-
reduce(or_, [right.modality_type], self.modality_type),
90+
self.modality_type,
91+
# reduce(or_, [right.modality_type], self.modality_type), #TODO
9192
self,
9293
right,
9394
join_condition,
@@ -136,8 +137,10 @@ def combine(self, other: Union[Modality, List[Modality]], fusion_method):
136137
fused_modality = TransformedModality(
137138
self, fusion_method, ModalityType.EMBEDDING
138139
)
140+
start_time = time.time()
139141
fused_modality.data = fusion_method.transform(self.create_modality_list(other))
140-
142+
end_time = time.time()
143+
fused_modality.transform_time = end_time - start_time
141144
return fused_modality
142145

143146
def combine_with_training(
@@ -147,7 +150,12 @@ def combine_with_training(
147150
self, fusion_method, ModalityType.EMBEDDING
148151
)
149152
modalities = self.create_modality_list(other)
153+
start_time = time.time()
150154
fused_modality.data = fusion_method.transform_with_training(modalities, task)
155+
end_time = time.time()
156+
fused_modality.transform_time = (
157+
end_time - start_time
158+
) # Note: this incldues the training time
151159

152160
return fused_modality
153161

0 commit comments

Comments
 (0)