Skip to content

Commit f8463d5

Browse files
authored
add compute embeddings script (#75)
1 parent f9859cd commit f8463d5

File tree

3 files changed

+372
-0
lines changed

3 files changed

+372
-0
lines changed

embeddings/computeEmbeddings.py

Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
# !pip install pydgraph pybars3 sentence_transformers mistralai openai
2+
import sys
3+
import json
4+
import os
5+
import re
6+
import pydgraph
7+
from openai import OpenAI
8+
from mistralai.client import MistralClient
9+
10+
from pybars import Compiler
11+
from sentence_transformers import SentenceTransformer
12+
13+
14+
# Example of embeddings.json
15+
# {
16+
# "embeddings" : [
17+
# {
18+
# "entityType":"Product",
19+
# "attribute":"product_embedding",
20+
# "index":"hnsw(metric: \"euclidean\")",
21+
# "provider": "huggingface",
22+
# "model":"sentence-transformers/all-MiniLM-L6-v2",
23+
# "config" : {
24+
# "dqlQuery" : "{ title:Product.title }",
25+
# "template": "{{title}} "
26+
# },
27+
# "disabled": false
28+
# }
29+
# ]
30+
# }
31+
#
32+
# provider : huggingface or openai or mistral
33+
# model : model name from the provider doing embeddings
34+
35+
# TODO
36+
# check code for template loops.
37+
38+
compiler = Compiler()
39+
global client # dgrpah client is a global variable
40+
41+
42+
# DGRAPH_CONNECTION must be defined in env variables
43+
# E.g dgraph://young-wind.us-east-1.aws.cloud.dgraph.io:443?sslmode=verify-ca&apikey=...
44+
assert "DGRAPH_CONNECTION" in os.environ, "DGRAPH_CONNECTION must be defined as a connection string"
45+
dgraph_cnx = os.environ["DGRAPH_CONNECTION"]
46+
47+
48+
# TRANSFORMER_API_KEY must be defined in env variables
49+
# client stub for on-prem requires grpc host:port without protocol
50+
# client stub for cloud requires the grpc endpoint of graphql endpoint or base url of the cluster
51+
# to run on a self-hosted env, unset ADMIN_KEY and set DGRAPH_GRPC
52+
53+
54+
def setClient():
55+
global client
56+
client = pydgraph.open(dgraph_cnx)
57+
58+
59+
60+
def clearIndex(predicate):
61+
print(f"remove index for {predicate}")
62+
schema = f"{predicate}: float32vector ."
63+
op = pydgraph.Operation(schema=schema)
64+
alter = client.alter(op)
65+
print(alter)
66+
67+
68+
def computeIndex(predicate, index):
69+
print(f"create index for {predicate} {index}")
70+
schema = f"{predicate}: float32vector @index({index}) ."
71+
op = pydgraph.Operation(schema=schema)
72+
alter = client.alter(op)
73+
print(alter)
74+
75+
76+
def huggingfaceEmbeddings(model, sentences):
77+
embeddings = model.encode(sentences)
78+
return embeddings.tolist()
79+
80+
81+
def computeEmbedding(
82+
predicate, data, template, provider, modelName, model, llm, dimensions
83+
):
84+
# data is an array of objects contaiing uid and other predicates
85+
# create an array of text
86+
# get the embeddings
87+
# produce a RDF text
88+
# data is a list of object having uid and other predicates used in the template
89+
90+
nquad_list = []
91+
sentences = [template(e) for e in data]
92+
93+
if "huggingface" == provider:
94+
embeddings = huggingfaceEmbeddings(model, sentences)
95+
elif "openai" == provider:
96+
if dimensions is not None:
97+
openaidata = llm.embeddings.create(
98+
input=sentences,
99+
model=modelName,
100+
encoding_format="float",
101+
dimensions=dimensions,
102+
)
103+
else:
104+
openaidata = llm.embeddings.create(
105+
input=sentences, model=modelName, encoding_format="float"
106+
)
107+
embeddings = [e.embedding for e in openaidata.data]
108+
elif "mistral" == provider:
109+
mistraldata = llm.embeddings(model=modelName, input=sentences)
110+
embeddings = [e.embedding for e in mistraldata.data]
111+
112+
# embeddings is a list of vectors in the same order as the input data
113+
try:
114+
for i in range(0, len(data)):
115+
uid = data[i]["uid"]
116+
nquad_list.append(f'<{uid}> <{predicate}> "{embeddings[i]}" .')
117+
# (prompt="{body[uid]}")
118+
except Exception:
119+
print(embeddings)
120+
return nquad_list
121+
122+
123+
def mutate_rdf(nquads, client):
124+
ret = {}
125+
body = "\n".join(nquads)
126+
if len(nquads) > 0:
127+
txn = client.txn()
128+
try:
129+
res = txn.mutate(set_nquads=body)
130+
txn.commit()
131+
ret["nquads"] = (len(nquads),)
132+
ret["total_ns"] = res.latency.total_ns
133+
except Exception as inst:
134+
print(inst)
135+
finally:
136+
txn.discard()
137+
return ret
138+
139+
140+
def buildEmbeddings(embedding_def, only_missing=True, filehandle=sys.stdout):
141+
global client
142+
entity = embedding_def["entityType"]
143+
config = embedding_def["config"]
144+
provider = embedding_def["provider"]
145+
modelName = embedding_def["model"]
146+
dimensions = embedding_def["dimensions"] if "dimensions" in embedding_def else None
147+
index = embedding_def["index"]
148+
149+
if "huggingface" == provider:
150+
model = SentenceTransformer(modelName)
151+
llmclient = None
152+
else:
153+
model = None
154+
if "openai" == provider:
155+
llmclient = OpenAI(
156+
# This is the default and can be omitted
157+
api_key=os.environ.get("OPENAI_API_KEY"),
158+
)
159+
elif "mistral" == provider:
160+
assert "MISTRAL_API_KEY" in os.environ, "MISTRAL_API_KEY must be defined"
161+
llmclient = MistralClient(api_key=os.environ.get("MISTRAL_API_KEY"))
162+
163+
predicate = f"{embedding_def['entityType']}.{embedding_def['attribute']}"
164+
165+
total = 0
166+
167+
template = compiler.compile(config["template"])
168+
# inject uid in the query
169+
# querypart = re.sub(r'([a-zA-Z_]+)',rf"\1:{entity}.\1",config['query'])
170+
querypart = config["dqlQuery"]
171+
querypart = querypart.replace("{", "{ uid ", 1)
172+
print(querypart)
173+
# remove index by updating DQL schema
174+
clearIndex(predicate)
175+
print(
176+
f"compute embeddings for {predicate} using model {modelName} from {provider}"
177+
)
178+
if only_missing:
179+
filter = f"@filter( NOT has({predicate}))"
180+
else:
181+
filter = ""
182+
# Run query.
183+
after = ""
184+
while True:
185+
print(".")
186+
txn = client.txn(read_only=True)
187+
query = (
188+
f"{{list(func: type({entity}),first:100 {after}) {filter} {querypart} }}"
189+
)
190+
try:
191+
res = txn.query(query)
192+
data = json.loads(res.json)
193+
except Exception as inst:
194+
print(type(inst)) # the exception type
195+
print(inst.args) # arguments stored in .args
196+
print(inst)
197+
break
198+
finally:
199+
txn.discard()
200+
201+
if len(data["list"]) > 0:
202+
last_uid = data["list"][-1]["uid"]
203+
after = f",after:{last_uid}"
204+
else:
205+
break
206+
207+
nquads = computeEmbedding(
208+
predicate,
209+
data["list"],
210+
template,
211+
provider,
212+
modelName,
213+
model,
214+
llmclient,
215+
dimensions,
216+
)
217+
if filehandle is None:
218+
mutate_rdf(nquads, client)
219+
else:
220+
filehandle.write("\n".join(nquads))
221+
total += len(data["list"])
222+
223+
computeIndex(predicate, index)
224+
return total
225+
226+
227+
def replace_env(matchobj):
228+
key = matchobj.group(1)
229+
assert key in os.environ, (
230+
"config file is using a key not defined as environment variable: " + key
231+
)
232+
return os.environ.get(key)
233+
234+
235+
236+
print("using connection string ")
237+
print(dgraph_cnx)
238+
if len(sys.argv) == 2:
239+
outputfile = sys.argv[1]
240+
print(f"Produce RDF file in {outputfile}")
241+
else:
242+
outputfile = None
243+
print("Mutate embeddings in cluster.")
244+
245+
confirm = input("Continue (y/n)?")
246+
247+
if confirm == "y":
248+
q = input("Generate only missing embedding (y/n)?")
249+
only_missing = q == "y"
250+
251+
re_env = re.compile(r"{{env.(\w*)}}")
252+
setClient()
253+
254+
with open("./embeddings.json") as f:
255+
data = f.read()
256+
raw = re_env.sub(replace_env, data)
257+
embeddings = json.loads(raw)
258+
259+
definitions = embeddings["embeddings"]
260+
261+
if outputfile is not None:
262+
out = open(outputfile, "w")
263+
else:
264+
out = None
265+
for embedding_def in definitions:
266+
total = buildEmbeddings(embedding_def, only_missing, out)
267+
print(
268+
f"{total} embeddings for {embedding_def['entityType']}.{embedding_def['attribute']}"
269+
)
270+
if out is not None:
271+
out.close()

embeddings/embeddings.json

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
{
2+
"embeddings" : [
3+
{
4+
"entityType":"Product",
5+
"attribute":"embedding",
6+
"index":"hnsw(metric: \"euclidean\")",
7+
"provider":"huggingface",
8+
"model":"sentence-transformers/all-MiniLM-L6-v2",
9+
"config" : {
10+
"dqlQuery" : "{ title:Product.title }",
11+
"template": "{{title}} "
12+
},
13+
"disabled": false
14+
},
15+
{
16+
"entityType":"Service",
17+
"attribute":"embedding",
18+
"index":"hnsw(metric: \"euclidean\")",
19+
"provider":"huggingface",
20+
"model":"sentence-transformers/all-MiniLM-L6-v2",
21+
"config" : {
22+
"dqlQuery" : "{ title:Service.title }",
23+
"template": "{{title}} "
24+
},
25+
"disabled": false
26+
}
27+
]
28+
}

embeddings/reset-index.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# !pip install pydgraph python_graphql_client
2+
import sys
3+
import json
4+
import os
5+
import re
6+
import pydgraph
7+
8+
9+
# reset the index all embedding predicates or of one provided predicate name
10+
# Should be replaced by Deploying the GraphQL Schema without indexes and then deploying with the Indexes.
11+
#
12+
global client # dgrpah client is a global variable
13+
14+
assert "DGRAPH_GRPC" in os.environ, "DGRAPH_GRPC must be defined"
15+
dgraph_grpc = os.environ["DGRAPH_GRPC"]
16+
if "cloud.dgraph" in dgraph_grpc:
17+
assert "DGRAPH_ADMIN_KEY" in os.environ, "DGRAPH_ADMIN_KEY must be defined"
18+
APIAdminKey = os.environ["DGRAPH_ADMIN_KEY"]
19+
else:
20+
APIAdminKey = None
21+
22+
# TRANSFORMER_API_KEY must be defined in env variables
23+
# client stub for on-prem requires grpc host:port without protocol
24+
# client stub for cloud requires the grpc endpoint of graphql endpoint or base url of the cluster
25+
# to run on a self-hosted env, unset ADMIN_KEY and set DGRAPH_GRPC
26+
27+
def setClient():
28+
global client
29+
if APIAdminKey is None:
30+
client_stub = pydgraph.DgraphClientStub(dgraph_grpc)
31+
else:
32+
client_stub = pydgraph.DgraphClientStub.from_cloud(dgraph_grpc,APIAdminKey )
33+
client = pydgraph.DgraphClient(client_stub)
34+
35+
def clearIndex(predicate):
36+
print(f"remove index for {predicate}")
37+
schema = f"{predicate}: float32vector ."
38+
op = pydgraph.Operation(schema=schema)
39+
alter = client.alter(op)
40+
print(alter)
41+
def computeIndex(predicate,index):
42+
print(f"create index for {predicate} {index}")
43+
schema = f"{predicate}: float32vector @index({index}) ."
44+
op = pydgraph.Operation(schema=schema)
45+
alter = client.alter(op)
46+
print(alter)
47+
48+
49+
50+
if len(sys.argv) == 3:
51+
requested_predicate = sys.argv[2]
52+
print(f"Reindexing {requested_predicate} in {dgraph_grpc}")
53+
else:
54+
requested_predicate = None
55+
print(f"Reindexing all embeddings predicates in {dgraph_grpc}")
56+
57+
58+
confirm = input("Continue (y/n)?")
59+
60+
61+
62+
if confirm == "y":
63+
setClient()
64+
65+
with open("./embeddings.json") as f:
66+
data = f.read()
67+
hm_config = json.loads(data)
68+
for embedding_def in hm_config['embeddings']:
69+
predicate = f"{embedding_def['entityType']}.{embedding_def['attribute']}"
70+
if requested_predicate == None or requested_predicate == predicate:
71+
index = embedding_def['index']
72+
clearIndex(predicate)
73+
computeIndex(predicate,index)

0 commit comments

Comments
 (0)