|
22 | 22 | from spacy.vocab import Vocab |
23 | 23 | from data import data_type, doc_ock |
24 | 24 | from embedders import Transformer |
25 | | -from typing import Any, Dict, Iterator, List, Optional |
| 25 | +from typing import Any, Dict, Iterator, List, Optional, Union |
26 | 26 |
|
27 | 27 | from util import daemon, request_util |
28 | 28 | from util.config_handler import get_config_value |
|
33 | 33 | import pandas as pd |
34 | 34 | from submodules.s3 import controller as s3 |
35 | 35 | import openai |
| 36 | +import gc |
| 37 | + |
36 | 38 |
|
37 | 39 | logging.basicConfig(level=logging.INFO) |
38 | 40 | logger = logging.getLogger(__name__) |
@@ -77,7 +79,7 @@ def get_docbins( |
77 | 79 | docs = list(doc_bin_loaded.get_docs(vocab)) |
78 | 80 | for col, doc in zip(record_item.columns, docs): |
79 | 81 | if col == attribute_name: |
80 | | - result[record_item.record_id] = doc |
| 82 | + result[str(record_item.record_id)] = doc |
81 | 83 | result_list = [] |
82 | 84 | for record_id in record_ids_batch: |
83 | 85 | result_list.append(result[record_id]) |
@@ -409,14 +411,14 @@ def run_encoding( |
409 | 411 | if embedding_type == enums.EmbeddingType.ON_ATTRIBUTE.value: |
410 | 412 | request_util.post_embedding_to_neural_search(project_id, embedding_id) |
411 | 413 |
|
412 | | - if get_config_value("is_managed"): |
413 | | - pickle_path = os.path.join( |
414 | | - "/inference", project_id, f"embedder-{embedding_id}.pkl" |
415 | | - ) |
416 | | - if not os.path.exists(pickle_path): |
417 | | - os.makedirs(os.path.dirname(pickle_path), exist_ok=True) |
418 | | - with open(pickle_path, "wb") as f: |
419 | | - pickle.dump(embedder, f) |
| 414 | + # now always since otherwise record edit wouldn't work for embedded columns |
| 415 | + pickle_path = os.path.join( |
| 416 | + "/inference", project_id, f"embedder-{embedding_id}.pkl" |
| 417 | + ) |
| 418 | + if not os.path.exists(pickle_path): |
| 419 | + os.makedirs(os.path.dirname(pickle_path), exist_ok=True) |
| 420 | + with open(pickle_path, "wb") as f: |
| 421 | + pickle.dump(embedder, f) |
420 | 422 |
|
421 | 423 | upload_embedding_as_file(project_id, embedding_id) |
422 | 424 | embedding.update_embedding_state_finished( |
@@ -499,3 +501,96 @@ def upload_embedding_as_file( |
499 | 501 |
|
500 | 502 | def __is_embedders_internal_model(model_name: str): |
501 | 503 | return model_name in ["bag-of-characters", "bag-of-words", "tf-idf"] |
| 504 | + |
| 505 | + |
| 506 | +def re_embed_records(project_id: str, changes: Dict[str, List[Dict[str, str]]]): |
| 507 | + for embedding_id in changes: |
| 508 | + if len(changes[embedding_id]) == 0: |
| 509 | + continue |
| 510 | + |
| 511 | + embedding_item = embedding.get(project_id, embedding_id) |
| 512 | + if not embedding_item: |
| 513 | + continue |
| 514 | + |
| 515 | + # convert to int since the request automatically converts it to string |
| 516 | + if "sub_key" in changes[embedding_id][0]: |
| 517 | + for d in changes[embedding_id]: |
| 518 | + d["sub_key"] = int(d["sub_key"]) |
| 519 | + |
| 520 | + embedder = __setup_tmp_embedder(project_id, embedding_id) |
| 521 | + |
| 522 | + data_to_embed = None |
| 523 | + record_ids = None # Either list or set depending on embedding type |
| 524 | + attribute_name = changes[embedding_id][0]["attribute_name"] |
| 525 | + |
| 526 | + if embedding_item.type == enums.EmbeddingType.ON_TOKEN.value: |
| 527 | + # can't have sub_key so records are unique so we can just get them all since order is preserved in get_docbins |
| 528 | + record_ids = [c["record_id"] for c in changes[embedding_id]] |
| 529 | + data_to_embed = get_docbins( |
| 530 | + project_id, record_ids, embedder.nlp.vocab, attribute_name |
| 531 | + ) |
| 532 | + else: |
| 533 | + # order is important, data collection request doesn't order so we do it ourselves |
| 534 | + record_ids = {c["record_id"] for c in changes[embedding_id]} |
| 535 | + records = record.get_by_record_ids(project_id, record_ids) |
| 536 | + records = {str(r.id): r for r in records} |
| 537 | + |
| 538 | + data_to_embed = [ |
| 539 | + records[c["record_id"]].data[attribute_name] |
| 540 | + if "sub_key" not in c |
| 541 | + else records[c["record_id"]].data[attribute_name][c["sub_key"]] |
| 542 | + for c in changes[embedding_id] |
| 543 | + ] |
| 544 | + |
| 545 | + new_tensors = embedder.transform(data_to_embed) |
| 546 | + |
| 547 | + if len(new_tensors) != len(changes[embedding_id]): |
| 548 | + raise Exception( |
| 549 | + f"Number of new tensors ({len(new_tensors)}) doesn't match number of changes ({len(changes[embedding_id])})" |
| 550 | + ) |
| 551 | + |
| 552 | + # delete old |
| 553 | + if "sub_key" in changes[embedding_id][0]: |
| 554 | + embedding.delete_by_record_ids_and_sub_keys( |
| 555 | + project_id, |
| 556 | + embedding_id, |
| 557 | + [(c["record_id"], c["sub_key"]) for c in changes[embedding_id]], |
| 558 | + ) |
| 559 | + else: |
| 560 | + embedding.delete_by_record_ids(project_id, embedding_id, record_ids) |
| 561 | + # add new |
| 562 | + record_ids_batched = [ |
| 563 | + c["record_id"] |
| 564 | + if "sub_key" not in c |
| 565 | + else c["record_id"] + "@" + str(c["sub_key"]) |
| 566 | + for c in changes[embedding_id] |
| 567 | + ] |
| 568 | + |
| 569 | + embedding.create_tensors( |
| 570 | + project_id, |
| 571 | + embedding_id, |
| 572 | + record_ids_batched, |
| 573 | + new_tensors, |
| 574 | + with_commit=True, |
| 575 | + ) |
| 576 | + |
| 577 | + upload_embedding_as_file(project_id, embedding_id) |
| 578 | + request_util.delete_embedding_from_neural_search(embedding_id) |
| 579 | + request_util.post_embedding_to_neural_search(project_id, embedding_id) |
| 580 | + |
| 581 | + del embedder |
| 582 | + time.sleep(0.1) |
| 583 | + gc.collect() |
| 584 | + time.sleep(0.1) |
| 585 | + |
| 586 | + |
| 587 | +def __setup_tmp_embedder(project_id: str, embedder_id: str) -> Transformer: |
| 588 | + embedder_path = os.path.join( |
| 589 | + "/inference", project_id, f"embedder-{embedder_id}.pkl" |
| 590 | + ) |
| 591 | + if not os.path.exists(embedder_path): |
| 592 | + raise Exception(f"Embedder {embedder_id} not found") |
| 593 | + with open(embedder_path, "rb") as f: |
| 594 | + embedder = pickle.load(f) |
| 595 | + |
| 596 | + return embedder |
0 commit comments