Skip to content

Commit 5292b42

Browse files
[SYSTEMDS-3913] Adapt visual representations to image modality
This patch fixes some errors in visual representations in order to make them work for video and images.
1 parent 608ddcd commit 5292b42

File tree

8 files changed

+12
-11
lines changed

8 files changed

+12
-11
lines changed

src/main/python/systemds/scuro/dataloader/image_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):
5454
else:
5555
height, width, channels = image.shape
5656

57-
image = image.astype(np.float32) / 255.0
57+
image = image.astype(np.uint8, copy=False)
5858

5959
self.metadata[file] = self.modality_type.create_metadata(
6060
width, height, channels

src/main/python/systemds/scuro/dataloader/json_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,6 @@ def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):
5555
except:
5656
text = json_file[self.field]
5757

58-
text = " ".join(text)
58+
text = " ".join(text) if isinstance(text, list) else text
5959
self.data.append(text)
6060
self.metadata[idx] = self.modality_type.create_metadata(len(text), text)

src/main/python/systemds/scuro/drsearch/hyperparameter_tuner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,11 +177,11 @@ def visit_node(node_id):
177177

178178
if self.maximize_metric:
179179
best_params, best_score = max(
180-
all_results, key=lambda x: x[1].scores[self.scoring_metric]
180+
all_results, key=lambda x: x[1].average_scores[self.scoring_metric]
181181
)
182182
else:
183183
best_params, best_score = min(
184-
all_results, key=lambda x: x[1].scores[self.scoring_metric]
184+
all_results, key=lambda x: x[1].average_scores[self.scoring_metric]
185185
)
186186

187187
tuning_time = time.time() - start_time

src/main/python/systemds/scuro/representations/color_histogram.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,19 @@
2222
import numpy as np
2323
import cv2
2424

25+
from systemds.scuro.drsearch.operator_registry import register_representation
2526
from systemds.scuro.modality.type import ModalityType
2627
from systemds.scuro.representations.unimodal import UnimodalRepresentation
2728
from systemds.scuro.modality.transformed import TransformedModality
2829

2930

31+
@register_representation(ModalityType.IMAGE)
3032
class ColorHistogram(UnimodalRepresentation):
3133
def __init__(
3234
self,
3335
color_space="RGB",
34-
bins=32,
35-
normalize=True,
36+
bins=64,
37+
normalize=False,
3638
aggregation="mean",
3739
output_file=None,
3840
):
@@ -48,7 +50,7 @@ def __init__(
4850
def _get_parameters(self):
4951
return {
5052
"color_space": ["RGB", "HSV", "GRAY"],
51-
"bins": [8, 16, 32, 64, 128, 256, (8, 8, 8), (16, 16, 16)],
53+
"bins": [8, 16, 32, 64, 128, 256],
5254
"normalize": [True, False],
5355
"aggregation": ["mean", "max", "concat"],
5456
}

src/main/python/systemds/scuro/representations/fusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def transform_with_training(self, modalities: List[Modality], task):
8686
(len(modalities[0].data), transformed_train.shape[1])
8787
)
8888
transformed_data[task.train_indices] = transformed_train
89-
transformed_data[task.val_indices] = transformed_val
89+
transformed_data[task.test_indices] = transformed_other
9090

9191
return transformed_data
9292

src/main/python/systemds/scuro/representations/resnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def _get_parameters(self, high_level=True):
114114
return parameters
115115

116116
def transform(self, modality):
117-
self.data_type = numpy_dtype_to_torch_dtype(modality.data_type)
117+
self.data_type = torch.float32
118118
if next(self.model.parameters()).dtype != self.data_type:
119119
self.model = self.model.to(self.data_type)
120120

src/main/python/systemds/scuro/representations/vgg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def _get_parameters(self):
6565
return parameters
6666

6767
def transform(self, modality):
68-
self.data_type = numpy_dtype_to_torch_dtype(modality.data_type)
68+
self.data_type = torch.float32
6969
if next(self.model.parameters()).dtype != self.data_type:
7070
self.model = self.model.to(self.data_type)
7171

src/main/python/systemds/scuro/utils/torch_dataset.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ def __getitem__(self, index) -> Dict[str, object]:
6262

6363
if isinstance(data, np.ndarray) and data.ndim == 3:
6464
# image
65-
data = torch.tensor(data).permute(2, 0, 1)
6665
output = self.tf(data).to(self.device)
6766
else:
6867
for i, d in enumerate(data):

0 commit comments

Comments
 (0)