forked from zilliztech/GPTCache
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtowhee.py
More file actions
26 lines (21 loc) · 839 Bytes
/
towhee.py
File metadata and controls
26 lines (21 loc) · 839 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from gptcache.util import import_towhee
import_towhee()
import numpy as np
from towhee.dc2 import pipe, ops
class Towhee:
# english model: paraphrase-albert-small-v2
# chinese model: uer/albert-base-chinese-cluecorpussmall
def __init__(self, model="paraphrase-albert-small-v2"):
self._pipe = (
pipe.input('text')
.map('text', 'vec',
ops.sentence_embedding.transformers(model_name=model))
.map('vec', 'vec', ops.towhee.np_normalize())
.output('text', 'vec')
)
self.__dimension = len(self._pipe("foo").get_dict()['vec'])
def to_embeddings(self, data, **kwargs):
emb = self._pipe(data).get_dict()['vec']
return np.array(emb).astype('float32')
def dimension(self):
return self.__dimension