-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathget_bert_embeddings.py
More file actions
54 lines (44 loc) · 1.9 KB
/
get_bert_embeddings.py
File metadata and controls
54 lines (44 loc) · 1.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# usage:
#
# dataset = QADataset(tokenizer=tokenizer,
# paragraph_tokens=par_tokens,
# question_tokens=que_tokens,
# answer_spans=answer_spans,
# word2index=word2index)
# embeddings = embed_data(dataset.x_data)
import torch
from pytorch_pretrained_bert import BertTokenizer, BertConfig, BertForMaskedLM
from tqdm.auto import tqdm
def get_embeddings(x_data_instance:list, model):
indexed_tokens = x_data_instance
tokens_tensor = torch.tensor([indexed_tokens])
segments_ids = [1] * len(indexed_tokens)
segments_tensors = torch.tensor([segments_ids])
model.eval()
with torch.no_grad():
encoded_layers, _ = model.bert(tokens_tensor,
segments_tensors)
token_embeddings = torch.stack(encoded_layers, dim=0)
token_embeddings = torch.squeeze(token_embeddings, dim=1)
token_embeddings = token_embeddings.permute(1, 0, 2)
token_vecs_cat = []
for token in token_embeddings:
cat_vec = torch.cat((token[-1], token[-2], token[-3], token[-4]),
dim=0)
token_vecs_cat.append(cat_vec)
return torch.stack(token_vecs_cat, dim=0)
def embed_data(x_data):
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
output_model_file = "lm/pytorch_model.bin"
output_config_file = "lm/config.json"
tokenizer = BertTokenizer.from_pretrained("lm", do_lower_case=False)
config = BertConfig.from_json_file(output_config_file)
model = BertForMaskedLM(config)
state_dict = torch.load(output_model_file, map_location=device)
model.load_state_dict(state_dict)
entries = []
data_iterator = tqdm(x_data, desc='Loading embeddings')
for entry in data_iterator:
entries.append(get_embeddings(entry, model))
entries = torch.stack(entries)
return entries