-
Notifications
You must be signed in to change notification settings - Fork 17
Expand file tree
/
Copy pathtest_transformers_online_macos.py
More file actions
114 lines (94 loc) · 4.28 KB
/
test_transformers_online_macos.py
File metadata and controls
114 lines (94 loc) · 4.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import torch
import numpy as np
from transformers import AutoModel, AutoTokenizer
class LLMEmbeddingModel():
def __init__(self,
model_name_or_path,
batch_size=128,
max_length=1024,
gpu_id=0):
self.model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True)
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side="right", trust_remote_code=True)
# macOS 友好的设备选择:CUDA -> MPS -> CPU
if torch.cuda.is_available():
self.device = torch.device(f"cuda:{gpu_id}")
elif torch.backends.mps.is_available():
self.device = torch.device("mps")
else:
self.device = torch.device("cpu")
self.model.to(self.device).eval()
self.max_length = max_length
self.batch_size = batch_size
query_instruction = "Given a search query, retrieve passages that answer the question"
if query_instruction:
self.query_instruction = f"Instruction: {query_instruction} \nQuery:"
else:
self.query_instruction = "Query:"
self.doc_instruction = ""
print(f"query instruction: {[self.query_instruction]}\ndoc instruction: {[self.doc_instruction]}")
print(f"Using device: {self.device}")
def mean_pooling(self, hidden_state, attention_mask):
s = torch.sum(hidden_state * attention_mask.unsqueeze(-1).float(), dim=1)
d = attention_mask.sum(dim=1, keepdim=True).float()
embedding = s / d
return embedding
@torch.no_grad()
def encode(self, sentences_batch, instruction):
inputs = self.tokenizer(
sentences_batch,
padding=True,
truncation=True,
return_tensors="pt",
max_length=self.max_length,
add_special_tokens=True,
)
# 将输入移动到目标设备
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
last_hidden_state = outputs[0]
instruction_tokens = self.tokenizer(
instruction,
padding=False,
truncation=True,
max_length=self.max_length,
add_special_tokens=True,
)["input_ids"]
if len(np.shape(np.array(instruction_tokens))) == 1:
inputs["attention_mask"][:, :len(instruction_tokens)] = 0
else:
instruction_length = [len(item) for item in instruction_tokens]
assert len(instruction) == len(sentences_batch)
for idx in range(len(instruction_length)):
inputs["attention_mask"][idx, :instruction_length[idx]] = 0
embeddings = self.mean_pooling(last_hidden_state, inputs["attention_mask"])
embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
return embeddings
def encode_queries(self, queries):
queries = queries if isinstance(queries, list) else [queries]
queries = [f"{self.query_instruction}{query}" for query in queries]
return self.encode(queries, self.query_instruction)
def encode_passages(self, passages):
passages = passages if isinstance(passages, list) else [passages]
passages = [f"{self.doc_instruction}{passage}" for passage in passages]
return self.encode(passages, self.doc_instruction)
def compute_similarity_for_vectors(self, q_reps, p_reps):
if len(p_reps.size()) == 2:
return torch.matmul(q_reps, p_reps.transpose(0, 1))
return torch.matmul(q_reps, p_reps.transpose(-2, -1))
def compute_similarity(self, queries, passages):
q_reps = self.encode_queries(queries)
p_reps = self.encode_passages(passages)
scores = self.compute_similarity_for_vectors(q_reps, p_reps)
scores = scores.detach().cpu().tolist()
return scores
queries = ["What's the weather like?"]
passages = [
'The weather is lovely today.',
"It's so sunny outside!",
'He drove to the stadium.'
]
model_name_or_path = "tencent/Youtu-Embedding"
model = LLMEmbeddingModel(model_name_or_path)
scores = model.compute_similarity(queries, passages)
print(f"scores: {scores}")