Skip to content

Commit e680fc8

Browse files
testing
1 parent 9af4b89 commit e680fc8

File tree

1 file changed

+106
-107
lines changed

1 file changed

+106
-107
lines changed

src/main/python/tests/scuro/test_dr_search.py

Lines changed: 106 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -78,110 +78,109 @@ def scale_data(data, train_indizes):
7878

7979

8080
class TestDataLoaders(unittest.TestCase):
81-
pass
82-
# train_indizes = None
83-
# val_indizes = None
84-
# test_file_path = None
85-
# mods = None
86-
# text = None
87-
# audio = None
88-
# video = None
89-
# data_generator = None
90-
# num_instances = 0
91-
# representations = None
92-
#
93-
# @classmethod
94-
# def setUpClass(cls):
95-
# cls.test_file_path = "test_data_dr_search"
96-
# cls.num_instances = 20
97-
# modalities = [ModalityType.VIDEO, ModalityType.AUDIO, ModalityType.TEXT]
98-
#
99-
# cls.data_generator = setup_data(
100-
# modalities, cls.num_instances, cls.test_file_path
101-
# )
102-
# os.makedirs(f"{cls.test_file_path}/embeddings")
103-
#
104-
# # TODO: adapt the representation so they return non aggregated values. Apply windowing operation instead
105-
#
106-
# cls.bert = cls.data_generator.modalities_by_type[
107-
# ModalityType.TEXT
108-
# ].apply_representation(Bert())
109-
# cls.mel_spe = (
110-
# cls.data_generator.modalities_by_type[ModalityType.AUDIO]
111-
# .apply_representation(MelSpectrogram())
112-
# .flatten()
113-
# )
114-
# cls.resnet = (
115-
# cls.data_generator.modalities_by_type[ModalityType.VIDEO]
116-
# .apply_representation(ResNet())
117-
# .window_aggregation(10, "mean")
118-
# .flatten()
119-
# )
120-
# cls.mods = [cls.bert, cls.mel_spe, cls.resnet]
121-
#
122-
# split = train_test_split(
123-
# cls.data_generator.indices,
124-
# cls.data_generator.labels,
125-
# test_size=0.2,
126-
# random_state=42,
127-
# )
128-
# cls.train_indizes, cls.val_indizes = [int(i) for i in split[0]], [
129-
# int(i) for i in split[1]
130-
# ]
131-
#
132-
# for m in cls.mods:
133-
# m.data = scale_data(m.data, cls.train_indizes)
134-
#
135-
# cls.representations = [
136-
# Concatenation(),
137-
# Average(),
138-
# RowMax(100),
139-
# Multiplication(),
140-
# Sum(),
141-
# LSTM(width=256, depth=3),
142-
# ]
143-
#
144-
# @classmethod
145-
# def tearDownClass(cls):
146-
# print("Cleaning up test data")
147-
# shutil.rmtree(cls.test_file_path)
148-
#
149-
# def test_enumerate_all(self):
150-
# task = Task(
151-
# "TestTask",
152-
# TestSVM(),
153-
# self.data_generator.labels,
154-
# self.train_indizes,
155-
# self.val_indizes,
156-
# )
157-
# dr_search = DRSearch(self.mods, task, self.representations)
158-
# best_representation, best_score, best_modalities = dr_search.fit_enumerate_all()
159-
#
160-
# for r in dr_search.scores.values():
161-
# for scores in r.values():
162-
# assert scores[1] <= best_score
163-
#
164-
# def test_enumerate_all_vs_random(self):
165-
# task = Task(
166-
# "TestTask",
167-
# TestSVM(),
168-
# self.data_generator.labels,
169-
# self.train_indizes,
170-
# self.val_indizes,
171-
# )
172-
# dr_search = DRSearch(self.mods, task, self.representations)
173-
# best_representation_enum, best_score_enum, best_modalities_enum = (
174-
# dr_search.fit_enumerate_all()
175-
# )
176-
#
177-
# dr_search.reset_best_params()
178-
#
179-
# best_representation_rand, best_score_rand, best_modalities_rand = (
180-
# dr_search.fit_random(seed=42)
181-
# )
182-
#
183-
# assert best_score_rand <= best_score_enum
184-
#
185-
#
186-
# if __name__ == "__main__":
187-
# unittest.main()
81+
train_indizes = None
82+
val_indizes = None
83+
test_file_path = None
84+
mods = None
85+
text = None
86+
audio = None
87+
video = None
88+
data_generator = None
89+
num_instances = 0
90+
representations = None
91+
92+
@classmethod
93+
def setUpClass(cls):
94+
cls.test_file_path = "test_data_dr_search"
95+
cls.num_instances = 20
96+
modalities = [ModalityType.VIDEO, ModalityType.AUDIO, ModalityType.TEXT]
97+
98+
cls.data_generator = setup_data(
99+
modalities, cls.num_instances, cls.test_file_path
100+
)
101+
os.makedirs(f"{cls.test_file_path}/embeddings")
102+
103+
# TODO: adapt the representation so they return non aggregated values. Apply windowing operation instead
104+
105+
cls.bert = cls.data_generator.modalities_by_type[
106+
ModalityType.TEXT
107+
].apply_representation(Bert())
108+
cls.mel_spe = (
109+
cls.data_generator.modalities_by_type[ModalityType.AUDIO]
110+
.apply_representation(MelSpectrogram())
111+
.flatten()
112+
)
113+
cls.resnet = (
114+
cls.data_generator.modalities_by_type[ModalityType.VIDEO]
115+
.apply_representation(ResNet())
116+
.window_aggregation(10, "mean")
117+
.flatten()
118+
)
119+
cls.mods = [cls.bert, cls.mel_spe, cls.resnet]
120+
121+
split = train_test_split(
122+
cls.data_generator.indices,
123+
cls.data_generator.labels,
124+
test_size=0.2,
125+
random_state=42,
126+
)
127+
cls.train_indizes, cls.val_indizes = [int(i) for i in split[0]], [
128+
int(i) for i in split[1]
129+
]
130+
131+
for m in cls.mods:
132+
m.data = scale_data(m.data, cls.train_indizes)
133+
134+
cls.representations = [
135+
Concatenation(),
136+
Average(),
137+
RowMax(100),
138+
Multiplication(),
139+
Sum(),
140+
LSTM(width=256, depth=3),
141+
]
142+
143+
@classmethod
144+
def tearDownClass(cls):
145+
print("Cleaning up test data")
146+
shutil.rmtree(cls.test_file_path)
147+
148+
def test_enumerate_all(self):
149+
task = Task(
150+
"TestTask",
151+
TestSVM(),
152+
self.data_generator.labels,
153+
self.train_indizes,
154+
self.val_indizes,
155+
)
156+
dr_search = DRSearch(self.mods, task, self.representations)
157+
best_representation, best_score, best_modalities = dr_search.fit_enumerate_all()
158+
159+
for r in dr_search.scores.values():
160+
for scores in r.values():
161+
assert scores[1] <= best_score
162+
163+
def test_enumerate_all_vs_random(self):
164+
task = Task(
165+
"TestTask",
166+
TestSVM(),
167+
self.data_generator.labels,
168+
self.train_indizes,
169+
self.val_indizes,
170+
)
171+
dr_search = DRSearch(self.mods, task, self.representations)
172+
best_representation_enum, best_score_enum, best_modalities_enum = (
173+
dr_search.fit_enumerate_all()
174+
)
175+
176+
dr_search.reset_best_params()
177+
178+
best_representation_rand, best_score_rand, best_modalities_rand = (
179+
dr_search.fit_random(seed=42)
180+
)
181+
182+
assert best_score_rand <= best_score_enum
183+
184+
185+
if __name__ == "__main__":
186+
unittest.main()

0 commit comments

Comments
 (0)