-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathscibert_embed.py
More file actions
26 lines (22 loc) · 1.34 KB
/
scibert_embed.py
File metadata and controls
26 lines (22 loc) · 1.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
from transformers import AutoTokenizer, AutoModel
scibert_tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased')
scibert_model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased', output_hidden_states = True)
scibert_model.eval()
def scibert_embeddings(text):
# tokens = scibert_tokenizer.tokenize(text[0])
# print(f"{tokens =}")
encoding = scibert_tokenizer.batch_encode_plus(text, padding=True, truncation=False, return_tensors="pt", add_special_tokens=True)
input_ids = encoding['input_ids']
attention_mask = encoding["attention_mask"]
# print(input_ids.shape, attention_mask.shape)
with torch.no_grad():
outputs = scibert_model(input_ids, attention_mask=attention_mask)
hidden_states = outputs[2] # This contains the hidden states
token_embeddings = torch.stack(hidden_states, dim=0) # converting the tuple into pytorch tensor
token_embeddings = token_embeddings[-2][0] # taking second last layer output and first sentence (only sentence)
# print(f"Shape of Word Embeddings: {token_embeddings.shape}")
token_embeddings = token_embeddings[1:-1,:] # removing CLS and SEP token embeddings
keyphrase_embedding = torch.mean(token_embeddings, dim=0)
# print(f"Shape of Keyphrase Embeddings: {keyphrase_embedding.shape}")
return keyphrase_embedding