|
13 | 13 | from graphql.error.base import GraphQLError |
14 | 14 | from submodules.model import enums, events |
15 | 15 | from submodules.model.business_objects import ( |
| 16 | + attribute, |
16 | 17 | information_source, |
17 | 18 | embedding, |
18 | 19 | labeling_task, |
|
33 | 34 | get_label_ids_by_names, |
34 | 35 | ) |
35 | 36 | from submodules.model.business_objects.payload import get_max_token, get |
36 | | -from submodules.model.business_objects.tokenization import get_doc_bin_progress |
| 37 | +from submodules.model.business_objects.tokenization import ( |
| 38 | + get_doc_bin_progress, |
| 39 | + get_doc_bin_table_to_json, |
| 40 | +) |
37 | 41 | from submodules.model.models import ( |
38 | 42 | InformationSource, |
39 | 43 | InformationSourceStatisticsExclusion, |
@@ -694,3 +698,108 @@ def add_information_source_statistics_exclusion( |
694 | 698 | if idx % 2 == 0 |
695 | 699 | ] |
696 | 700 | general.add_all(exclusions, with_commit=True) |
| 701 | + |
| 702 | + |
| 703 | +def prepare_sample_records_doc_bin( |
| 704 | + project_id: str, information_source_id: str |
| 705 | +) -> Tuple[str, List[str]]: |
| 706 | + sample_records = record.get_attribute_calculation_sample_records(project_id) |
| 707 | + |
| 708 | + sample_records_doc_bin = get_doc_bin_table_to_json( |
| 709 | + project_id=project_id, |
| 710 | + missing_columns=get_missing_columns_tokenization(project_id), |
| 711 | + record_ids=[r[0] for r in sample_records], |
| 712 | + ) |
| 713 | + project_item = project.get(project_id) |
| 714 | + org_id = str(project_item.organization_id) |
| 715 | + prefixed_doc_bin = f"{information_source_id}_doc_bin.json" |
| 716 | + s3.put_object( |
| 717 | + org_id, |
| 718 | + project_id + "/" + prefixed_doc_bin, |
| 719 | + sample_records_doc_bin, |
| 720 | + ) |
| 721 | + |
| 722 | + return prefixed_doc_bin, sample_records |
| 723 | + |
| 724 | + |
| 725 | +def run_labeling_function_exec_env( |
| 726 | + project_id: str, information_source_id: str, prefixed_doc_bin: str |
| 727 | +) -> Tuple[List[str], List[List[str]], bool]: |
| 728 | + |
| 729 | + information_source_item = information_source.get(project_id, information_source_id) |
| 730 | + |
| 731 | + prefixed_function_name = f"{information_source_id}_fn" |
| 732 | + prefixed_payload = f"{information_source_id}_payload.json" |
| 733 | + prefixed_knowledge_base = f"{information_source_id}_knowledge" |
| 734 | + project_item = project.get(project_id) |
| 735 | + org_id = str(project_item.organization_id) |
| 736 | + |
| 737 | + s3.put_object( |
| 738 | + org_id, |
| 739 | + project_id + "/" + prefixed_function_name, |
| 740 | + information_source_item.source_code, |
| 741 | + ) |
| 742 | + |
| 743 | + s3.put_object( |
| 744 | + org_id, |
| 745 | + project_id + "/" + prefixed_knowledge_base, |
| 746 | + knowledge_base.build_knowledge_base_from_project(project_id), |
| 747 | + ) |
| 748 | + |
| 749 | + tokenization_progress = get_doc_bin_progress(project_id) |
| 750 | + |
| 751 | + command = [ |
| 752 | + s3.create_access_link(org_id, project_id + "/" + prefixed_doc_bin), |
| 753 | + s3.create_access_link(org_id, project_id + "/" + prefixed_function_name), |
| 754 | + s3.create_access_link(org_id, project_id + "/" + prefixed_knowledge_base), |
| 755 | + tokenization_progress, |
| 756 | + project_item.tokenizer_blank, |
| 757 | + s3.create_file_upload_link(org_id, project_id + "/" + prefixed_payload), |
| 758 | + ] |
| 759 | + |
| 760 | + container = client.containers.run( |
| 761 | + image=lf_exec_env_image, |
| 762 | + command=command, |
| 763 | + remove=True, |
| 764 | + detach=True, |
| 765 | + network=exec_env_network, |
| 766 | + ) |
| 767 | + |
| 768 | + container_logs = [ |
| 769 | + line.decode("utf-8").strip("\n") |
| 770 | + for line in container.logs( |
| 771 | + stream=True, stdout=True, stderr=True, timestamps=True |
| 772 | + ) |
| 773 | + ] |
| 774 | + |
| 775 | + code_has_errors = False |
| 776 | + |
| 777 | + try: |
| 778 | + payload = s3.get_object(org_id, project_id + "/" + prefixed_payload) |
| 779 | + calculated_labels = json.loads(payload) |
| 780 | + except Exception: |
| 781 | + print("Could not grab data from s3 -- labeling function") |
| 782 | + code_has_errors = True |
| 783 | + calculated_labels = {} |
| 784 | + |
| 785 | + if not prefixed_doc_bin == "docbin_full": |
| 786 | + # sample records docbin should be deleted after calculation |
| 787 | + s3.delete_object(org_id, project_id + "/" + prefixed_doc_bin) |
| 788 | + s3.delete_object(org_id, project_id + "/" + prefixed_function_name) |
| 789 | + s3.delete_object(org_id, project_id + "/" + prefixed_payload) |
| 790 | + s3.delete_object(org_id, project_id + "/" + prefixed_knowledge_base) |
| 791 | + |
| 792 | + return calculated_labels, container_logs, code_has_errors |
| 793 | + |
| 794 | + |
| 795 | +def get_missing_columns_tokenization(project_id: str) -> str: |
| 796 | + missing_columns = [ |
| 797 | + attribute_item.name |
| 798 | + for attribute_item in attribute.get_all(project_id) |
| 799 | + if attribute_item.data_type != enums.DataTypes.TEXT.value |
| 800 | + ] |
| 801 | + missing_columns_str = ",\n".join( |
| 802 | + ["'" + k + "',r.data->'" + k + "'" for k in missing_columns] |
| 803 | + ) |
| 804 | + |
| 805 | + return missing_columns_str |
0 commit comments