Skip to content

Commit 2cba23b

Browse files
committed
Bert model's ability to handle long texts and address the 512 token limit
1 parent b242810 commit 2cba23b

File tree

4 files changed

+61
-9
lines changed

4 files changed

+61
-9
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,3 +147,4 @@ dmypy.json
147147
/embedding_npy
148148
/flask_server
149149
*.bin
150+
*ini

flask4modelcache.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
# -*- coding: utf-8 -*-
2-
import json
2+
import time
3+
from datetime import datetime
34
from flask import Flask, request
45
import logging
5-
from datetime import datetime
66
import configparser
7-
import time
87
import json
98
from modelcache import cache
109
from modelcache.adapter import adapter
@@ -105,10 +104,12 @@ def user_backend():
105104

106105
if request_type == 'query':
107106
try:
107+
start_time = time.time()
108108
response = adapter.ChatCompletion.create_query(
109109
scope={"model": model},
110110
query=query
111111
)
112+
delta_time = '{}s'.format(round(time.time() - start_time, 2))
112113
if response is None:
113114
result = {"errorCode": 0, "errorDesc": '', "cacheHit": False, "delta_time": delta_time, "hit_query": '',
114115
"answer": ''}
@@ -120,6 +121,7 @@ def user_backend():
120121
hit_query = response_hitquery(response)
121122
result = {"errorCode": 0, "errorDesc": '', "cacheHit": True, "delta_time": delta_time,
122123
"hit_query": hit_query, "answer": answer}
124+
delta_time_log = round(time.time() - start_time, 2)
123125
future = executor.submit(save_query_info, result, model, query, delta_time_log)
124126
except Exception as e:
125127
result = {"errorCode": 202, "errorDesc": e, "cacheHit": False, "delta_time": 0,

modelcache/adapter/adapter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def create_insert(cls, *args, **kwargs):
3535
logging.info('adapt_insert_e: {}'.format(e))
3636
return 'adapt_insert_exception'
3737

38+
3839
@classmethod
3940
def create_remove(cls, *args, **kwargs):
4041
try:

modelcache/embedding/data2vec.py

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,18 +28,66 @@ def __init__(self, model: str = "sentence-transformers/all-MiniLM-L6-v2"):
2828
config = AutoConfig.from_pretrained(model)
2929
self.__dimension = config.hidden_size
3030

31+
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
3132
self.tokenizer = BertTokenizer.from_pretrained(model, local_files_only=True)
3233
self.model = BertModel.from_pretrained(model, local_files_only=True)
3334

3435
def to_embeddings(self, data, **_):
3536
encoded_input = self.tokenizer(data, padding=True, truncation=True, return_tensors='pt')
36-
with torch.no_grad():
37-
model_output = self.model(**encoded_input)
37+
num_tokens = sum(map(len, encoded_input['input_ids']))
3838

39-
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
40-
sentence_embeddings = sentence_embeddings.squeeze(0).detach().numpy()
41-
embedding_array = np.array(sentence_embeddings).astype("float32")
42-
return embedding_array
39+
if num_tokens <= 512:
40+
with torch.no_grad():
41+
encoded_input = {k: v.to(self.device) for k, v in encoded_input.items()}
42+
model_output = self.model(**encoded_input)
43+
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
44+
sentence_embeddings = sentence_embeddings.squeeze(0).detach().cpu().numpy()
45+
embedding_array = np.array(sentence_embeddings).astype("float32")
46+
return embedding_array
47+
else:
48+
window_size = 510
49+
start = 0
50+
input_ids = encoded_input['input_ids']
51+
input_ids = input_ids[:, 1:-1]
52+
start_token = self.tokenizer.cls_token
53+
end_token = self.tokenizer.sep_token
54+
start_token_id = self.tokenizer.convert_tokens_to_ids(start_token)
55+
end_token_id = self.tokenizer.convert_tokens_to_ids(end_token)
56+
begin_element = torch.tensor([[start_token_id]])
57+
end_element = torch.tensor([[end_token_id]])
58+
59+
embedding_array_list = list()
60+
while start < num_tokens:
61+
# Calculate the ending position of the sliding window.
62+
end = start + window_size
63+
# If the ending position exceeds the length, adjust it to the length.
64+
if end > num_tokens:
65+
end = num_tokens
66+
# Retrieve the data within the sliding window.
67+
input_ids_window = input_ids[:, start:end]
68+
# Insert a new element at position 0.
69+
input_ids_window = torch.cat([begin_element, input_ids_window[:, 0:]], dim=1)
70+
# Insert a new element at the last position.
71+
input_ids_window = torch.cat([input_ids_window, end_element], dim=1)
72+
input_ids_window_length = sum(map(len, input_ids_window))
73+
token_type_ids = torch.tensor([[0] * input_ids_window_length])
74+
attention_mask = torch.tensor([[1] * input_ids_window_length])
75+
76+
# Concatenate new input_ids
77+
encoded_input_window = {'input_ids': input_ids_window, 'token_type_ids': token_type_ids,
78+
'attention_mask': attention_mask}
79+
with torch.no_grad():
80+
encoded_input_window = {k: v.to(self.device) for k, v in encoded_input_window.items()}
81+
model_output_window = self.model(**encoded_input_window)
82+
83+
sentence_embeddings_window = mean_pooling(model_output_window, encoded_input_window['attention_mask'])
84+
sentence_embeddings_window = sentence_embeddings_window.squeeze(0).detach().cpu().numpy()
85+
embedding_array_window = np.array(sentence_embeddings_window).astype("float32")
86+
embedding_array_list.append(embedding_array_window)
87+
start = end
88+
89+
embedding_array = np.mean(embedding_array_list, axis=0)
90+
return embedding_array
4391

4492
def post_proc(self, token_embeddings, inputs):
4593
attention_mask = inputs["attention_mask"]

0 commit comments

Comments
 (0)