Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 56 additions & 24 deletions src/main/python/systemds/scuro/representations/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,34 @@
from systemds.scuro.representations.utils import save_embeddings
from systemds.scuro.modality.type import ModalityType
from systemds.scuro.drsearch.operator_registry import register_representation

from systemds.scuro.utils.static_variables import get_device
import os
from torch.utils.data import Dataset, DataLoader

os.environ["TOKENIZERS_PARALLELISM"] = "false"


class TextDataset(Dataset):
def __init__(self, texts):

self.texts = []
for text in texts:
if text is None:
self.texts.append("")
elif isinstance(text, np.ndarray):
self.texts.append(str(text.item()) if text.size == 1 else str(text))
elif not isinstance(text, str):
self.texts.append(str(text))
else:
self.texts.append(text)

def __len__(self):
return len(self.texts)

def __getitem__(self, idx):
return self.texts[idx]


@register_representation(ModalityType.TEXT)
class Bert(UnimodalRepresentation):
def __init__(self, model_name="bert", output_file=None, max_seq_length=512):
Expand All @@ -49,40 +71,50 @@ def transform(self, modality):
model_name, clean_up_tokenization_spaces=True
)

model = BertModel.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name).to(get_device())

embeddings = self.create_embeddings(modality, model, tokenizer)

if self.output_file is not None:
save_embeddings(embeddings, self.output_file)

transformed_modality.data_type = np.float32
transformed_modality.data = embeddings
transformed_modality.data = np.array(embeddings)
return transformed_modality

def create_embeddings(self, modality, model, tokenizer):
inputs = tokenizer(
modality.data,
return_offsets_mapping=True,
return_tensors="pt",
padding="longest",
return_attention_mask=True,
truncation=True,
)
ModalityType.TEXT.add_field_for_instances(
modality.metadata,
"token_to_character_mapping",
inputs.data["offset_mapping"].tolist(),
)
dataset = TextDataset(modality.data)
dataloader = DataLoader(dataset, batch_size=32, shuffle=False, collate_fn=None)
cls_embeddings = []
for batch in dataloader:
inputs = tokenizer(
batch,
return_offsets_mapping=True,
return_tensors="pt",
padding="max_length",
return_attention_mask=True,
truncation=True,
max_length=512, # TODO: make this dynamic
)

ModalityType.TEXT.add_field_for_instances(
modality.metadata, "attention_masks", inputs.data["attention_mask"].tolist()
)
del inputs.data["offset_mapping"]
inputs.to(get_device())
ModalityType.TEXT.add_field_for_instances(
modality.metadata,
"token_to_character_mapping",
inputs.data["offset_mapping"].tolist(),
)

ModalityType.TEXT.add_field_for_instances(
modality.metadata,
"attention_masks",
inputs.data["attention_mask"].tolist(),
)
del inputs.data["offset_mapping"]

with torch.no_grad():
outputs = model(**inputs)
with torch.no_grad():
outputs = model(**inputs)

cls_embedding = outputs.last_hidden_state.detach().numpy()
cls_embedding = outputs.last_hidden_state.detach().cpu().numpy()
cls_embeddings.extend(cls_embedding)

return cls_embedding
return np.array(cls_embeddings)
14 changes: 12 additions & 2 deletions src/main/python/systemds/scuro/representations/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def create_visual_embeddings(self, modality):
frame_batch = frames[frame_ids_range]

inputs = self.processor(images=frame_batch, return_tensors="pt")
inputs.to(get_device())
with torch.no_grad():
output = self.model.get_image_features(**inputs)

Expand Down Expand Up @@ -125,9 +126,18 @@ def transform(self, modality):
def create_text_embeddings(self, data, model):
embeddings = []
for d in data:
inputs = self.processor(text=d, return_tensors="pt", padding=True)
inputs = self.processor(
text=d,
return_tensors="pt",
padding=True,
truncation=True,
max_length=77,
)
inputs.to(get_device())
with torch.no_grad():
text_embedding = model.get_text_features(**inputs)
embeddings.append(text_embedding.squeeze().numpy().reshape(1, -1))
embeddings.append(
text_embedding.squeeze().detach().cpu().numpy().reshape(1, -1)
)

return embeddings
21 changes: 19 additions & 2 deletions src/main/python/systemds/scuro/representations/concatenation.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,12 @@ def __init__(self):

def execute(self, modalities: List[Modality]):
if len(modalities) == 1:
return np.array(modalities[0].data)
return np.asarray(
modalities[0].data,
dtype=modalities[0].metadata[list(modalities[0].metadata.keys())[0]][
"data_layout"
]["type"],
)

max_emb_size = self.get_max_embedding_size(modalities)
size = len(modalities[0].data)
Expand All @@ -52,6 +57,18 @@ def execute(self, modalities: List[Modality]):
data = np.zeros((size, 0))

for modality in modalities:
data = np.concatenate([data, copy.deepcopy(modality.data)], axis=-1)
other_modality = copy.deepcopy(modality.data)
data = np.concatenate(
[
data,
np.asarray(
other_modality,
dtype=modality.metadata[list(modality.metadata.keys())[0]][
"data_layout"
]["type"],
),
],
axis=-1,
)

return np.array(data)
17 changes: 13 additions & 4 deletions src/main/python/systemds/scuro/representations/sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from typing import List


import numpy as np
from systemds.scuro.modality.modality import Modality
from systemds.scuro.representations.utils import pad_sequences

Expand All @@ -40,9 +40,18 @@ def __init__(self):
self.needs_alignment = True

def execute(self, modalities: List[Modality]):
data = modalities[0].data
data = np.asarray(
modalities[0].data,
dtype=modalities[0].metadata[list(modalities[0].metadata.keys())[0]][
"data_layout"
]["type"],
)

for m in range(1, len(modalities)):
data += modalities[m].data

data += np.asarray(
modalities[m].data,
dtype=modalities[m].metadata[list(modalities[m].metadata.keys())[0]][
"data_layout"
]["type"],
)
return data
21 changes: 13 additions & 8 deletions src/main/python/systemds/scuro/representations/vgg.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def transform(self, modality):

dataset = CustomDataset(modality.data, self.data_type, get_device())
embeddings = {}

activations = {}

def get_activation(name_):
Expand All @@ -88,7 +87,6 @@ def hook(
get_activation(self.layer_name)
)
else:

self.model.classifier[int(digit)].register_forward_hook(
get_activation(self.layer_name)
)
Expand All @@ -97,17 +95,25 @@ def hook(
video_id = instance["id"][0]
frames = instance["data"][0]
embeddings[video_id] = []
batch_size = 32

for start_index in range(0, len(frames), batch_size):
end_index = min(start_index + batch_size, len(frames))
frame_ids_range = range(start_index, end_index)
frame_batch = frames[frame_ids_range]
if frames.dim() == 3:
# Single image: (3, 224, 224) -> (1, 3, 224, 224)
frames = frames.unsqueeze(0)
batch_size = 1
else:
# Video: (T, 3, 224, 224) - process in batches
batch_size = 32

for start_index in range(0, frames.shape[0], batch_size):
end_index = min(start_index + batch_size, frames.shape[0])
frame_batch = frames[start_index:end_index]

_ = self.model(frame_batch)
output = activations[self.layer_name]

if len(output.shape) == 4:
output = torch.nn.functional.adaptive_avg_pool2d(output, (1, 1))

embeddings[video_id].extend(
torch.flatten(output, 1)
.detach()
Expand All @@ -122,7 +128,6 @@ def hook(
transformed_modality = TransformedModality(
modality, self, self.output_modality_type
)

transformed_modality.data = list(embeddings.values())

return transformed_modality
Loading