@@ -28,18 +28,66 @@ def __init__(self, model: str = "sentence-transformers/all-MiniLM-L6-v2"):
28
28
config = AutoConfig .from_pretrained (model )
29
29
self .__dimension = config .hidden_size
30
30
31
+ self .device = 'cuda' if torch .cuda .is_available () else 'cpu'
31
32
self .tokenizer = BertTokenizer .from_pretrained (model , local_files_only = True )
32
33
self .model = BertModel .from_pretrained (model , local_files_only = True )
33
34
34
35
def to_embeddings (self , data , ** _ ):
35
36
encoded_input = self .tokenizer (data , padding = True , truncation = True , return_tensors = 'pt' )
36
- with torch .no_grad ():
37
- model_output = self .model (** encoded_input )
37
+ num_tokens = sum (map (len , encoded_input ['input_ids' ]))
38
38
39
- sentence_embeddings = mean_pooling (model_output , encoded_input ['attention_mask' ])
40
- sentence_embeddings = sentence_embeddings .squeeze (0 ).detach ().numpy ()
41
- embedding_array = np .array (sentence_embeddings ).astype ("float32" )
42
- return embedding_array
39
+ if num_tokens <= 512 :
40
+ with torch .no_grad ():
41
+ encoded_input = {k : v .to (self .device ) for k , v in encoded_input .items ()}
42
+ model_output = self .model (** encoded_input )
43
+ sentence_embeddings = mean_pooling (model_output , encoded_input ['attention_mask' ])
44
+ sentence_embeddings = sentence_embeddings .squeeze (0 ).detach ().cpu ().numpy ()
45
+ embedding_array = np .array (sentence_embeddings ).astype ("float32" )
46
+ return embedding_array
47
+ else :
48
+ window_size = 510
49
+ start = 0
50
+ input_ids = encoded_input ['input_ids' ]
51
+ input_ids = input_ids [:, 1 :- 1 ]
52
+ start_token = self .tokenizer .cls_token
53
+ end_token = self .tokenizer .sep_token
54
+ start_token_id = self .tokenizer .convert_tokens_to_ids (start_token )
55
+ end_token_id = self .tokenizer .convert_tokens_to_ids (end_token )
56
+ begin_element = torch .tensor ([[start_token_id ]])
57
+ end_element = torch .tensor ([[end_token_id ]])
58
+
59
+ embedding_array_list = list ()
60
+ while start < num_tokens :
61
+ # Calculate the ending position of the sliding window.
62
+ end = start + window_size
63
+ # If the ending position exceeds the length, adjust it to the length.
64
+ if end > num_tokens :
65
+ end = num_tokens
66
+ # Retrieve the data within the sliding window.
67
+ input_ids_window = input_ids [:, start :end ]
68
+ # Insert a new element at position 0.
69
+ input_ids_window = torch .cat ([begin_element , input_ids_window [:, 0 :]], dim = 1 )
70
+ # Insert a new element at the last position.
71
+ input_ids_window = torch .cat ([input_ids_window , end_element ], dim = 1 )
72
+ input_ids_window_length = sum (map (len , input_ids_window ))
73
+ token_type_ids = torch .tensor ([[0 ] * input_ids_window_length ])
74
+ attention_mask = torch .tensor ([[1 ] * input_ids_window_length ])
75
+
76
+ # Concatenate new input_ids
77
+ encoded_input_window = {'input_ids' : input_ids_window , 'token_type_ids' : token_type_ids ,
78
+ 'attention_mask' : attention_mask }
79
+ with torch .no_grad ():
80
+ encoded_input_window = {k : v .to (self .device ) for k , v in encoded_input_window .items ()}
81
+ model_output_window = self .model (** encoded_input_window )
82
+
83
+ sentence_embeddings_window = mean_pooling (model_output_window , encoded_input_window ['attention_mask' ])
84
+ sentence_embeddings_window = sentence_embeddings_window .squeeze (0 ).detach ().cpu ().numpy ()
85
+ embedding_array_window = np .array (sentence_embeddings_window ).astype ("float32" )
86
+ embedding_array_list .append (embedding_array_window )
87
+ start = end
88
+
89
+ embedding_array = np .mean (embedding_array_list , axis = 0 )
90
+ return embedding_array
43
91
44
92
def post_proc (self , token_embeddings , inputs ):
45
93
attention_mask = inputs ["attention_mask" ]
0 commit comments