Skip to content

Commit e9c449d

Browse files
Add params and output modality type
1 parent edbbb10 commit e9c449d

File tree

10 files changed

+79
-17
lines changed

10 files changed

+79
-17
lines changed

src/main/python/systemds/scuro/representations/bert.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,10 @@
2929

3030

3131
class Bert(UnimodalRepresentation):
32-
def __init__(self, output_file=None):
33-
super().__init__("Bert")
32+
def __init__(self, model_name="bert", output_file=None):
33+
parameters = {"model_name": "bert"}
34+
self.model_name = model_name
35+
super().__init__("Bert", ModalityType.EMBEDDING, parameters)
3436

3537
self.output_file = output_file
3638

src/main/python/systemds/scuro/representations/bow.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,9 @@
2727

2828

2929
class BoW(UnimodalRepresentation):
30-
def __init__(self, ngram_range, min_df, output_file=None):
31-
super().__init__("BoW")
30+
def __init__(self, ngram_range=2, min_df=2, output_file=None):
31+
parameters = {"ngram_range": [ngram_range], "min_df": [min_df]}
32+
super().__init__("BoW", ModalityType.EMBEDDING, parameters)
3233
self.ngram_range = ngram_range
3334
self.min_df = min_df
3435
self.output_file = output_file

src/main/python/systemds/scuro/representations/fusion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@
2727

2828

2929
class Fusion(Representation):
30-
def __init__(self, name):
30+
def __init__(self, name, parameters=None):
3131
"""
3232
Parent class for different multimodal fusion types
3333
:param name: Name of the fusion type
3434
"""
35-
super().__init__(name)
35+
super().__init__(name, parameters)
3636

3737
def transform(self, modalities: List[Modality]):
3838
"""

src/main/python/systemds/scuro/representations/glove.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def load_glove_embeddings(file_path):
3939

4040
class GloVe(UnimodalRepresentation):
4141
def __init__(self, glove_path, output_file=None):
42-
super().__init__("GloVe")
42+
super().__init__("GloVe", ModalityType.TEXT)
4343
self.glove_path = glove_path
4444
self.output_file = output_file
4545

src/main/python/systemds/scuro/representations/mel_spectrogram.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,16 @@
2828

2929

3030
class MelSpectrogram(UnimodalRepresentation):
31-
def __init__(self):
32-
super().__init__("MelSpectrogram")
31+
def __init__(self, n_mels=128, hop_length=512, n_fft=2048):
32+
parameters = {
33+
"n_mels": [20, 32, 64, 128],
34+
"hop_length": [256, 512, 1024, 2048],
35+
"n_fft": [1024, 2048, 4096],
36+
}
37+
super().__init__("MelSpectrogram", ModalityType.TIMESERIES, parameters)
38+
self.n_mels = n_mels
39+
self.hop_length = hop_length
40+
self.n_fft = n_fft
3341

3442
def transform(self, modality):
3543
transformed_modality = TransformedModality(

src/main/python/systemds/scuro/representations/representation.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,24 @@
1818
# under the License.
1919
#
2020
# -------------------------------------------------------------
21+
from abc import abstractmethod
2122

2223

2324
class Representation:
24-
def __init__(self, name):
25+
def __init__(self, name, parameters):
2526
self.name = name
27+
self._parameters = parameters
28+
29+
@property
30+
def parameters(self):
31+
return self._parameters
32+
33+
def get_current_parameters(self):
34+
current_params = {}
35+
for parameter in self.parameters.keys():
36+
current_params[parameter] = getattr(self, parameter)
37+
return current_params
38+
39+
def set_parameters(self, parameters):
40+
for parameter in parameters:
41+
setattr(self, parameter, parameters[parameter])

src/main/python/systemds/scuro/representations/resnet.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@
3838

3939
class ResNet(UnimodalRepresentation):
4040
def __init__(self, layer="avgpool", model_name="ResNet18", output_file=None):
41-
super().__init__("ResNet")
41+
self.model_name = model_name
42+
parameters = self._get_parameters()
43+
super().__init__("ResNet", ModalityType.TIMESERIES, parameters) # TODO: TIMESERIES only for videos - images would be handled as EMBEDDIGN
4244

4345
self.output_file = output_file
4446
self.layer_name = layer
@@ -82,6 +84,25 @@ def model(self, model):
8284
else:
8385
raise NotImplementedError
8486

87+
def _get_parameters(self, high_level=True):
88+
parameters = {"model_name": [], "layer_name": []}
89+
for m in ["ResNet18", "ResNet34", "ResNet50", "ResNet101", "ResNet152"]:
90+
parameters["model_name"].append(m)
91+
92+
if high_level:
93+
parameters["layer_name"] = [
94+
"conv1",
95+
"layer1",
96+
"layer2",
97+
"layer3",
98+
"layer4",
99+
"avgpool",
100+
]
101+
else:
102+
for name, layer in self.model.named_modules():
103+
parameters["layer_name"].append(name)
104+
return parameters
105+
85106
def transform(self, modality):
86107

87108
t = transforms.Compose(

src/main/python/systemds/scuro/representations/tfidf.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,9 @@
2727

2828

2929
class TfIdf(UnimodalRepresentation):
30-
def __init__(self, min_df, output_file=None):
31-
super().__init__("TF-IDF")
30+
def __init__(self, min_df=2, output_file=None):
31+
parameters = {"min_df": [min_df]}
32+
super().__init__("TF-IDF", ModalityType.EMBEDDING, parameters)
3233
self.min_df = min_df
3334
self.output_file = output_file
3435

src/main/python/systemds/scuro/representations/unimodal.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,25 @@
1818
# under the License.
1919
#
2020
# -------------------------------------------------------------
21+
import abc
22+
2123
from systemds.scuro.representations.representation import Representation
2224

2325

2426
class UnimodalRepresentation(Representation):
25-
def __init__(self, name):
27+
def __init__(self, name: str, output_modality_type, parameters=None):
2628
"""
2729
Parent class for all unimodal representation types
2830
:param name: name of the representation
31+
:param parameters: parameters of the representation; name of the parameter and
32+
possible parameter values
2933
"""
30-
super().__init__(name)
34+
super().__init__(name, parameters)
35+
self.output_modality_type = output_modality_type
36+
if parameters is None:
37+
parameters = {}
3138

39+
@abc.abstractmethod
3240
def transform(self, data):
3341
raise f"Not implemented for {self.name}"
3442

src/main/python/systemds/scuro/representations/word2vec.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,13 @@ def get_embedding(sentence, model):
4040

4141

4242
class W2V(UnimodalRepresentation):
43-
def __init__(self, vector_size, min_count, window, output_file=None):
44-
super().__init__("Word2Vec")
43+
def __init__(self, vector_size=3, min_count=2, window=2, output_file=None):
44+
parameters = {
45+
"vector_size": [vector_size],
46+
"min_count": [min_count],
47+
"window": [window],
48+
} # TODO
49+
super().__init__("Word2Vec", ModalityType.EMBEDDING, parameters)
4550
self.vector_size = vector_size
4651
self.min_count = min_count
4752
self.window = window

0 commit comments

Comments
 (0)