diff --git a/src/main/python/systemds/scuro/representations/bert.py b/src/main/python/systemds/scuro/representations/bert.py index 3478b84e672..4d486bff59d 100644 --- a/src/main/python/systemds/scuro/representations/bert.py +++ b/src/main/python/systemds/scuro/representations/bert.py @@ -26,12 +26,34 @@ from systemds.scuro.representations.utils import save_embeddings from systemds.scuro.modality.type import ModalityType from systemds.scuro.drsearch.operator_registry import register_representation - +from systemds.scuro.utils.static_variables import get_device import os +from torch.utils.data import Dataset, DataLoader os.environ["TOKENIZERS_PARALLELISM"] = "false" +class TextDataset(Dataset): + def __init__(self, texts): + + self.texts = [] + for text in texts: + if text is None: + self.texts.append("") + elif isinstance(text, np.ndarray): + self.texts.append(str(text.item()) if text.size == 1 else str(text)) + elif not isinstance(text, str): + self.texts.append(str(text)) + else: + self.texts.append(text) + + def __len__(self): + return len(self.texts) + + def __getitem__(self, idx): + return self.texts[idx] + + @register_representation(ModalityType.TEXT) class Bert(UnimodalRepresentation): def __init__(self, model_name="bert", output_file=None, max_seq_length=512): @@ -49,7 +71,7 @@ def transform(self, modality): model_name, clean_up_tokenization_spaces=True ) - model = BertModel.from_pretrained(model_name) + model = BertModel.from_pretrained(model_name).to(get_device()) embeddings = self.create_embeddings(modality, model, tokenizer) @@ -57,32 +79,42 @@ def transform(self, modality): save_embeddings(embeddings, self.output_file) transformed_modality.data_type = np.float32 - transformed_modality.data = embeddings + transformed_modality.data = np.array(embeddings) return transformed_modality def create_embeddings(self, modality, model, tokenizer): - inputs = tokenizer( - modality.data, - return_offsets_mapping=True, - return_tensors="pt", - padding="longest", - return_attention_mask=True, - truncation=True, - ) - ModalityType.TEXT.add_field_for_instances( - modality.metadata, - "token_to_character_mapping", - inputs.data["offset_mapping"].tolist(), - ) + dataset = TextDataset(modality.data) + dataloader = DataLoader(dataset, batch_size=32, shuffle=False, collate_fn=None) + cls_embeddings = [] + for batch in dataloader: + inputs = tokenizer( + batch, + return_offsets_mapping=True, + return_tensors="pt", + padding="max_length", + return_attention_mask=True, + truncation=True, + max_length=512, # TODO: make this dynamic + ) - ModalityType.TEXT.add_field_for_instances( - modality.metadata, "attention_masks", inputs.data["attention_mask"].tolist() - ) - del inputs.data["offset_mapping"] + inputs.to(get_device()) + ModalityType.TEXT.add_field_for_instances( + modality.metadata, + "token_to_character_mapping", + inputs.data["offset_mapping"].tolist(), + ) + + ModalityType.TEXT.add_field_for_instances( + modality.metadata, + "attention_masks", + inputs.data["attention_mask"].tolist(), + ) + del inputs.data["offset_mapping"] - with torch.no_grad(): - outputs = model(**inputs) + with torch.no_grad(): + outputs = model(**inputs) - cls_embedding = outputs.last_hidden_state.detach().numpy() + cls_embedding = outputs.last_hidden_state.detach().cpu().numpy() + cls_embeddings.extend(cls_embedding) - return cls_embedding + return np.array(cls_embeddings) diff --git a/src/main/python/systemds/scuro/representations/clip.py b/src/main/python/systemds/scuro/representations/clip.py index 044d0f795a6..1d458aeb7d0 100644 --- a/src/main/python/systemds/scuro/representations/clip.py +++ b/src/main/python/systemds/scuro/representations/clip.py @@ -81,6 +81,7 @@ def create_visual_embeddings(self, modality): frame_batch = frames[frame_ids_range] inputs = self.processor(images=frame_batch, return_tensors="pt") + inputs.to(get_device()) with torch.no_grad(): output = self.model.get_image_features(**inputs) @@ -125,9 +126,18 @@ def transform(self, modality): def create_text_embeddings(self, data, model): embeddings = [] for d in data: - inputs = self.processor(text=d, return_tensors="pt", padding=True) + inputs = self.processor( + text=d, + return_tensors="pt", + padding=True, + truncation=True, + max_length=77, + ) + inputs.to(get_device()) with torch.no_grad(): text_embedding = model.get_text_features(**inputs) - embeddings.append(text_embedding.squeeze().numpy().reshape(1, -1)) + embeddings.append( + text_embedding.squeeze().detach().cpu().numpy().reshape(1, -1) + ) return embeddings diff --git a/src/main/python/systemds/scuro/representations/concatenation.py b/src/main/python/systemds/scuro/representations/concatenation.py index a4d4d53c43e..bf854a481fd 100644 --- a/src/main/python/systemds/scuro/representations/concatenation.py +++ b/src/main/python/systemds/scuro/representations/concatenation.py @@ -41,7 +41,12 @@ def __init__(self): def execute(self, modalities: List[Modality]): if len(modalities) == 1: - return np.array(modalities[0].data) + return np.asarray( + modalities[0].data, + dtype=modalities[0].metadata[list(modalities[0].metadata.keys())[0]][ + "data_layout" + ]["type"], + ) max_emb_size = self.get_max_embedding_size(modalities) size = len(modalities[0].data) @@ -52,6 +57,18 @@ def execute(self, modalities: List[Modality]): data = np.zeros((size, 0)) for modality in modalities: - data = np.concatenate([data, copy.deepcopy(modality.data)], axis=-1) + other_modality = copy.deepcopy(modality.data) + data = np.concatenate( + [ + data, + np.asarray( + other_modality, + dtype=modality.metadata[list(modality.metadata.keys())[0]][ + "data_layout" + ]["type"], + ), + ], + axis=-1, + ) return np.array(data) diff --git a/src/main/python/systemds/scuro/representations/sum.py b/src/main/python/systemds/scuro/representations/sum.py index 5b3710b6e14..79655a4b665 100644 --- a/src/main/python/systemds/scuro/representations/sum.py +++ b/src/main/python/systemds/scuro/representations/sum.py @@ -21,7 +21,7 @@ from typing import List - +import numpy as np from systemds.scuro.modality.modality import Modality from systemds.scuro.representations.utils import pad_sequences @@ -40,9 +40,18 @@ def __init__(self): self.needs_alignment = True def execute(self, modalities: List[Modality]): - data = modalities[0].data + data = np.asarray( + modalities[0].data, + dtype=modalities[0].metadata[list(modalities[0].metadata.keys())[0]][ + "data_layout" + ]["type"], + ) for m in range(1, len(modalities)): - data += modalities[m].data - + data += np.asarray( + modalities[m].data, + dtype=modalities[m].metadata[list(modalities[m].metadata.keys())[0]][ + "data_layout" + ]["type"], + ) return data diff --git a/src/main/python/systemds/scuro/representations/vgg.py b/src/main/python/systemds/scuro/representations/vgg.py index 374586f2b9a..4d0212883c6 100644 --- a/src/main/python/systemds/scuro/representations/vgg.py +++ b/src/main/python/systemds/scuro/representations/vgg.py @@ -71,7 +71,6 @@ def transform(self, modality): dataset = CustomDataset(modality.data, self.data_type, get_device()) embeddings = {} - activations = {} def get_activation(name_): @@ -88,7 +87,6 @@ def hook( get_activation(self.layer_name) ) else: - self.model.classifier[int(digit)].register_forward_hook( get_activation(self.layer_name) ) @@ -97,17 +95,25 @@ def hook( video_id = instance["id"][0] frames = instance["data"][0] embeddings[video_id] = [] - batch_size = 32 - for start_index in range(0, len(frames), batch_size): - end_index = min(start_index + batch_size, len(frames)) - frame_ids_range = range(start_index, end_index) - frame_batch = frames[frame_ids_range] + if frames.dim() == 3: + # Single image: (3, 224, 224) -> (1, 3, 224, 224) + frames = frames.unsqueeze(0) + batch_size = 1 + else: + # Video: (T, 3, 224, 224) - process in batches + batch_size = 32 + + for start_index in range(0, frames.shape[0], batch_size): + end_index = min(start_index + batch_size, frames.shape[0]) + frame_batch = frames[start_index:end_index] _ = self.model(frame_batch) output = activations[self.layer_name] + if len(output.shape) == 4: output = torch.nn.functional.adaptive_avg_pool2d(output, (1, 1)) + embeddings[video_id].extend( torch.flatten(output, 1) .detach() @@ -122,7 +128,6 @@ def hook( transformed_modality = TransformedModality( modality, self, self.output_modality_type ) - transformed_modality.data = list(embeddings.values()) return transformed_modality