Skip to content

Commit f885840

Browse files
[SYSTEMDS-3913] Make learnable fusion representations work for multi-label tasks
This patch adds functionality to learnable fusion methods to make it work with multi-label tasks as well.
1 parent 5292b42 commit f885840

File tree

6 files changed

+144
-50
lines changed

6 files changed

+144
-50
lines changed

src/main/python/systemds/scuro/dataloader/video_loader.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,13 @@ def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):
7171
self.fps, length, width, height, num_channels
7272
)
7373

74-
frames = []
74+
num_frames = (length + frame_interval - 1) // frame_interval
75+
76+
stacked_frames = np.zeros(
77+
(num_frames, height, width, num_channels), dtype=self._data_type
78+
)
79+
80+
frame_idx = 0
7581
idx = 0
7682
while cap.isOpened():
7783
ret, frame = cap.read()
@@ -81,7 +87,11 @@ def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):
8187
if idx % frame_interval == 0:
8288
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
8389
frame = frame.astype(self._data_type) / 255.0
84-
frames.append(frame)
90+
stacked_frames[frame_idx] = frame
91+
frame_idx += 1
8592
idx += 1
8693

87-
self.data.append(np.stack(frames))
94+
if frame_idx < num_frames:
95+
stacked_frames = stacked_frames[:frame_idx]
96+
97+
self.data.append(stacked_frames)

src/main/python/systemds/scuro/modality/modality.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,8 @@ def update_metadata(self):
8888
):
8989
return
9090

91-
md_copy = deepcopy(self.metadata)
92-
self.metadata = {}
93-
for i, (md_k, md_v) in enumerate(md_copy.items()):
91+
for i, (md_k, md_v) in enumerate(self.metadata.items()):
92+
md_v = selective_copy_metadata(md_v)
9493
updated_md = self.modality_type.update_metadata(md_v, self.data[i])
9594
self.metadata[md_k] = updated_md
9695
if i == 0:
@@ -183,3 +182,20 @@ def is_aligned(self, other_modality):
183182
break
184183

185184
return aligned
185+
186+
187+
def selective_copy_metadata(metadata):
188+
if isinstance(metadata, dict):
189+
new_md = {}
190+
for k, v in metadata.items():
191+
if k == "data_layout":
192+
new_md[k] = v.copy() if isinstance(v, dict) else v
193+
elif isinstance(v, np.ndarray):
194+
new_md[k] = v
195+
else:
196+
new_md[k] = selective_copy_metadata(v)
197+
return new_md
198+
elif isinstance(metadata, (list, tuple)):
199+
return type(metadata)(selective_copy_metadata(item) for item in metadata)
200+
else:
201+
return metadata

src/main/python/systemds/scuro/modality/unimodal_modality.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,6 @@ def apply_representation(self, representation):
146146
else:
147147
original_lengths.append(d.shape[0])
148148

149-
new_modality.data = self.l2_normalize_features(new_modality.data)
150-
151149
if len(original_lengths) > 0 and min(original_lengths) < max(original_lengths):
152150
target_length = max(original_lengths)
153151
padded_embeddings = []
@@ -194,20 +192,3 @@ def apply_representation(self, representation):
194192
new_modality.transform_time = time.time() - start
195193
new_modality.self_contained = representation.self_contained
196194
return new_modality
197-
198-
def l2_normalize_features(self, feature_list):
199-
normalized_features = []
200-
for feature in feature_list:
201-
original_shape = feature.shape
202-
flattened = feature.flatten()
203-
204-
norm = np.linalg.norm(flattened)
205-
if norm > 0:
206-
normalized_flat = flattened / norm
207-
normalized_feature = normalized_flat.reshape(original_shape)
208-
else:
209-
normalized_feature = feature
210-
211-
normalized_features.append(normalized_feature)
212-
213-
return normalized_features

src/main/python/systemds/scuro/representations/fusion.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,19 +68,25 @@ def transform(self, modalities: List[Modality]):
6868
return self.execute(mods)
6969

7070
def transform_with_training(self, modalities: List[Modality], task):
71+
fusion_train_indices = task.fusion_train_indices
72+
7173
train_modalities = []
7274
for modality in modalities:
7375
train_data = [
74-
d for i, d in enumerate(modality.data) if i in task.train_indices
76+
d for i, d in enumerate(modality.data) if i in fusion_train_indices
7577
]
7678
train_modality = TransformedModality(modality, self)
7779
train_modality.data = copy.deepcopy(train_data)
7880
train_modalities.append(train_modality)
7981

8082
transformed_train = self.execute(
81-
train_modalities, task.labels[task.train_indices]
83+
train_modalities, task.labels[fusion_train_indices]
8284
)
83-
transformed_val = self.transform_data(modalities, task.val_indices)
85+
86+
all_other_indices = [
87+
i for i in range(len(modalities[0].data)) if i not in fusion_train_indices
88+
]
89+
transformed_other = self.transform_data(modalities, all_other_indices)
8490

8591
transformed_data = np.zeros(
8692
(len(modalities[0].data), transformed_train.shape[1])

src/main/python/systemds/scuro/representations/lstm.py

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,15 @@ def __init__(
4242
depth=1,
4343
dropout_rate=0.1,
4444
learning_rate=0.001,
45-
epochs=50,
45+
epochs=20,
4646
batch_size=32,
4747
):
4848
parameters = {
4949
"width": [128, 256, 512],
5050
"depth": [1, 2, 3],
5151
"dropout_rate": [0.1, 0.2, 0.3, 0.4, 0.5],
5252
"learning_rate": [0.001, 0.0001, 0.01, 0.1],
53-
"epochs": [50, 100, 200],
53+
"epochs": [10, 2050, 100, 200],
5454
"batch_size": [8, 16, 32, 64, 128],
5555
}
5656

@@ -70,6 +70,7 @@ def __init__(
7070
self.num_classes = None
7171
self.is_trained = False
7272
self.model_state = None
73+
self.is_multilabel = False
7374

7475
self._set_random_seeds()
7576

@@ -166,18 +167,32 @@ def execute(self, modalities: List[Modality], labels: np.ndarray = None):
166167
X = self._prepare_data(modalities)
167168
y = np.array(labels)
168169

170+
if y.ndim == 2 and y.shape[1] > 1:
171+
self.is_multilabel = True
172+
self.num_classes = y.shape[1]
173+
else:
174+
self.is_multilabel = False
175+
if y.ndim == 2:
176+
y = y.ravel()
177+
self.num_classes = len(np.unique(y))
178+
169179
self.input_dim = X.shape[2]
170-
self.num_classes = len(np.unique(y))
171180

172181
self.model = self._build_model(self.input_dim, self.num_classes)
173182
device = get_device()
174183
self.model.to(device)
175184

176-
criterion = nn.CrossEntropyLoss()
185+
if self.is_multilabel:
186+
criterion = nn.BCEWithLogitsLoss()
187+
else:
188+
criterion = nn.CrossEntropyLoss()
177189
optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
178190

179191
X_tensor = torch.FloatTensor(X).to(device)
180-
y_tensor = torch.LongTensor(y).to(device)
192+
if self.is_multilabel:
193+
y_tensor = torch.FloatTensor(y).to(device)
194+
else:
195+
y_tensor = torch.LongTensor(y).to(device)
181196

182197
dataset = TensorDataset(X_tensor, y_tensor)
183198
dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
@@ -202,15 +217,23 @@ def execute(self, modalities: List[Modality], labels: np.ndarray = None):
202217
"state_dict": self.model.state_dict(),
203218
"input_dim": self.input_dim,
204219
"num_classes": self.num_classes,
220+
"is_multilabel": self.is_multilabel,
205221
"width": self.width,
206222
"depth": self.depth,
207223
"dropout_rate": self.dropout_rate,
208224
}
209225

210226
self.model.eval()
227+
all_features = []
211228
with torch.no_grad():
212-
features, _ = self.model(X_tensor)
213-
return features.cpu().numpy()
229+
inference_dataloader = DataLoader(
230+
TensorDataset(X_tensor), batch_size=self.batch_size, shuffle=False
231+
)
232+
for (batch_X,) in inference_dataloader:
233+
features, _ = self.model(batch_X)
234+
all_features.append(features.cpu())
235+
236+
return torch.cat(all_features, dim=0).numpy()
214237

215238
def apply_representation(self, modalities: List[Modality]) -> np.ndarray:
216239
if not self.is_trained or self.model is None:
@@ -222,12 +245,17 @@ def apply_representation(self, modalities: List[Modality]) -> np.ndarray:
222245
self.model.to(device)
223246

224247
X_tensor = torch.FloatTensor(X).to(device)
225-
248+
all_features = []
226249
self.model.eval()
227250
with torch.no_grad():
228-
features, _ = self.model(X_tensor)
251+
inference_dataloader = DataLoader(
252+
TensorDataset(X_tensor), batch_size=self.batch_size, shuffle=False
253+
)
254+
for (batch_X,) in inference_dataloader:
255+
features, _ = self.model(batch_X)
256+
all_features.append(features.cpu())
229257

230-
return features.cpu().numpy()
258+
return torch.cat(all_features, dim=0).numpy()
231259

232260
def get_model_state(self) -> Dict[str, Any]:
233261
return self.model_state
@@ -236,6 +264,7 @@ def set_model_state(self, state: Dict[str, Any]):
236264
self.model_state = state
237265
self.input_dim = state["input_dim"]
238266
self.num_classes = state["num_classes"]
267+
self.is_multilabel = state.get("is_multilabel", False)
239268

240269
self.model = self._build_model(self.input_dim, self.num_classes)
241270
self.model.load_state_dict(state["state_dict"])

src/main/python/systemds/scuro/representations/multimodal_attention_fusion.py

Lines changed: 64 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,15 @@ def __init__(
4040
num_heads=8,
4141
dropout=0.1,
4242
batch_size=32,
43-
num_epochs=50,
43+
num_epochs=20,
4444
learning_rate=0.001,
4545
):
4646
parameters = {
4747
"hidden_dim": [32, 128, 256, 384, 512, 768],
4848
"num_heads": [2, 4, 8, 12],
4949
"dropout": [0.0, 0.1, 0.2, 0.3, 0.4],
5050
"batch_size": [8, 16, 32, 64, 128],
51-
"num_epochs": [50, 100, 150, 200],
51+
"num_epochs": [10, 20, 50, 100, 150, 200],
5252
"learning_rate": [1e-5, 1e-4, 1e-3, 1e-2],
5353
}
5454
super().__init__("AttentionFusion", parameters)
@@ -69,6 +69,7 @@ def __init__(
6969
self.num_classes = None
7070
self.is_trained = False
7171
self.model_state = None
72+
self.is_multilabel = False
7273

7374
self._set_random_seeds()
7475

@@ -122,9 +123,17 @@ def execute(self, modalities: List[Modality], labels: np.ndarray = None):
122123
inputs, input_dimensions, max_sequence_length = self._prepare_data(modalities)
123124
y = np.array(labels)
124125

126+
if y.ndim == 2 and y.shape[1] > 1:
127+
self.is_multilabel = True
128+
self.num_classes = y.shape[1]
129+
else:
130+
self.is_multilabel = False
131+
if y.ndim == 2:
132+
y = y.ravel()
133+
self.num_classes = len(np.unique(y))
134+
125135
self.input_dim = input_dimensions
126136
self.max_sequence_length = max_sequence_length
127-
self.num_classes = len(np.unique(y))
128137

129138
self.encoder = MultiModalAttentionFusion(
130139
self.input_dim,
@@ -142,7 +151,10 @@ def execute(self, modalities: List[Modality], labels: np.ndarray = None):
142151
self.encoder.to(device)
143152
self.classification_head.to(device)
144153

145-
criterion = nn.CrossEntropyLoss()
154+
if self.is_multilabel:
155+
criterion = nn.BCEWithLogitsLoss()
156+
else:
157+
criterion = nn.CrossEntropyLoss()
146158
optimizer = torch.optim.Adam(
147159
list(self.encoder.parameters())
148160
+ list(self.classification_head.parameters()),
@@ -151,7 +163,11 @@ def execute(self, modalities: List[Modality], labels: np.ndarray = None):
151163

152164
for modality_name in inputs:
153165
inputs[modality_name] = inputs[modality_name].to(device)
154-
labels_tensor = torch.from_numpy(y).long().to(device)
166+
167+
if self.is_multilabel:
168+
labels_tensor = torch.from_numpy(y).float().to(device)
169+
else:
170+
labels_tensor = torch.from_numpy(y).long().to(device)
155171

156172
dataset_inputs = []
157173
for i in range(len(y)):
@@ -197,9 +213,17 @@ def execute(self, modalities: List[Modality], labels: np.ndarray = None):
197213
optimizer.step()
198214

199215
total_loss += loss.item()
200-
_, predicted = torch.max(logits.data, 1)
201-
total_correct += (predicted == batch_labels).sum().item()
202-
total_samples += batch_labels.size(0)
216+
217+
if self.is_multilabel:
218+
predicted = (torch.sigmoid(logits) > 0.5).float()
219+
correct = (predicted == batch_labels).float()
220+
hamming_acc = correct.mean()
221+
total_correct += hamming_acc.item() * batch_labels.size(0)
222+
total_samples += batch_labels.size(0)
223+
else:
224+
_, predicted = torch.max(logits.data, 1)
225+
total_correct += (predicted == batch_labels).sum().item()
226+
total_samples += batch_labels.size(0)
203227

204228
self.is_trained = True
205229

@@ -214,10 +238,24 @@ def execute(self, modalities: List[Modality], labels: np.ndarray = None):
214238
"dropout": self.dropout,
215239
}
216240

241+
all_features = []
242+
217243
with torch.no_grad():
218-
encoder_output = self.encoder(inputs)
244+
for batch_start in range(
245+
0, len(inputs[list(inputs.keys())[0]]), self.batch_size
246+
):
247+
batch_end = min(
248+
batch_start + self.batch_size, len(inputs[list(inputs.keys())[0]])
249+
)
250+
251+
batch_inputs = {}
252+
for modality_name, tensor in inputs.items():
253+
batch_inputs[modality_name] = tensor[batch_start:batch_end]
254+
255+
encoder_output = self.encoder(batch_inputs)
256+
all_features.append(encoder_output["fused"].cpu())
219257

220-
return encoder_output["fused"].cpu().numpy()
258+
return torch.cat(all_features, dim=0).numpy()
221259

222260
def apply_representation(self, modalities: List[Modality]) -> np.ndarray:
223261
if not self.is_trained or self.encoder is None:
@@ -232,10 +270,23 @@ def apply_representation(self, modalities: List[Modality]) -> np.ndarray:
232270
inputs[modality_name] = inputs[modality_name].to(device)
233271

234272
self.encoder.eval()
273+
all_features = []
274+
235275
with torch.no_grad():
236-
encoder_output = self.encoder(inputs)
276+
batch_size = self.batch_size
277+
n_samples = len(inputs[list(inputs.keys())[0]])
278+
279+
for batch_start in range(0, n_samples, batch_size):
280+
batch_end = min(batch_start + batch_size, n_samples)
281+
282+
batch_inputs = {}
283+
for modality_name, tensor in inputs.items():
284+
batch_inputs[modality_name] = tensor[batch_start:batch_end]
285+
286+
encoder_output = self.encoder(batch_inputs)
287+
all_features.append(encoder_output["fused"].cpu())
237288

238-
return encoder_output["fused"].cpu().numpy()
289+
return torch.cat(all_features, dim=0).numpy()
239290

240291
def get_model_state(self) -> Dict[str, Any]:
241292
return self.model_state
@@ -245,6 +296,7 @@ def set_model_state(self, state: Dict[str, Any]):
245296
self.input_dim = state["input_dimensions"]
246297
self.max_sequence_length = state["max_sequence_length"]
247298
self.num_classes = state["num_classes"]
299+
self.is_multilabel = state.get("is_multilabel", False)
248300

249301
self.encoder = MultiModalAttentionFusion(
250302
self.input_dim,

0 commit comments

Comments
 (0)