Skip to content

Commit f8b3507

Browse files
lumburovskalinaJWittmeyerJWittmeyer
authored
Release v1.3.4 (#55)
* Added update notifications for every update of the attribute * Back end for running lf on 10 records * Full record data returned for lf sample records * Adds formatting * PR comments * Submodule change Co-authored-by: JWittmeyer <[email protected]> Co-authored-by: JWittmeyer <[email protected]>
1 parent e18abc2 commit f8b3507

File tree

6 files changed

+192
-14
lines changed

6 files changed

+192
-14
lines changed

controller/attribute/manager.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@ def update_attribute(
8787
name: str,
8888
source_code: str,
8989
) -> None:
90-
attribute.update(
90+
91+
attribute_item: Attribute = attribute.update(
9192
project_id,
9293
attribute_id,
9394
data_type,
@@ -96,12 +97,10 @@ def update_attribute(
9697
source_code,
9798
with_commit=True,
9899
)
99-
if attribute.get(project_id, attribute_id).state in [
100-
AttributeState.UPLOADED.value,
101-
AttributeState.AUTOMATICALLY_CREATED.value,
102-
AttributeState.USABLE.value,
103-
]:
104-
notification.send_organization_update(project_id, "attributes_updated")
100+
101+
notification.send_organization_update(
102+
project_id=project_id, message=f"calculate_attribute:updated:{str(attribute_item.id)}"
103+
)
105104

106105

107106
def delete_attribute(project_id: str, attribute_id: str) -> None:

controller/payload/manager.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1-
from typing import Any, Optional
2-
1+
from typing import Any, Dict, List, Optional, Tuple
32
from controller.payload import payload_scheduler
4-
from submodules.model import InformationSourcePayload
5-
from submodules.model.business_objects import payload
3+
from graphql_api.types import (
4+
LabelingFunctionSampleRecordWrapper,
5+
LabelingFunctionSampleRecords,
6+
)
7+
from submodules.model import InformationSourcePayload, enums
8+
from submodules.model.business_objects import information_source, payload
69

710

811
def get_payload(project_id: str, payload_id: str) -> InformationSourcePayload:
@@ -39,3 +42,43 @@ def update_payload_status(
3942
project_id: str, payload_id: str, status: str
4043
) -> InformationSourcePayload:
4144
return payload.update_status(project_id, payload_id, status)
45+
46+
47+
def get_labeling_function_on_10_records(
48+
project_id: str, information_source_id: str
49+
) -> LabelingFunctionSampleRecords:
50+
doc_bin_samples, sample_records = payload_scheduler.prepare_sample_records_doc_bin(
51+
project_id=project_id, information_source_id=information_source_id
52+
)
53+
(
54+
calculated_labels,
55+
container_logs,
56+
code_has_errors,
57+
) = payload_scheduler.run_labeling_function_exec_env(
58+
project_id=project_id,
59+
information_source_id=information_source_id,
60+
prefixed_doc_bin=doc_bin_samples,
61+
)
62+
calculated_labels = fill_missing_record_ids(sample_records, calculated_labels)
63+
64+
return LabelingFunctionSampleRecords(
65+
records=[
66+
LabelingFunctionSampleRecordWrapper(
67+
record_id=record_item[0],
68+
full_record_data=record_item[1],
69+
calculated_labels=calculated_labels[record_item[0]],
70+
)
71+
for record_item in sample_records
72+
],
73+
container_logs=container_logs,
74+
code_has_errors=code_has_errors,
75+
)
76+
77+
78+
def fill_missing_record_ids(sample_records: List[str], calculated_labels: Dict[str, List[Any]]) -> List[str]:
79+
for record_item in sample_records:
80+
record_id = record_item[0]
81+
if record_id not in calculated_labels:
82+
calculated_labels[record_id] = []
83+
84+
return calculated_labels

controller/payload/payload_scheduler.py

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from graphql.error.base import GraphQLError
1414
from submodules.model import enums, events
1515
from submodules.model.business_objects import (
16+
attribute,
1617
information_source,
1718
embedding,
1819
labeling_task,
@@ -33,7 +34,10 @@
3334
get_label_ids_by_names,
3435
)
3536
from submodules.model.business_objects.payload import get_max_token, get
36-
from submodules.model.business_objects.tokenization import get_doc_bin_progress
37+
from submodules.model.business_objects.tokenization import (
38+
get_doc_bin_progress,
39+
get_doc_bin_table_to_json,
40+
)
3741
from submodules.model.models import (
3842
InformationSource,
3943
InformationSourceStatisticsExclusion,
@@ -694,3 +698,108 @@ def add_information_source_statistics_exclusion(
694698
if idx % 2 == 0
695699
]
696700
general.add_all(exclusions, with_commit=True)
701+
702+
703+
def prepare_sample_records_doc_bin(
704+
project_id: str, information_source_id: str
705+
) -> Tuple[str, List[str]]:
706+
sample_records = record.get_attribute_calculation_sample_records(project_id)
707+
708+
sample_records_doc_bin = get_doc_bin_table_to_json(
709+
project_id=project_id,
710+
missing_columns=get_missing_columns_tokenization(project_id),
711+
record_ids=[r[0] for r in sample_records],
712+
)
713+
project_item = project.get(project_id)
714+
org_id = str(project_item.organization_id)
715+
prefixed_doc_bin = f"{information_source_id}_doc_bin.json"
716+
s3.put_object(
717+
org_id,
718+
project_id + "/" + prefixed_doc_bin,
719+
sample_records_doc_bin,
720+
)
721+
722+
return prefixed_doc_bin, sample_records
723+
724+
725+
def run_labeling_function_exec_env(
726+
project_id: str, information_source_id: str, prefixed_doc_bin: str
727+
) -> Tuple[List[str], List[List[str]], bool]:
728+
729+
information_source_item = information_source.get(project_id, information_source_id)
730+
731+
prefixed_function_name = f"{information_source_id}_fn"
732+
prefixed_payload = f"{information_source_id}_payload.json"
733+
prefixed_knowledge_base = f"{information_source_id}_knowledge"
734+
project_item = project.get(project_id)
735+
org_id = str(project_item.organization_id)
736+
737+
s3.put_object(
738+
org_id,
739+
project_id + "/" + prefixed_function_name,
740+
information_source_item.source_code,
741+
)
742+
743+
s3.put_object(
744+
org_id,
745+
project_id + "/" + prefixed_knowledge_base,
746+
knowledge_base.build_knowledge_base_from_project(project_id),
747+
)
748+
749+
tokenization_progress = get_doc_bin_progress(project_id)
750+
751+
command = [
752+
s3.create_access_link(org_id, project_id + "/" + prefixed_doc_bin),
753+
s3.create_access_link(org_id, project_id + "/" + prefixed_function_name),
754+
s3.create_access_link(org_id, project_id + "/" + prefixed_knowledge_base),
755+
tokenization_progress,
756+
project_item.tokenizer_blank,
757+
s3.create_file_upload_link(org_id, project_id + "/" + prefixed_payload),
758+
]
759+
760+
container = client.containers.run(
761+
image=lf_exec_env_image,
762+
command=command,
763+
remove=True,
764+
detach=True,
765+
network=exec_env_network,
766+
)
767+
768+
container_logs = [
769+
line.decode("utf-8").strip("\n")
770+
for line in container.logs(
771+
stream=True, stdout=True, stderr=True, timestamps=True
772+
)
773+
]
774+
775+
code_has_errors = False
776+
777+
try:
778+
payload = s3.get_object(org_id, project_id + "/" + prefixed_payload)
779+
calculated_labels = json.loads(payload)
780+
except Exception:
781+
print("Could not grab data from s3 -- labeling function")
782+
code_has_errors = True
783+
calculated_labels = {}
784+
785+
if not prefixed_doc_bin == "docbin_full":
786+
# sample records docbin should be deleted after calculation
787+
s3.delete_object(org_id, project_id + "/" + prefixed_doc_bin)
788+
s3.delete_object(org_id, project_id + "/" + prefixed_function_name)
789+
s3.delete_object(org_id, project_id + "/" + prefixed_payload)
790+
s3.delete_object(org_id, project_id + "/" + prefixed_knowledge_base)
791+
792+
return calculated_labels, container_logs, code_has_errors
793+
794+
795+
def get_missing_columns_tokenization(project_id: str) -> str:
796+
missing_columns = [
797+
attribute_item.name
798+
for attribute_item in attribute.get_all(project_id)
799+
if attribute_item.data_type != enums.DataTypes.TEXT.value
800+
]
801+
missing_columns_str = ",\n".join(
802+
["'" + k + "',r.data->'" + k + "'" for k in missing_columns]
803+
)
804+
805+
return missing_columns_str

graphql_api/query/payload.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import graphene
22

33
from controller.auth import manager as auth
4-
from graphql_api.types import InformationSourcePayload
4+
from graphql_api.types import InformationSourcePayload, LabelingFunctionSampleRecords
55
from controller.payload import manager
66

77

@@ -13,9 +13,24 @@ class PayloadQuery(graphene.ObjectType):
1313
project_id=graphene.ID(required=True),
1414
)
1515

16+
get_labeling_function_on_10_records = graphene.Field(
17+
LabelingFunctionSampleRecords,
18+
project_id=graphene.ID(required=True),
19+
information_source_id=graphene.ID(required=True),
20+
)
21+
1622
def resolve_payload_by_payload_id(
1723
self, info, payload_id: str, project_id: str
1824
) -> InformationSourcePayload:
1925
auth.check_demo_access(info)
2026
auth.check_project_access(info, project_id)
2127
return manager.get_payload(project_id, payload_id)
28+
29+
def resolve_get_labeling_function_on_10_records(
30+
self, info, project_id: str, information_source_id: str
31+
) -> LabelingFunctionSampleRecords:
32+
auth.check_demo_access(info)
33+
auth.check_project_access(info, project_id)
34+
return manager.get_labeling_function_on_10_records(
35+
project_id, information_source_id
36+
)

graphql_api/types.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -750,3 +750,15 @@ class LastRunAttributesResult(graphene.ObjectType):
750750
class UserAttributeSampleRecordsResult(graphene.ObjectType):
751751
record_ids = graphene.List(graphene.ID)
752752
calculated_attributes = graphene.List(graphene.String)
753+
754+
755+
class LabelingFunctionSampleRecordWrapper(graphene.ObjectType):
756+
record_id = graphene.ID()
757+
calculated_labels = graphene.List(graphene.String)
758+
full_record_data = graphene.JSONString()
759+
760+
761+
class LabelingFunctionSampleRecords(graphene.ObjectType):
762+
records = graphene.List(LabelingFunctionSampleRecordWrapper)
763+
container_logs = graphene.List(graphene.String)
764+
code_has_errors = graphene.Boolean()

submodules/model

0 commit comments

Comments
 (0)