|
| 1 | +# !pip install pydgraph pybars3 sentence_transformers mistralai openai |
| 2 | +import sys |
| 3 | +import json |
| 4 | +import os |
| 5 | +import re |
| 6 | +import pydgraph |
| 7 | +from openai import OpenAI |
| 8 | +from mistralai.client import MistralClient |
| 9 | + |
| 10 | +from pybars import Compiler |
| 11 | +from sentence_transformers import SentenceTransformer |
| 12 | + |
| 13 | + |
| 14 | +# Example of embeddings.json |
| 15 | +# { |
| 16 | +# "embeddings" : [ |
| 17 | +# { |
| 18 | +# "entityType":"Product", |
| 19 | +# "attribute":"product_embedding", |
| 20 | +# "index":"hnsw(metric: \"euclidean\")", |
| 21 | +# "provider": "huggingface", |
| 22 | +# "model":"sentence-transformers/all-MiniLM-L6-v2", |
| 23 | +# "config" : { |
| 24 | +# "dqlQuery" : "{ title:Product.title }", |
| 25 | +# "template": "{{title}} " |
| 26 | +# }, |
| 27 | +# "disabled": false |
| 28 | +# } |
| 29 | +# ] |
| 30 | +# } |
| 31 | +# |
| 32 | +# provider : huggingface or openai or mistral |
| 33 | +# model : model name from the provider doing embeddings |
| 34 | + |
| 35 | +# TODO |
| 36 | +# check code for template loops. |
| 37 | + |
| 38 | +compiler = Compiler() |
| 39 | +global client # dgrpah client is a global variable |
| 40 | + |
| 41 | + |
| 42 | +# DGRAPH_CONNECTION must be defined in env variables |
| 43 | +# E.g dgraph://young-wind.us-east-1.aws.cloud.dgraph.io:443?sslmode=verify-ca&apikey=... |
| 44 | +assert "DGRAPH_CONNECTION" in os.environ, "DGRAPH_CONNECTION must be defined as a connection string" |
| 45 | +dgraph_cnx = os.environ["DGRAPH_CONNECTION"] |
| 46 | + |
| 47 | + |
| 48 | +# TRANSFORMER_API_KEY must be defined in env variables |
| 49 | +# client stub for on-prem requires grpc host:port without protocol |
| 50 | +# client stub for cloud requires the grpc endpoint of graphql endpoint or base url of the cluster |
| 51 | +# to run on a self-hosted env, unset ADMIN_KEY and set DGRAPH_GRPC |
| 52 | + |
| 53 | + |
| 54 | +def setClient(): |
| 55 | + global client |
| 56 | + client = pydgraph.open(dgraph_cnx) |
| 57 | + |
| 58 | + |
| 59 | + |
| 60 | +def clearIndex(predicate): |
| 61 | + print(f"remove index for {predicate}") |
| 62 | + schema = f"{predicate}: float32vector ." |
| 63 | + op = pydgraph.Operation(schema=schema) |
| 64 | + alter = client.alter(op) |
| 65 | + print(alter) |
| 66 | + |
| 67 | + |
| 68 | +def computeIndex(predicate, index): |
| 69 | + print(f"create index for {predicate} {index}") |
| 70 | + schema = f"{predicate}: float32vector @index({index}) ." |
| 71 | + op = pydgraph.Operation(schema=schema) |
| 72 | + alter = client.alter(op) |
| 73 | + print(alter) |
| 74 | + |
| 75 | + |
| 76 | +def huggingfaceEmbeddings(model, sentences): |
| 77 | + embeddings = model.encode(sentences) |
| 78 | + return embeddings.tolist() |
| 79 | + |
| 80 | + |
| 81 | +def computeEmbedding( |
| 82 | + predicate, data, template, provider, modelName, model, llm, dimensions |
| 83 | +): |
| 84 | + # data is an array of objects contaiing uid and other predicates |
| 85 | + # create an array of text |
| 86 | + # get the embeddings |
| 87 | + # produce a RDF text |
| 88 | + # data is a list of object having uid and other predicates used in the template |
| 89 | + |
| 90 | + nquad_list = [] |
| 91 | + sentences = [template(e) for e in data] |
| 92 | + |
| 93 | + if "huggingface" == provider: |
| 94 | + embeddings = huggingfaceEmbeddings(model, sentences) |
| 95 | + elif "openai" == provider: |
| 96 | + if dimensions is not None: |
| 97 | + openaidata = llm.embeddings.create( |
| 98 | + input=sentences, |
| 99 | + model=modelName, |
| 100 | + encoding_format="float", |
| 101 | + dimensions=dimensions, |
| 102 | + ) |
| 103 | + else: |
| 104 | + openaidata = llm.embeddings.create( |
| 105 | + input=sentences, model=modelName, encoding_format="float" |
| 106 | + ) |
| 107 | + embeddings = [e.embedding for e in openaidata.data] |
| 108 | + elif "mistral" == provider: |
| 109 | + mistraldata = llm.embeddings(model=modelName, input=sentences) |
| 110 | + embeddings = [e.embedding for e in mistraldata.data] |
| 111 | + |
| 112 | + # embeddings is a list of vectors in the same order as the input data |
| 113 | + try: |
| 114 | + for i in range(0, len(data)): |
| 115 | + uid = data[i]["uid"] |
| 116 | + nquad_list.append(f'<{uid}> <{predicate}> "{embeddings[i]}" .') |
| 117 | + # (prompt="{body[uid]}") |
| 118 | + except Exception: |
| 119 | + print(embeddings) |
| 120 | + return nquad_list |
| 121 | + |
| 122 | + |
| 123 | +def mutate_rdf(nquads, client): |
| 124 | + ret = {} |
| 125 | + body = "\n".join(nquads) |
| 126 | + if len(nquads) > 0: |
| 127 | + txn = client.txn() |
| 128 | + try: |
| 129 | + res = txn.mutate(set_nquads=body) |
| 130 | + txn.commit() |
| 131 | + ret["nquads"] = (len(nquads),) |
| 132 | + ret["total_ns"] = res.latency.total_ns |
| 133 | + except Exception as inst: |
| 134 | + print(inst) |
| 135 | + finally: |
| 136 | + txn.discard() |
| 137 | + return ret |
| 138 | + |
| 139 | + |
| 140 | +def buildEmbeddings(embedding_def, only_missing=True, filehandle=sys.stdout): |
| 141 | + global client |
| 142 | + entity = embedding_def["entityType"] |
| 143 | + config = embedding_def["config"] |
| 144 | + provider = embedding_def["provider"] |
| 145 | + modelName = embedding_def["model"] |
| 146 | + dimensions = embedding_def["dimensions"] if "dimensions" in embedding_def else None |
| 147 | + index = embedding_def["index"] |
| 148 | + |
| 149 | + if "huggingface" == provider: |
| 150 | + model = SentenceTransformer(modelName) |
| 151 | + llmclient = None |
| 152 | + else: |
| 153 | + model = None |
| 154 | + if "openai" == provider: |
| 155 | + llmclient = OpenAI( |
| 156 | + # This is the default and can be omitted |
| 157 | + api_key=os.environ.get("OPENAI_API_KEY"), |
| 158 | + ) |
| 159 | + elif "mistral" == provider: |
| 160 | + assert "MISTRAL_API_KEY" in os.environ, "MISTRAL_API_KEY must be defined" |
| 161 | + llmclient = MistralClient(api_key=os.environ.get("MISTRAL_API_KEY")) |
| 162 | + |
| 163 | + predicate = f"{embedding_def['entityType']}.{embedding_def['attribute']}" |
| 164 | + |
| 165 | + total = 0 |
| 166 | + |
| 167 | + template = compiler.compile(config["template"]) |
| 168 | + # inject uid in the query |
| 169 | + # querypart = re.sub(r'([a-zA-Z_]+)',rf"\1:{entity}.\1",config['query']) |
| 170 | + querypart = config["dqlQuery"] |
| 171 | + querypart = querypart.replace("{", "{ uid ", 1) |
| 172 | + print(querypart) |
| 173 | + # remove index by updating DQL schema |
| 174 | + clearIndex(predicate) |
| 175 | + print( |
| 176 | + f"compute embeddings for {predicate} using model {modelName} from {provider}" |
| 177 | + ) |
| 178 | + if only_missing: |
| 179 | + filter = f"@filter( NOT has({predicate}))" |
| 180 | + else: |
| 181 | + filter = "" |
| 182 | + # Run query. |
| 183 | + after = "" |
| 184 | + while True: |
| 185 | + print(".") |
| 186 | + txn = client.txn(read_only=True) |
| 187 | + query = ( |
| 188 | + f"{{list(func: type({entity}),first:100 {after}) {filter} {querypart} }}" |
| 189 | + ) |
| 190 | + try: |
| 191 | + res = txn.query(query) |
| 192 | + data = json.loads(res.json) |
| 193 | + except Exception as inst: |
| 194 | + print(type(inst)) # the exception type |
| 195 | + print(inst.args) # arguments stored in .args |
| 196 | + print(inst) |
| 197 | + break |
| 198 | + finally: |
| 199 | + txn.discard() |
| 200 | + |
| 201 | + if len(data["list"]) > 0: |
| 202 | + last_uid = data["list"][-1]["uid"] |
| 203 | + after = f",after:{last_uid}" |
| 204 | + else: |
| 205 | + break |
| 206 | + |
| 207 | + nquads = computeEmbedding( |
| 208 | + predicate, |
| 209 | + data["list"], |
| 210 | + template, |
| 211 | + provider, |
| 212 | + modelName, |
| 213 | + model, |
| 214 | + llmclient, |
| 215 | + dimensions, |
| 216 | + ) |
| 217 | + if filehandle is None: |
| 218 | + mutate_rdf(nquads, client) |
| 219 | + else: |
| 220 | + filehandle.write("\n".join(nquads)) |
| 221 | + total += len(data["list"]) |
| 222 | + |
| 223 | + computeIndex(predicate, index) |
| 224 | + return total |
| 225 | + |
| 226 | + |
| 227 | +def replace_env(matchobj): |
| 228 | + key = matchobj.group(1) |
| 229 | + assert key in os.environ, ( |
| 230 | + "config file is using a key not defined as environment variable: " + key |
| 231 | + ) |
| 232 | + return os.environ.get(key) |
| 233 | + |
| 234 | + |
| 235 | + |
| 236 | +print("using connection string ") |
| 237 | +print(dgraph_cnx) |
| 238 | +if len(sys.argv) == 2: |
| 239 | + outputfile = sys.argv[1] |
| 240 | + print(f"Produce RDF file in {outputfile}") |
| 241 | +else: |
| 242 | + outputfile = None |
| 243 | + print("Mutate embeddings in cluster.") |
| 244 | + |
| 245 | +confirm = input("Continue (y/n)?") |
| 246 | + |
| 247 | +if confirm == "y": |
| 248 | + q = input("Generate only missing embedding (y/n)?") |
| 249 | + only_missing = q == "y" |
| 250 | + |
| 251 | + re_env = re.compile(r"{{env.(\w*)}}") |
| 252 | + setClient() |
| 253 | + |
| 254 | + with open("./embeddings.json") as f: |
| 255 | + data = f.read() |
| 256 | + raw = re_env.sub(replace_env, data) |
| 257 | + embeddings = json.loads(raw) |
| 258 | + |
| 259 | + definitions = embeddings["embeddings"] |
| 260 | + |
| 261 | + if outputfile is not None: |
| 262 | + out = open(outputfile, "w") |
| 263 | + else: |
| 264 | + out = None |
| 265 | + for embedding_def in definitions: |
| 266 | + total = buildEmbeddings(embedding_def, only_missing, out) |
| 267 | + print( |
| 268 | + f"{total} embeddings for {embedding_def['entityType']}.{embedding_def['attribute']}" |
| 269 | + ) |
| 270 | + if out is not None: |
| 271 | + out.close() |
0 commit comments