Skip to content

Commit 98ab818

Browse files
Option to edit records (#60)
* Rembed endpoint * fixes token level embedding calculation --------- Co-authored-by: FelixKirschKern <[email protected]>
1 parent ab1b4b5 commit 98ab818

File tree

3 files changed

+124
-12
lines changed

3 files changed

+124
-12
lines changed

app.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,16 @@ def upload_tensor_data(
139139
return responses.PlainTextResponse(status_code=status.HTTP_200_OK)
140140

141141

142+
@app.post("/re_embed_records/{project_id}")
143+
def re_embed_record(
144+
project_id: str, request: data_type.EmbeddingRebuildRequest
145+
) -> responses.PlainTextResponse:
146+
session_token = general.get_ctx_token()
147+
controller.re_embed_records(project_id, request.changes)
148+
general.remove_and_refresh_session(session_token)
149+
return responses.PlainTextResponse(status_code=status.HTTP_200_OK)
150+
151+
142152
@app.put("/config_changed")
143153
def config_changed() -> responses.PlainTextResponse:
144154
config_handler.refresh_config()

controller.py

Lines changed: 105 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from spacy.vocab import Vocab
2323
from data import data_type, doc_ock
2424
from embedders import Transformer
25-
from typing import Any, Dict, Iterator, List, Optional
25+
from typing import Any, Dict, Iterator, List, Optional, Union
2626

2727
from util import daemon, request_util
2828
from util.config_handler import get_config_value
@@ -33,6 +33,8 @@
3333
import pandas as pd
3434
from submodules.s3 import controller as s3
3535
import openai
36+
import gc
37+
3638

3739
logging.basicConfig(level=logging.INFO)
3840
logger = logging.getLogger(__name__)
@@ -77,7 +79,7 @@ def get_docbins(
7779
docs = list(doc_bin_loaded.get_docs(vocab))
7880
for col, doc in zip(record_item.columns, docs):
7981
if col == attribute_name:
80-
result[record_item.record_id] = doc
82+
result[str(record_item.record_id)] = doc
8183
result_list = []
8284
for record_id in record_ids_batch:
8385
result_list.append(result[record_id])
@@ -409,14 +411,14 @@ def run_encoding(
409411
if embedding_type == enums.EmbeddingType.ON_ATTRIBUTE.value:
410412
request_util.post_embedding_to_neural_search(project_id, embedding_id)
411413

412-
if get_config_value("is_managed"):
413-
pickle_path = os.path.join(
414-
"/inference", project_id, f"embedder-{embedding_id}.pkl"
415-
)
416-
if not os.path.exists(pickle_path):
417-
os.makedirs(os.path.dirname(pickle_path), exist_ok=True)
418-
with open(pickle_path, "wb") as f:
419-
pickle.dump(embedder, f)
414+
# now always since otherwise record edit wouldn't work for embedded columns
415+
pickle_path = os.path.join(
416+
"/inference", project_id, f"embedder-{embedding_id}.pkl"
417+
)
418+
if not os.path.exists(pickle_path):
419+
os.makedirs(os.path.dirname(pickle_path), exist_ok=True)
420+
with open(pickle_path, "wb") as f:
421+
pickle.dump(embedder, f)
420422

421423
upload_embedding_as_file(project_id, embedding_id)
422424
embedding.update_embedding_state_finished(
@@ -499,3 +501,96 @@ def upload_embedding_as_file(
499501

500502
def __is_embedders_internal_model(model_name: str):
501503
return model_name in ["bag-of-characters", "bag-of-words", "tf-idf"]
504+
505+
506+
def re_embed_records(project_id: str, changes: Dict[str, List[Dict[str, str]]]):
507+
for embedding_id in changes:
508+
if len(changes[embedding_id]) == 0:
509+
continue
510+
511+
embedding_item = embedding.get(project_id, embedding_id)
512+
if not embedding_item:
513+
continue
514+
515+
# convert to int since the request automatically converts it to string
516+
if "sub_key" in changes[embedding_id][0]:
517+
for d in changes[embedding_id]:
518+
d["sub_key"] = int(d["sub_key"])
519+
520+
embedder = __setup_tmp_embedder(project_id, embedding_id)
521+
522+
data_to_embed = None
523+
record_ids = None # Either list or set depending on embedding type
524+
attribute_name = changes[embedding_id][0]["attribute_name"]
525+
526+
if embedding_item.type == enums.EmbeddingType.ON_TOKEN.value:
527+
# can't have sub_key so records are unique so we can just get them all since order is preserved in get_docbins
528+
record_ids = [c["record_id"] for c in changes[embedding_id]]
529+
data_to_embed = get_docbins(
530+
project_id, record_ids, embedder.nlp.vocab, attribute_name
531+
)
532+
else:
533+
# order is important, data collection request doesn't order so we do it ourselves
534+
record_ids = {c["record_id"] for c in changes[embedding_id]}
535+
records = record.get_by_record_ids(project_id, record_ids)
536+
records = {str(r.id): r for r in records}
537+
538+
data_to_embed = [
539+
records[c["record_id"]].data[attribute_name]
540+
if "sub_key" not in c
541+
else records[c["record_id"]].data[attribute_name][c["sub_key"]]
542+
for c in changes[embedding_id]
543+
]
544+
545+
new_tensors = embedder.transform(data_to_embed)
546+
547+
if len(new_tensors) != len(changes[embedding_id]):
548+
raise Exception(
549+
f"Number of new tensors ({len(new_tensors)}) doesn't match number of changes ({len(changes[embedding_id])})"
550+
)
551+
552+
# delete old
553+
if "sub_key" in changes[embedding_id][0]:
554+
embedding.delete_by_record_ids_and_sub_keys(
555+
project_id,
556+
embedding_id,
557+
[(c["record_id"], c["sub_key"]) for c in changes[embedding_id]],
558+
)
559+
else:
560+
embedding.delete_by_record_ids(project_id, embedding_id, record_ids)
561+
# add new
562+
record_ids_batched = [
563+
c["record_id"]
564+
if "sub_key" not in c
565+
else c["record_id"] + "@" + str(c["sub_key"])
566+
for c in changes[embedding_id]
567+
]
568+
569+
embedding.create_tensors(
570+
project_id,
571+
embedding_id,
572+
record_ids_batched,
573+
new_tensors,
574+
with_commit=True,
575+
)
576+
577+
upload_embedding_as_file(project_id, embedding_id)
578+
request_util.delete_embedding_from_neural_search(embedding_id)
579+
request_util.post_embedding_to_neural_search(project_id, embedding_id)
580+
581+
del embedder
582+
time.sleep(0.1)
583+
gc.collect()
584+
time.sleep(0.1)
585+
586+
587+
def __setup_tmp_embedder(project_id: str, embedder_id: str) -> Transformer:
588+
embedder_path = os.path.join(
589+
"/inference", project_id, f"embedder-{embedder_id}.pkl"
590+
)
591+
if not os.path.exists(embedder_path):
592+
raise Exception(f"Embedder {embedder_id} not found")
593+
with open(embedder_path, "rb") as f:
594+
embedder = pickle.load(f)
595+
596+
return embedder

data/data_type.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
1-
from dataclasses import dataclass
2-
from typing import Optional
1+
from typing import Dict, List
32
from pydantic import BaseModel
43

54

65
class EmbeddingRequest(BaseModel):
76
project_id: str
87
embedding_id: str
98

9+
10+
class EmbeddingRebuildRequest(BaseModel):
11+
# example request structure:
12+
# {"<embedding_id>":[{"record_id":"<record_id>","attribute_name":"<attribute_name>","sub_key":<sub_key>}]}
13+
# note that sub_key is optional and only for embedding lists relevant
14+
# also sub_key is an int but converted to string in the request
15+
16+
changes: Dict[str, List[Dict[str, str]]]

0 commit comments

Comments
 (0)