Skip to content

Commit 15b7b94

Browse files
Azure platform on embeddings (#145)
* Additional data field on the embeddings table * Added azure as embedding platform * Import/export additional data * Submodule dev change --------- Co-authored-by: JWittmeyer <[email protected]>
1 parent 76fb0af commit 15b7b94

File tree

7 files changed

+112
-64
lines changed

7 files changed

+112
-64
lines changed
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
"""Additional data on the embeddings table
2+
3+
Revision ID: 0714589d508e
4+
Revises: 73798599a917
5+
Create Date: 2023-08-03 13:21:49.464532
6+
7+
"""
8+
from alembic import op
9+
import sqlalchemy as sa
10+
11+
12+
# revision identifiers, used by Alembic.
13+
revision = '0714589d508e'
14+
down_revision = '73798599a917'
15+
branch_labels = None
16+
depends_on = None
17+
18+
19+
def upgrade():
20+
# ### commands auto generated by Alembic - please adjust! ###
21+
op.add_column('embedding', sa.Column('additional_data', sa.JSON(), nullable=True))
22+
# ### end Alembic commands ###
23+
24+
25+
def downgrade():
26+
# ### commands auto generated by Alembic - please adjust! ###
27+
op.drop_column('embedding', 'additional_data')
28+
# ### end Alembic commands ###

controller/embedding/manager.py

Lines changed: 44 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -43,25 +43,15 @@ def get_recommended_encoders(is_managed: bool) -> List[Any]:
4343
return recommendations
4444

4545

46+
def create_embedding(project_id: str, embedding_id: str) -> None:
47+
daemon.run(connector.request_embedding, project_id, embedding_id)
4648

47-
def create_embedding(
48-
project_id: str, embedding_id: str
49-
) -> None:
50-
daemon.run(
51-
connector.request_embedding,
52-
project_id,
53-
embedding_id
54-
)
5549

5650
def create_embeddings_one_by_one(
5751
project_id: str,
5852
embeddings_ids: List[str],
5953
) -> None:
60-
daemon.run(
61-
__embed_one_by_one_helper,
62-
project_id,
63-
embeddings_ids
64-
)
54+
daemon.run(__embed_one_by_one_helper, project_id, embeddings_ids)
6555

6656

6757
def request_tensor_upload(project_id: str, embedding_id: str) -> Any:
@@ -74,10 +64,7 @@ def delete_embedding(project_id: str, embedding_id: str) -> None:
7464
connector.request_deleting_embedding(project_id, embedding_id)
7565

7666

77-
def __embed_one_by_one_helper(
78-
project_id: str,
79-
embeddings_ids: List[str]
80-
) -> None:
67+
def __embed_one_by_one_helper(project_id: str, embeddings_ids: List[str]) -> None:
8168
for embedding_id in embeddings_ids:
8269
connector.request_embedding(project_id, embedding_id)
8370
time.sleep(5)
@@ -86,7 +73,12 @@ def __embed_one_by_one_helper(
8673

8774

8875
def get_embedding_name(
89-
project_id: str, attribute_id: str, platform: str, embedding_type: str, model: Optional[str] = None, apiToken: Optional[str] = None
76+
project_id: str,
77+
attribute_id: str,
78+
platform: str,
79+
embedding_type: str,
80+
model: Optional[str] = None,
81+
api_token: Optional[str] = None,
9082
) -> str:
9183
if embedding_type not in [
9284
enums.EmbeddingType.ON_ATTRIBUTE.value,
@@ -109,19 +101,21 @@ def get_embedding_name(
109101
if model:
110102
name += f"-{model}"
111103

112-
if apiToken:
113-
name += f"-{apiToken[:3]}...{apiToken[-4:]}"
114-
104+
if api_token:
105+
name += f"-{api_token[:3]}...{api_token[-4:]}"
106+
115107
return name
116108

117109

118-
def recreate_embeddings(project_id: str, embedding_ids: Optional[List[str]] = None) -> None:
110+
def recreate_embeddings(
111+
project_id: str, embedding_ids: Optional[List[str]] = None
112+
) -> None:
119113
if not embedding_ids:
120114
embeddings = embedding.get_all_embeddings_by_project_id(project_id)
121115
if len(embeddings) == 0:
122116
return
123117
embedding_ids = [str(embed.id) for embed in embeddings]
124-
118+
125119
set_to_wait = False
126120
for embedding_id in embedding_ids:
127121
set_to_wait = True
@@ -130,10 +124,10 @@ def recreate_embeddings(project_id: str, embedding_ids: Optional[List[str]] = No
130124

131125
if set_to_wait:
132126
notification.send_organization_update(
133-
project_id,
134-
f"embedding:{None}:state:{enums.EmbeddingState.WAITING.value}",
135-
)
136-
127+
project_id,
128+
f"embedding:{None}:state:{enums.EmbeddingState.WAITING.value}",
129+
)
130+
137131
for embedding_id in embedding_ids:
138132
new_id = None
139133
try:
@@ -147,13 +141,17 @@ def recreate_embeddings(project_id: str, embedding_ids: Optional[List[str]] = No
147141
embedding_item = general.refresh(embedding_item)
148142
if not embedding_item:
149143
raise Exception("Embedding not found")
150-
elif embedding_item.state == enums.EmbeddingState.FAILED.value or embedding_item.state == enums.EmbeddingState.FINISHED.value:
144+
elif (
145+
embedding_item.state == enums.EmbeddingState.FAILED.value
146+
or embedding_item.state == enums.EmbeddingState.FINISHED.value
147+
):
151148
break
152149
else:
153150
time.sleep(1)
154151
except Exception as e:
155152
print(
156-
f"Error while recreating embedding for {project_id} with id {embedding_id} - {e}", flush=True
153+
f"Error while recreating embedding for {project_id} with id {embedding_id} - {e}",
154+
flush=True,
157155
)
158156
notification.send_organization_update(
159157
project_id,
@@ -162,23 +160,19 @@ def recreate_embeddings(project_id: str, embedding_ids: Optional[List[str]] = No
162160
old_embedding_item = embedding.get(project_id, embedding_id)
163161
if old_embedding_item:
164162
old_embedding_item.state = enums.EmbeddingState.FAILED.value
165-
163+
166164
if new_id:
167165
new_embedding_item = embedding.get(project_id, new_id)
168166
if new_embedding_item:
169167
new_embedding_item.state = enums.EmbeddingState.FAILED.value
170168
general.commit()
171169

172-
173170
notification.send_organization_update(
174171
project_id=project_id, message="embedding:finished:all"
175172
)
176173

177174

178-
179-
def __recreate_embedding(
180-
project_id: str, embedding_id: str
181-
) -> Embedding:
175+
def __recreate_embedding(project_id: str, embedding_id: str) -> Embedding:
182176
old_embedding_item = embedding.get(project_id, embedding_id)
183177
old_id = old_embedding_item.id
184178
new_embedding_item = embedding.create(
@@ -191,25 +185,30 @@ def __recreate_embedding(
191185
model=old_embedding_item.model,
192186
platform=old_embedding_item.platform,
193187
api_token=old_embedding_item.api_token,
194-
with_commit=False
188+
additional_data=old_embedding_item.additional_data,
189+
with_commit=False,
195190
)
196191
embedding.delete(project_id, embedding_id, with_commit=False)
197192
embedding.delete_tensors(embedding_id, with_commit=False)
198193
general.commit()
199194

200-
if new_embedding_item.platform == enums.EmbeddingPlatform.OPENAI.value or new_embedding_item.platform == enums.EmbeddingPlatform.COHERE.value:
201-
agreement_item = agreement.get_by_xfkey(project_id, old_id, enums.AgreementType.EMBEDDING.value)
195+
if (
196+
new_embedding_item.platform == enums.EmbeddingPlatform.OPENAI.value
197+
or new_embedding_item.platform == enums.EmbeddingPlatform.COHERE.value
198+
or new_embedding_item.platform == enums.EmbeddingPlatform.AZURE.value
199+
):
200+
agreement_item = agreement.get_by_xfkey(
201+
project_id, old_id, enums.AgreementType.EMBEDDING.value
202+
)
202203
if not agreement_item:
203204
new_embedding_item.state = enums.EmbeddingState.FAILED.value
204205
general.commit()
205-
raise Exception(f"No agreement found for embedding {new_embedding_item.name}")
206+
raise Exception(
207+
f"No agreement found for embedding {new_embedding_item.name}"
208+
)
206209
agreement_item.xfkey = new_embedding_item.id
207210
general.commit()
208211

209212
connector.request_deleting_embedding(project_id, old_id)
210-
daemon.run(
211-
connector.request_embedding,
212-
project_id,
213-
new_embedding_item.id
214-
)
215-
return new_embedding_item
213+
daemon.run(connector.request_embedding, project_id, new_embedding_item.id)
214+
return new_embedding_item

controller/task_queue/handler/embedding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ def get_task_functions() -> Tuple[Callable, Callable, int]:
1818

1919

2020
def __start_task(task: Dict[str, Any]) -> bool:
21-
2221
# check task still relevant
2322
task_db_obj = task_queue_db_bo.get(task["id"])
2423
if task_db_obj is None or task_db_obj.is_active:
@@ -47,6 +46,7 @@ def __start_task(task: Dict[str, Any]) -> bool:
4746
terms_accepted = task["task_info"]["terms_accepted"]
4847

4948
filter_attributes = task["task_info"]["filter_attributes"]
49+
additional_data = task["task_info"]["additional_data"]
5050
embedding_item = embedding_db_bo.create(
5151
project_id,
5252
attribute_id,
@@ -58,6 +58,7 @@ def __start_task(task: Dict[str, Any]) -> bool:
5858
platform=platform,
5959
api_token=api_token,
6060
filter_attributes=filter_attributes,
61+
additional_data=additional_data,
6162
)
6263
if (
6364
platform == enums.EmbeddingPlatform.OPENAI.value
@@ -80,7 +81,6 @@ def __start_task(task: Dict[str, Any]) -> bool:
8081

8182

8283
def __check_finished(task: Dict[str, Any]) -> bool:
83-
8484
embedding_item = embedding_db_bo.get_embedding_by_name(
8585
task["project_id"], task["task_info"]["embedding_name"]
8686
)

controller/transfer/project_transfer_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,9 @@ def __transform_embedding_by_name(embedding_name: str):
397397
model=embedding_item.get(
398398
"model",
399399
),
400+
additional_data=embedding_item.get(
401+
"additional_data",
402+
),
400403
)
401404
embedding_ids[
402405
embedding_item.get(
@@ -1174,6 +1177,7 @@ def get_project_export_dump(
11741177
"finished_at": embedding_item.finished_at,
11751178
"platform": embedding_item.platform,
11761179
"model": embedding_item.model,
1180+
"additional_data": embedding_item.additional_data,
11771181
}
11781182
for embedding_item in embeddings
11791183
]

graphql_api/mutation/embedding.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,18 @@ def mutate(self, info, project_id: str, attribute_id: str, config: Dict[str, Any
3131
api_token = config.get("apiToken")
3232
terms_text = config.get("termsText")
3333
terms_accepted = config.get("termsAccepted")
34+
additional_data = None
35+
if config.get("base") is not None:
36+
additional_data = {
37+
"base": config.get("base"),
38+
"type": config.get("type"),
39+
"version": config.get("version"),
40+
}
3441

3542
# prototyping logic, this will be part of config after ui integration
3643
relevant_attribute_list = attribute_do.get_all_possible_names_for_qdrant(
3744
project_id
3845
)
39-
4046
task_queue_manager.add_task(
4147
project_id,
4248
TaskType.EMBEDDING,
@@ -53,6 +59,7 @@ def mutate(self, info, project_id: str, attribute_id: str, config: Dict[str, Any
5359
"terms_text": terms_text,
5460
"terms_accepted": terms_accepted,
5561
"filter_attributes": relevant_attribute_list,
62+
"additional_data": additional_data,
5663
},
5764
)
5865
notification.send_organization_update(

graphql_api/query/embedding.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,12 @@
44

55
from controller.misc import manager as misc
66
from controller.auth import manager as auth
7-
from graphql_api.types import EmbeddingPlatform, Encoder, LanguageModel, RecordTokenizationTask
7+
from graphql_api.types import (
8+
EmbeddingPlatform,
9+
Encoder,
10+
LanguageModel,
11+
RecordTokenizationTask,
12+
)
813
from submodules.model import enums
914
from submodules.model.business_objects import tokenization, task_queue
1015
from util import spacy_util
@@ -57,23 +62,28 @@ def resolve_project_tokenization(
5762
def resolve_embedding_platforms(self, info) -> List[EmbeddingPlatform]:
5863
return [
5964
{
60-
"platform": enums.EmbeddingPlatform.HUGGINGFACE.value,
61-
"terms": None,
62-
"link": None
65+
"platform": enums.EmbeddingPlatform.HUGGINGFACE.value,
66+
"terms": None,
67+
"link": None,
6368
},
6469
{
65-
"platform": enums.EmbeddingPlatform.COHERE.value,
66-
"terms": "Please note that by enabling this third-party API, you are stating that you accept its addition as a sub-processor under the terms of our Data Processing Agreement. Please be aware that the Cohere API policies may conflict with your internal data and privacy policies. For more information please check: @@PLACEHOLDER@@. For questions you can contact us at [email protected].",
67-
"link": "https://openai.com/policies/api-data-usage-policies"
68-
},
70+
"platform": enums.EmbeddingPlatform.COHERE.value,
71+
"terms": "Please note that by enabling this third-party API, you are stating that you accept its addition as a sub-processor under the terms of our Data Processing Agreement. Please be aware that the Cohere API policies may conflict with your internal data and privacy policies. For more information please check: @@PLACEHOLDER@@. For questions you can contact us at [email protected].",
72+
"link": "https://cohere.com/terms-of-use",
73+
},
74+
{
75+
"platform": enums.EmbeddingPlatform.OPENAI.value,
76+
"terms": "Please note that by enabling this third-party API, you are stating that you accept its addition as a sub-processor under the terms of our Data Processing Agreement. Please be aware that the OpenAI API policies may conflict with your internal data and privacy policies. For more information please check: @@PLACEHOLDER@@. For questions you can contact us at [email protected].",
77+
"link": "https://openai.com/policies/api-data-usage-policies",
78+
},
6979
{
70-
"platform": enums.EmbeddingPlatform.OPENAI.value,
71-
"terms": "Please note that by enabling this third-party API, you are stating that you accept its addition as a sub-processor under the terms of our Data Processing Agreement. Please be aware that the OpenAI API policies may conflict with your internal data and privacy policies. For more information please check: @@PLACEHOLDER@@. For questions you can contact us at [email protected].",
72-
"link": "https://openai.com/policies/api-data-usage-policies"
80+
"platform": enums.EmbeddingPlatform.PYTHON.value,
81+
"terms": None,
82+
"link": None,
7383
},
7484
{
75-
"platform": enums.EmbeddingPlatform.PYTHON.value,
76-
"terms": None,
77-
"link": None
85+
"platform": enums.EmbeddingPlatform.AZURE.value,
86+
"terms": "Please note that by enabling this third-party API, you are stating that you accept its addition as a sub-processor under the terms of our Data Processing Agreement. Please be aware that the Azure API policies may conflict with your internal data and privacy policies. For more information please check: @@PLACEHOLDER@@. For questions you can contact us at [email protected].",
87+
"link": "https://www.microsoft.com/en-us/legal/terms-of-use",
7888
},
79-
]
89+
]

submodules/model

0 commit comments

Comments
 (0)