Skip to content

Commit 4e00aa1

Browse files
christinadionysiomboehm7
authored andcommitted
[SYSTEMDS-3701] Add test suite for Scuro
Closes #2143.
1 parent 08875cb commit 4e00aa1

File tree

10 files changed

+592
-84
lines changed

10 files changed

+592
-84
lines changed

src/main/python/systemds/scuro/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,12 @@
3030
from systemds.scuro.representations.bert import Bert
3131
from systemds.scuro.representations.unimodal import UnimodalRepresentation
3232
from systemds.scuro.representations.lstm import LSTM
33-
from systemds.scuro.representations.utils import NPY, Pickle, HDF5, JSON
33+
from systemds.scuro.representations.representation_dataloader import (
34+
NPY,
35+
Pickle,
36+
HDF5,
37+
JSON,
38+
)
3439
from systemds.scuro.models.model import Model
3540
from systemds.scuro.models.discrete_model import DiscreteModel
3641
from systemds.scuro.modality.aligned_modality import AlignedModality

src/main/python/systemds/scuro/aligner/task.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,19 @@
2121
from typing import List
2222

2323
from systemds.scuro.models.model import Model
24+
import numpy as np
25+
from sklearn.model_selection import KFold
2426

2527

2628
class Task:
2729
def __init__(
28-
self, name: str, model: Model, labels, train_indices: List, val_indices: List
30+
self,
31+
name: str,
32+
model: Model,
33+
labels,
34+
train_indices: List,
35+
val_indices: List,
36+
kfold=5,
2937
):
3038
"""
3139
Parent class for the prediction task that is performed on top of the aligned representation
@@ -34,12 +42,15 @@ def __init__(
3442
:param labels: Labels used for prediction
3543
:param train_indices: Indices to extract training data
3644
:param val_indices: Indices to extract validation data
45+
:param kfold: Number of crossvalidation runs
46+
3747
"""
3848
self.name = name
3949
self.model = model
4050
self.labels = labels
4151
self.train_indices = train_indices
4252
self.val_indices = val_indices
53+
self.kfold = kfold
4354

4455
def get_train_test_split(self, data):
4556
X_train = [data[i] for i in self.train_indices]
@@ -51,9 +62,27 @@ def get_train_test_split(self, data):
5162

5263
def run(self, data):
5364
"""
54-
The run method need to be implemented by every task class
55-
It handles the training and validation procedures for the specific task
56-
:param data: The aligned data used in the prediction process
57-
:return: the validation accuracy
65+
The run method needs to be implemented by every task class
66+
It handles the training and validation procedures for the specific task
67+
:param data: The aligned data used in the prediction process
68+
:return: the validation accuracy
5869
"""
59-
pass
70+
skf = KFold(n_splits=self.kfold, shuffle=True, random_state=11)
71+
train_scores = []
72+
test_scores = []
73+
fold = 0
74+
X, y, X_test, y_test = self.get_train_test_split(data)
75+
76+
for train, test in skf.split(X, y):
77+
train_X = np.array(X)[train]
78+
train_y = np.array(y)[train]
79+
80+
train_score = self.model.fit(train_X, train_y, X_test, y_test)
81+
train_scores.append(train_score)
82+
83+
test_score = self.model.test(X_test, y_test)
84+
test_scores.append(test_score)
85+
86+
fold += 1
87+
88+
return [np.mean(train_scores), np.mean(test_scores)]

src/main/python/systemds/scuro/representations/bert.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,9 @@ def parse_all(self, filepath, indices, get_sequences=False):
5656
data = file.readlines()
5757

5858
model_name = "bert-base-uncased"
59-
tokenizer = BertTokenizer.from_pretrained(model_name)
59+
tokenizer = BertTokenizer.from_pretrained(
60+
model_name, clean_up_tokenization_spaces=True
61+
)
6062

6163
if self.avg_layers is not None:
6264
model = BertModel.from_pretrained(model_name, output_hidden_states=True)
@@ -89,7 +91,7 @@ def create_embeddings(self, data, model, tokenizer):
8991
cls_embedding = torch.mean(torch.stack(cls_embedding), dim=0)
9092
else:
9193
cls_embedding = outputs.last_hidden_state[:, 0, :].squeeze().numpy()
92-
embeddings.append(cls_embedding)
94+
embeddings.append(cls_embedding.numpy())
9395

9496
embeddings = np.array(embeddings)
9597
return embeddings.reshape((embeddings.shape[0], embeddings.shape[-1]))

src/main/python/systemds/scuro/representations/fusion.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020
# -------------------------------------------------------------
2121
from typing import List
2222

23-
from sklearn.preprocessing import StandardScaler
24-
2523
from systemds.scuro.modality.modality import Modality
2624
from systemds.scuro.representations.representation import Representation
2725

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# -------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
# -------------------------------------------------------------
21+
22+
23+
import json
24+
import pickle
25+
import numpy as np
26+
import h5py
27+
28+
from systemds.scuro.representations.unimodal import UnimodalRepresentation
29+
30+
31+
class NPY(UnimodalRepresentation):
32+
def __init__(self):
33+
super().__init__("NPY")
34+
35+
def parse_all(self, filepath, indices, get_sequences=False):
36+
data = np.load(filepath, allow_pickle=True)
37+
38+
if indices is not None:
39+
return np.array([data[index] for index in indices])
40+
else:
41+
return np.array([data[index] for index in data])
42+
43+
44+
class Pickle(UnimodalRepresentation):
45+
def __init__(self):
46+
super().__init__("Pickle")
47+
48+
def parse_all(self, file_path, indices, get_sequences=False):
49+
with open(file_path, "rb") as f:
50+
data = pickle.load(f)
51+
52+
embeddings = []
53+
for n, idx in enumerate(indices):
54+
embeddings.append(data[idx])
55+
56+
return np.array(embeddings)
57+
58+
59+
class HDF5(UnimodalRepresentation):
60+
def __init__(self):
61+
super().__init__("HDF5")
62+
63+
def parse_all(self, filepath, indices=None, get_sequences=False):
64+
data = h5py.File(filepath)
65+
66+
if get_sequences:
67+
max_emb = 0
68+
for index in indices:
69+
if max_emb < len(data[index][()]):
70+
max_emb = len(data[index][()])
71+
72+
emb = []
73+
if indices is not None:
74+
for index in indices:
75+
emb_i = data[index].tolist()
76+
for i in range(len(emb_i), max_emb):
77+
emb_i.append([0 for x in range(0, len(emb_i[0]))])
78+
emb.append(emb_i)
79+
80+
return np.array(emb)
81+
else:
82+
if indices is not None:
83+
return np.array([np.mean(data[index], axis=0) for index in indices])
84+
else:
85+
return np.array([np.mean(data[index][()], axis=0) for index in data])
86+
87+
88+
class JSON(UnimodalRepresentation):
89+
def __init__(self):
90+
super().__init__("JSON")
91+
92+
def parse_all(self, filepath, indices):
93+
with open(filepath) as file:
94+
return json.load(file)

src/main/python/systemds/scuro/representations/utils.py

Lines changed: 0 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -19,81 +19,8 @@
1919
#
2020
# -------------------------------------------------------------
2121

22-
23-
import json
24-
import pickle
25-
26-
import h5py
2722
import numpy as np
2823

29-
from systemds.scuro.representations.unimodal import UnimodalRepresentation
30-
31-
32-
class NPY(UnimodalRepresentation):
33-
def __init__(self):
34-
super().__init__("NPY")
35-
36-
def parse_all(self, filepath, indices, get_sequences=False):
37-
data = np.load(filepath, allow_pickle=True)
38-
39-
if indices is not None:
40-
return np.array([data[index] for index in indices])
41-
else:
42-
return np.array([data[index] for index in data])
43-
44-
45-
class Pickle(UnimodalRepresentation):
46-
def __init__(self):
47-
super().__init__("Pickle")
48-
49-
def parse_all(self, file_path, indices, get_sequences=False):
50-
with open(file_path, "rb") as f:
51-
data = pickle.load(f)
52-
53-
embeddings = []
54-
for n, idx in enumerate(indices):
55-
embeddings.append(data[idx])
56-
57-
return np.array(embeddings)
58-
59-
60-
class HDF5(UnimodalRepresentation):
61-
def __init__(self):
62-
super().__init__("HDF5")
63-
64-
def parse_all(self, filepath, indices=None, get_sequences=False):
65-
data = h5py.File(filepath)
66-
67-
if get_sequences:
68-
max_emb = 0
69-
for index in indices:
70-
if max_emb < len(data[index][()]):
71-
max_emb = len(data[index][()])
72-
73-
emb = []
74-
if indices is not None:
75-
for index in indices:
76-
emb_i = data[index].tolist()
77-
for i in range(len(emb_i), max_emb):
78-
emb_i.append([0 for x in range(0, len(emb_i[0]))])
79-
emb.append(emb_i)
80-
81-
return np.array(emb)
82-
else:
83-
if indices is not None:
84-
return np.array([np.mean(data[index], axis=0) for index in indices])
85-
else:
86-
return np.array([np.mean(data[index][()], axis=0) for index in data])
87-
88-
89-
class JSON(UnimodalRepresentation):
90-
def __init__(self):
91-
super().__init__("JSON")
92-
93-
def parse_all(self, filepath, indices):
94-
with open(filepath) as file:
95-
return json.load(file)
96-
9724

9825
def pad_sequences(sequences, maxlen=None, dtype="float32", value=0):
9926
if maxlen is None:
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# -------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
# -------------------------------------------------------------

0 commit comments

Comments
 (0)