Skip to content

Commit 5b3e87b

Browse files
add additional text representations
1 parent a6a1509 commit 5b3e87b

File tree

7 files changed

+279
-30
lines changed

7 files changed

+279
-30
lines changed

src/main/python/systemds/scuro/representations/bert.py

Lines changed: 10 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,12 @@
1919
#
2020
# -------------------------------------------------------------
2121

22-
import pickle
23-
2422
import numpy as np
2523

2624
from systemds.scuro.representations.unimodal import UnimodalRepresentation
2725
import torch
2826
from transformers import BertTokenizer, BertModel
29-
import os
30-
31-
32-
def read_text_file(file_path):
33-
with open(file_path, "r", encoding="utf-8") as file:
34-
text = file.read()
35-
return text
27+
from systemds.scuro.representations.utils import read_data_from_file, save_embeddings
3628

3729

3830
class Bert(UnimodalRepresentation):
@@ -42,18 +34,8 @@ def __init__(self, avg_layers=None, output_file=None):
4234
self.avg_layers = avg_layers
4335
self.output_file = output_file
4436

45-
def parse_all(self, filepath, indices, get_sequences=False):
46-
# Assumes text is stored in .txt files
47-
data = []
48-
if os.path.isdir(filepath):
49-
for filename in os.listdir(filepath):
50-
f = os.path.join(filepath, filename)
51-
if os.path.isfile(f):
52-
with open(f, "r") as file:
53-
data.append(file.readlines()[0])
54-
else:
55-
with open(filepath, "r") as file:
56-
data = file.readlines()
37+
def parse_all(self, filepath, indices):
38+
data = read_data_from_file(filepath, indices)
5739

5840
model_name = "bert-base-uncased"
5941
tokenizer = BertTokenizer.from_pretrained(
@@ -65,13 +47,13 @@ def parse_all(self, filepath, indices, get_sequences=False):
6547
else:
6648
model = BertModel.from_pretrained(model_name)
6749

68-
embeddings = self.create_embeddings(data, model, tokenizer)
50+
embeddings = self.create_embeddings(list(data.values()), model, tokenizer)
6951

7052
if self.output_file is not None:
7153
data = {}
7254
for i in range(0, embeddings.shape[0]):
7355
data[indices[i]] = embeddings[i]
74-
self.save_embeddings(data)
56+
save_embeddings(data, self.output_file)
7557

7658
return embeddings
7759

@@ -88,14 +70,13 @@ def create_embeddings(self, data, model, tokenizer):
8870
outputs.hidden_states[i][:, 0, :]
8971
for i in range(-self.avg_layers, 0)
9072
]
91-
cls_embedding = torch.mean(torch.stack(cls_embedding), dim=0)
73+
cls_embedding = torch.mean(torch.stack(cls_embedding), dim=0).numpy()
9274
else:
9375
cls_embedding = outputs.last_hidden_state[:, 0, :].squeeze().numpy()
94-
embeddings.append(cls_embedding.numpy())
76+
embeddings.append(cls_embedding)
77+
78+
if self.output_file is not None:
79+
save_embeddings(embeddings, self.output_file)
9580

9681
embeddings = np.array(embeddings)
9782
return embeddings.reshape((embeddings.shape[0], embeddings.shape[-1]))
98-
99-
def save_embeddings(self, data):
100-
with open(self.output_file, "wb") as file:
101-
pickle.dump(data, file)
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# -------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
# -------------------------------------------------------------
21+
22+
import pandas as pd
23+
from sklearn.feature_extraction.text import CountVectorizer
24+
25+
from systemds.scuro.representations.unimodal import UnimodalRepresentation
26+
from systemds.scuro.representations.utils import read_data_from_file, save_embeddings
27+
28+
29+
class BoW(UnimodalRepresentation):
30+
def __init__(self, ngram_range, min_df, output_file=None):
31+
super().__init__("BoW")
32+
self.ngram_range = ngram_range
33+
self.min_df = min_df
34+
self.output_file = output_file
35+
36+
def parse_all(self, filepath, indices):
37+
vectorizer = CountVectorizer(
38+
ngram_range=(1, self.ngram_range), min_df=self.min_df
39+
)
40+
41+
segments = read_data_from_file(filepath, indices)
42+
X = vectorizer.fit_transform(segments.values())
43+
X = X.toarray()
44+
45+
if self.output_file is not None:
46+
df = pd.DataFrame(X)
47+
df.index = segments.keys()
48+
49+
save_embeddings(df, self.output_file)
50+
51+
return X
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# -------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
# -------------------------------------------------------------
21+
import nltk
22+
import numpy as np
23+
from nltk import word_tokenize
24+
25+
from systemds.scuro.representations.unimodal import UnimodalRepresentation
26+
from systemds.scuro.representations.utils import read_data_from_file, save_embeddings
27+
28+
29+
def load_glove_embeddings(file_path):
30+
embeddings = {}
31+
with open(file_path, "r", encoding="utf-8") as f:
32+
for line in f:
33+
values = line.split()
34+
word = values[0]
35+
vector = np.asarray(values[1:], dtype="float32")
36+
embeddings[word] = vector
37+
return embeddings
38+
39+
40+
class GloVe(UnimodalRepresentation):
41+
def __init__(self, glove_path, output_file=None):
42+
super().__init__("GloVe")
43+
self.glove_path = glove_path
44+
self.output_file = output_file
45+
46+
def parse_all(self, filepath, indices):
47+
glove_embeddings = load_glove_embeddings(self.glove_path)
48+
segments = read_data_from_file(filepath, indices)
49+
50+
embeddings = {}
51+
for k, v in segments.items():
52+
tokens = word_tokenize(v.lower())
53+
embeddings[k] = np.mean(
54+
[
55+
glove_embeddings[token]
56+
for token in tokens
57+
if token in glove_embeddings
58+
],
59+
axis=0,
60+
)
61+
62+
if self.output_file is not None:
63+
save_embeddings(embeddings, self.output_file)
64+
65+
embeddings = np.array(list(embeddings.values()))
66+
return embeddings
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# -------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
# -------------------------------------------------------------
21+
22+
import pandas as pd
23+
from sklearn.feature_extraction.text import TfidfVectorizer
24+
25+
from systemds.scuro.representations.unimodal import UnimodalRepresentation
26+
from systemds.scuro.representations.utils import read_data_from_file, save_embeddings
27+
28+
29+
class TfIdf(UnimodalRepresentation):
30+
def __init__(self, min_df, output_file=None):
31+
super().__init__("TF-IDF")
32+
self.min_df = min_df
33+
self.output_file = output_file
34+
35+
def parse_all(self, filepath, indices):
36+
vectorizer = TfidfVectorizer(min_df=self.min_df)
37+
38+
segments = read_data_from_file(filepath, indices)
39+
X = vectorizer.fit_transform(segments.values())
40+
X = X.toarray()
41+
42+
if self.output_file is not None:
43+
df = pd.DataFrame(X)
44+
df.index = segments.keys()
45+
46+
save_embeddings(df, self.output_file)
47+
48+
return X

src/main/python/systemds/scuro/representations/utils.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
# under the License.
1919
#
2020
# -------------------------------------------------------------
21+
import os
22+
import pickle
2123

2224
import numpy as np
2325

@@ -33,3 +35,39 @@ def pad_sequences(sequences, maxlen=None, dtype="float32", value=0):
3335
result[i, : len(data)] = data
3436

3537
return result
38+
39+
40+
def get_segments(data, key_prefix):
41+
segments = {}
42+
counter = 1
43+
for line in data:
44+
line = line.replace("\n", "")
45+
segments[key_prefix + str(counter)] = line
46+
counter += 1
47+
48+
return segments
49+
50+
51+
def read_data_from_file(filepath, indices):
52+
data = {}
53+
54+
is_dir = True if os.path.isdir(filepath) else False
55+
56+
if is_dir:
57+
files = os.listdir(filepath)
58+
59+
# get file extension
60+
_, ext = os.path.splitext(files[0])
61+
for key in indices:
62+
with open(filepath + key + ext) as segm:
63+
data.update(get_segments(segm, key + "_"))
64+
else:
65+
with open(filepath) as file:
66+
data.update(get_segments(file, ""))
67+
68+
return data
69+
70+
71+
def save_embeddings(data, file_name):
72+
with open(file_name, "wb") as file:
73+
pickle.dump(data, file)
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# -------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
# -------------------------------------------------------------
21+
import numpy as np
22+
23+
from systemds.scuro.representations.unimodal import UnimodalRepresentation
24+
from systemds.scuro.representations.utils import read_data_from_file, save_embeddings
25+
from gensim.models import Word2Vec
26+
from nltk.tokenize import word_tokenize
27+
import nltk
28+
29+
30+
def get_embedding(sentence, model):
31+
vectors = []
32+
for word in sentence:
33+
if word in model.wv:
34+
vectors.append(model.wv[word])
35+
36+
return np.mean(vectors, axis=0) if vectors else np.zeros(model.vector_size)
37+
38+
39+
class W2V(UnimodalRepresentation):
40+
def __init__(self, vector_size, min_count, window, output_file=None):
41+
super().__init__("Word2Vec")
42+
self.vector_size = vector_size
43+
self.min_count = min_count
44+
self.window = window
45+
self.output_file = output_file
46+
47+
def parse_all(self, filepath, indices):
48+
segments = read_data_from_file(filepath, indices)
49+
embeddings = {}
50+
t = [word_tokenize(s.lower()) for s in segments.values()]
51+
model = Word2Vec(
52+
sentences=t,
53+
vector_size=self.vector_size,
54+
window=self.window,
55+
min_count=self.min_count,
56+
)
57+
58+
for k, v in segments.items():
59+
tokenized_words = word_tokenize(v.lower())
60+
embeddings[k] = get_embedding(tokenized_words, model)
61+
62+
if self.output_file is not None:
63+
save_embeddings(embeddings, self.output_file)
64+
65+
return np.array(list(embeddings.values()))

src/main/python/tests/scuro/data_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(self, modalities, path, balanced=True):
3636
self.balanced = balanced
3737

3838
for modality in modalities:
39-
mod_path = f"{self.path}/{modality.name.lower()}"
39+
mod_path = f"{self.path}/{modality.name.lower()}/"
4040
os.mkdir(mod_path)
4141
modality.file_path = mod_path
4242
self.labels = []

0 commit comments

Comments
 (0)