Skip to content

Commit c40c7a0

Browse files
Pause and cancel for task master (#123)
* cancel embedding * model * model * model merge * model update
1 parent d8c1922 commit c40c7a0

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

controller.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ def run_encoding(
288288
send_project_update(project_id, f"notification_created:{user_id}", True)
289289
embedding.delete_tensors(embedding_id, with_commit=True)
290290
chunk = 0
291+
embedding_canceled = False
291292
for pair in generate_batches(
292293
project_id,
293294
record_ids,
@@ -301,9 +302,19 @@ def run_encoding(
301302

302303
record_ids_batched = pair["record_ids"]
303304
attribute_values_encoded_batch = pair["embeddings"]
304-
if not embedding.get(project_id, embedding_id):
305+
embedding_entity = embedding.get(project_id, embedding_id)
306+
if not embedding_entity:
305307
logger.info(f"Aborted {embedding_name}")
306308
break
309+
elif embedding_entity.state == enums.EmbeddingState.FAILED.value:
310+
embedding_canceled = True
311+
send_project_update(
312+
project_id,
313+
f"embedding:{embedding_id}:state:{enums.EmbeddingState.FAILED.value}",
314+
)
315+
logger.info(f"Canceled {embedding_name}")
316+
break
317+
307318
embedding.create_tensors(
308319
project_id,
309320
embedding_id,
@@ -401,7 +412,7 @@ def run_encoding(
401412
doc_ock.post_embedding_failed(user_id, f"{model}-{platform}")
402413
return status.HTTP_500_INTERNAL_SERVER_ERROR
403414

404-
if embedding.get(project_id, embedding_id):
415+
if embedding.get(project_id, embedding_id) and not embedding_canceled:
405416
for warning_type, idx_list in embedder.get_warnings().items():
406417
# use last record with warning as example
407418
example_record_id = record_ids[idx_list[-1]]

0 commit comments

Comments
 (0)