Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/main/python/systemds/scuro/dataloader/image_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):
else:
height, width, channels = image.shape

image = image.astype(np.float32) / 255.0
image = image.astype(np.uint8, copy=False)

self.metadata[file] = self.modality_type.create_metadata(
width, height, channels
Expand Down
2 changes: 1 addition & 1 deletion src/main/python/systemds/scuro/dataloader/json_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,6 @@ def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):
except:
text = json_file[self.field]

text = " ".join(text)
text = " ".join(text) if isinstance(text, list) else text
self.data.append(text)
self.metadata[idx] = self.modality_type.create_metadata(len(text), text)
Original file line number Diff line number Diff line change
Expand Up @@ -177,11 +177,11 @@ def visit_node(node_id):

if self.maximize_metric:
best_params, best_score = max(
all_results, key=lambda x: x[1].scores[self.scoring_metric]
all_results, key=lambda x: x[1].average_scores[self.scoring_metric]
)
else:
best_params, best_score = min(
all_results, key=lambda x: x[1].scores[self.scoring_metric]
all_results, key=lambda x: x[1].average_scores[self.scoring_metric]
)

tuning_time = time.time() - start_time
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,19 @@
import numpy as np
import cv2

from systemds.scuro.drsearch.operator_registry import register_representation
from systemds.scuro.modality.type import ModalityType
from systemds.scuro.representations.unimodal import UnimodalRepresentation
from systemds.scuro.modality.transformed import TransformedModality


@register_representation(ModalityType.IMAGE)
class ColorHistogram(UnimodalRepresentation):
def __init__(
self,
color_space="RGB",
bins=32,
normalize=True,
bins=64,
normalize=False,
aggregation="mean",
output_file=None,
):
Expand All @@ -48,7 +50,7 @@ def __init__(
def _get_parameters(self):
return {
"color_space": ["RGB", "HSV", "GRAY"],
"bins": [8, 16, 32, 64, 128, 256, (8, 8, 8), (16, 16, 16)],
"bins": [8, 16, 32, 64, 128, 256],
"normalize": [True, False],
"aggregation": ["mean", "max", "concat"],
}
Expand Down
2 changes: 1 addition & 1 deletion src/main/python/systemds/scuro/representations/fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def transform_with_training(self, modalities: List[Modality], task):
(len(modalities[0].data), transformed_train.shape[1])
)
transformed_data[task.train_indices] = transformed_train
transformed_data[task.val_indices] = transformed_val
transformed_data[task.test_indices] = transformed_other

return transformed_data

Expand Down
2 changes: 1 addition & 1 deletion src/main/python/systemds/scuro/representations/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def _get_parameters(self, high_level=True):
return parameters

def transform(self, modality):
self.data_type = numpy_dtype_to_torch_dtype(modality.data_type)
self.data_type = torch.float32
if next(self.model.parameters()).dtype != self.data_type:
self.model = self.model.to(self.data_type)

Expand Down
2 changes: 1 addition & 1 deletion src/main/python/systemds/scuro/representations/vgg.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _get_parameters(self):
return parameters

def transform(self, modality):
self.data_type = numpy_dtype_to_torch_dtype(modality.data_type)
self.data_type = torch.float32
if next(self.model.parameters()).dtype != self.data_type:
self.model = self.model.to(self.data_type)

Expand Down
1 change: 0 additions & 1 deletion src/main/python/systemds/scuro/utils/torch_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def __getitem__(self, index) -> Dict[str, object]:

if isinstance(data, np.ndarray) and data.ndim == 3:
# image
data = torch.tensor(data).permute(2, 0, 1)
output = self.tf(data).to(self.device)
else:
for i, d in enumerate(data):
Expand Down
Loading