Skip to content

Commit 3d714d3

Browse files
[SYSTEMDS-3913] Make visual representations more robust
This patch improves the handling of visual representations in scuro and makes it more robust against dynamic data types. Additionally, it also improves batched execution.
1 parent c40f95e commit 3d714d3

File tree

5 files changed

+113
-40
lines changed

5 files changed

+113
-40
lines changed

src/main/python/systemds/scuro/representations/bert.py

Lines changed: 56 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,34 @@
2626
from systemds.scuro.representations.utils import save_embeddings
2727
from systemds.scuro.modality.type import ModalityType
2828
from systemds.scuro.drsearch.operator_registry import register_representation
29-
29+
from systemds.scuro.utils.static_variables import get_device
3030
import os
31+
from torch.utils.data import Dataset, DataLoader
3132

3233
os.environ["TOKENIZERS_PARALLELISM"] = "false"
3334

3435

36+
class TextDataset(Dataset):
37+
def __init__(self, texts):
38+
39+
self.texts = []
40+
for text in texts:
41+
if text is None:
42+
self.texts.append("")
43+
elif isinstance(text, np.ndarray):
44+
self.texts.append(str(text.item()) if text.size == 1 else str(text))
45+
elif not isinstance(text, str):
46+
self.texts.append(str(text))
47+
else:
48+
self.texts.append(text)
49+
50+
def __len__(self):
51+
return len(self.texts)
52+
53+
def __getitem__(self, idx):
54+
return self.texts[idx]
55+
56+
3557
@register_representation(ModalityType.TEXT)
3658
class Bert(UnimodalRepresentation):
3759
def __init__(self, model_name="bert", output_file=None, max_seq_length=512):
@@ -49,40 +71,50 @@ def transform(self, modality):
4971
model_name, clean_up_tokenization_spaces=True
5072
)
5173

52-
model = BertModel.from_pretrained(model_name)
74+
model = BertModel.from_pretrained(model_name).to(get_device())
5375

5476
embeddings = self.create_embeddings(modality, model, tokenizer)
5577

5678
if self.output_file is not None:
5779
save_embeddings(embeddings, self.output_file)
5880

5981
transformed_modality.data_type = np.float32
60-
transformed_modality.data = embeddings
82+
transformed_modality.data = np.array(embeddings)
6183
return transformed_modality
6284

6385
def create_embeddings(self, modality, model, tokenizer):
64-
inputs = tokenizer(
65-
modality.data,
66-
return_offsets_mapping=True,
67-
return_tensors="pt",
68-
padding="longest",
69-
return_attention_mask=True,
70-
truncation=True,
71-
)
72-
ModalityType.TEXT.add_field_for_instances(
73-
modality.metadata,
74-
"token_to_character_mapping",
75-
inputs.data["offset_mapping"].tolist(),
76-
)
86+
dataset = TextDataset(modality.data)
87+
dataloader = DataLoader(dataset, batch_size=32, shuffle=False, collate_fn=None)
88+
cls_embeddings = []
89+
for batch in dataloader:
90+
inputs = tokenizer(
91+
batch,
92+
return_offsets_mapping=True,
93+
return_tensors="pt",
94+
padding="max_length",
95+
return_attention_mask=True,
96+
truncation=True,
97+
max_length=512, # TODO: make this dynamic
98+
)
7799

78-
ModalityType.TEXT.add_field_for_instances(
79-
modality.metadata, "attention_masks", inputs.data["attention_mask"].tolist()
80-
)
81-
del inputs.data["offset_mapping"]
100+
inputs.to(get_device())
101+
ModalityType.TEXT.add_field_for_instances(
102+
modality.metadata,
103+
"token_to_character_mapping",
104+
inputs.data["offset_mapping"].tolist(),
105+
)
106+
107+
ModalityType.TEXT.add_field_for_instances(
108+
modality.metadata,
109+
"attention_masks",
110+
inputs.data["attention_mask"].tolist(),
111+
)
112+
del inputs.data["offset_mapping"]
82113

83-
with torch.no_grad():
84-
outputs = model(**inputs)
114+
with torch.no_grad():
115+
outputs = model(**inputs)
85116

86-
cls_embedding = outputs.last_hidden_state.detach().numpy()
117+
cls_embedding = outputs.last_hidden_state.detach().cpu().numpy()
118+
cls_embeddings.extend(cls_embedding)
87119

88-
return cls_embedding
120+
return np.array(cls_embeddings)

src/main/python/systemds/scuro/representations/clip.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def create_visual_embeddings(self, modality):
8181
frame_batch = frames[frame_ids_range]
8282

8383
inputs = self.processor(images=frame_batch, return_tensors="pt")
84+
inputs.to(get_device())
8485
with torch.no_grad():
8586
output = self.model.get_image_features(**inputs)
8687

@@ -125,9 +126,18 @@ def transform(self, modality):
125126
def create_text_embeddings(self, data, model):
126127
embeddings = []
127128
for d in data:
128-
inputs = self.processor(text=d, return_tensors="pt", padding=True)
129+
inputs = self.processor(
130+
text=d,
131+
return_tensors="pt",
132+
padding=True,
133+
truncation=True,
134+
max_length=77,
135+
)
136+
inputs.to(get_device())
129137
with torch.no_grad():
130138
text_embedding = model.get_text_features(**inputs)
131-
embeddings.append(text_embedding.squeeze().numpy().reshape(1, -1))
139+
embeddings.append(
140+
text_embedding.squeeze().detach().cpu().numpy().reshape(1, -1)
141+
)
132142

133143
return embeddings

src/main/python/systemds/scuro/representations/concatenation.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,12 @@ def __init__(self):
4141

4242
def execute(self, modalities: List[Modality]):
4343
if len(modalities) == 1:
44-
return np.array(modalities[0].data)
44+
return np.asarray(
45+
modalities[0].data,
46+
dtype=modalities[0].metadata[list(modalities[0].metadata.keys())[0]][
47+
"data_layout"
48+
]["type"],
49+
)
4550

4651
max_emb_size = self.get_max_embedding_size(modalities)
4752
size = len(modalities[0].data)
@@ -52,6 +57,18 @@ def execute(self, modalities: List[Modality]):
5257
data = np.zeros((size, 0))
5358

5459
for modality in modalities:
55-
data = np.concatenate([data, copy.deepcopy(modality.data)], axis=-1)
60+
other_modality = copy.deepcopy(modality.data)
61+
data = np.concatenate(
62+
[
63+
data,
64+
np.asarray(
65+
other_modality,
66+
dtype=modality.metadata[list(modality.metadata.keys())[0]][
67+
"data_layout"
68+
]["type"],
69+
),
70+
],
71+
axis=-1,
72+
)
5673

5774
return np.array(data)

src/main/python/systemds/scuro/representations/sum.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from typing import List
2323

24-
24+
import numpy as np
2525
from systemds.scuro.modality.modality import Modality
2626
from systemds.scuro.representations.utils import pad_sequences
2727

@@ -40,9 +40,18 @@ def __init__(self):
4040
self.needs_alignment = True
4141

4242
def execute(self, modalities: List[Modality]):
43-
data = modalities[0].data
43+
data = np.asarray(
44+
modalities[0].data,
45+
dtype=modalities[0].metadata[list(modalities[0].metadata.keys())[0]][
46+
"data_layout"
47+
]["type"],
48+
)
4449

4550
for m in range(1, len(modalities)):
46-
data += modalities[m].data
47-
51+
data += np.asarray(
52+
modalities[m].data,
53+
dtype=modalities[m].metadata[list(modalities[m].metadata.keys())[0]][
54+
"data_layout"
55+
]["type"],
56+
)
4857
return data

src/main/python/systemds/scuro/representations/vgg.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ def transform(self, modality):
7171

7272
dataset = CustomDataset(modality.data, self.data_type, get_device())
7373
embeddings = {}
74-
7574
activations = {}
7675

7776
def get_activation(name_):
@@ -88,7 +87,6 @@ def hook(
8887
get_activation(self.layer_name)
8988
)
9089
else:
91-
9290
self.model.classifier[int(digit)].register_forward_hook(
9391
get_activation(self.layer_name)
9492
)
@@ -97,17 +95,25 @@ def hook(
9795
video_id = instance["id"][0]
9896
frames = instance["data"][0]
9997
embeddings[video_id] = []
100-
batch_size = 32
10198

102-
for start_index in range(0, len(frames), batch_size):
103-
end_index = min(start_index + batch_size, len(frames))
104-
frame_ids_range = range(start_index, end_index)
105-
frame_batch = frames[frame_ids_range]
99+
if frames.dim() == 3:
100+
# Single image: (3, 224, 224) -> (1, 3, 224, 224)
101+
frames = frames.unsqueeze(0)
102+
batch_size = 1
103+
else:
104+
# Video: (T, 3, 224, 224) - process in batches
105+
batch_size = 32
106+
107+
for start_index in range(0, frames.shape[0], batch_size):
108+
end_index = min(start_index + batch_size, frames.shape[0])
109+
frame_batch = frames[start_index:end_index]
106110

107111
_ = self.model(frame_batch)
108112
output = activations[self.layer_name]
113+
109114
if len(output.shape) == 4:
110115
output = torch.nn.functional.adaptive_avg_pool2d(output, (1, 1))
116+
111117
embeddings[video_id].extend(
112118
torch.flatten(output, 1)
113119
.detach()
@@ -122,7 +128,6 @@ def hook(
122128
transformed_modality = TransformedModality(
123129
modality, self, self.output_modality_type
124130
)
125-
126131
transformed_modality.data = list(embeddings.values())
127132

128133
return transformed_modality

0 commit comments

Comments
 (0)