Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions libs/labelbox/src/labelbox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,4 @@
PromptResponseClassification,
)
from lbox.exceptions import *
from labelbox.schema.taskstatus import TaskStatus
48 changes: 46 additions & 2 deletions libs/labelbox/src/labelbox/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
from labelbox.schema.slice import CatalogSlice, ModelSlice
from labelbox.schema.task import DataUpsertTask, Task
from labelbox.schema.user import User
from labelbox.schema.taskstatus import TaskStatus

logger = logging.getLogger(__name__)

Expand All @@ -90,6 +91,9 @@ class Client:
top-level data objects (Projects, Datasets).
"""

# Class variable to cache task types
_cancelable_task_types = None

def __init__(
self,
api_key=None,
Expand Down Expand Up @@ -2390,9 +2394,31 @@ def get_task_by_id(self, task_id: str) -> Union[Task, DataUpsertTask]:
task._user = user
return task

def _get_cancelable_task_types(self):
"""Internal method that returns a list of task types that can be canceled.

The result is cached after the first call to avoid unnecessary API requests.

Returns:
List[str]: List of cancelable task types in snake_case format
"""
if self._cancelable_task_types is None:
query = """query GetCancelableTaskTypes {
cancelableTaskTypes
}"""

result = self.execute(query).get("cancelableTaskTypes", [])
# Reformat to kebab case
self._cancelable_task_types = [
utils.snake_case(task_type).replace("_", "-")
for task_type in result
]

return self._cancelable_task_types

def cancel_task(self, task_id: str) -> bool:
"""
Cancels a task with the given ID.
Cancels a task with the given ID if the task type is cancelable and the task is in progress.

Args:
task_id (str): The ID of the task to cancel.
Expand All @@ -2401,8 +2427,26 @@ def cancel_task(self, task_id: str) -> bool:
bool: True if the task was successfully cancelled.

Raises:
LabelboxError: If the task could not be cancelled.
LabelboxError: If the task could not be cancelled, if the task type is not cancelable,
or if the task is not in progress.
ResourceNotFoundError: If the task does not exist (raised by get_task_by_id).
"""
# Get the task object to check its type and status
task = self.get_task_by_id(task_id)

# Check if task type is cancelable
cancelable_types = self._get_cancelable_task_types()
if task.type not in cancelable_types:
raise LabelboxError(
f"Task type '{task.type}' cannot be cancelled. Cancelable types are: {cancelable_types}"
)

# Check if task is in progress
if task.status_as_enum != TaskStatus.In_Progress:
raise LabelboxError(
f"Task cannot be cancelled because it is not in progress. Current status: {task.status}"
)

mutation_str = """
mutation CancelTaskPyApi($id: ID!) {
cancelBulkOperationJob(id: $id) {
Expand Down
1 change: 1 addition & 0 deletions libs/labelbox/src/labelbox/schema/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@
import labelbox.schema.catalog
import labelbox.schema.ontology_kind
import labelbox.schema.project_overview
import labelbox.schema.taskstatus
1 change: 1 addition & 0 deletions libs/labelbox/src/labelbox/schema/organization.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(self, *args, **kwargs):
projects = Relationship.ToMany("Project", True)
webhooks = Relationship.ToMany("Webhook", False)
resource_tags = Relationship.ToMany("ResourceTags", False)
tasks = Relationship.ToMany("Task", False, "tasks")

def invite_user(
self,
Expand Down
4 changes: 4 additions & 0 deletions libs/labelbox/src/labelbox/schema/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from labelbox.schema.internal.datarow_upload_constants import (
DOWNLOAD_RESULT_PAGE_SIZE,
)
from labelbox.schema.taskstatus import TaskStatus

if TYPE_CHECKING:
from labelbox import User
Expand Down Expand Up @@ -45,6 +46,9 @@ class Task(DbObject):
created_at = Field.DateTime("created_at")
name = Field.String("name")
status = Field.String("status")
status_as_enum = Field.Enum(
TaskStatus, "status_as_enum", "status"
) # additional status for filtering
completion_percentage = Field.Float("completion_percentage")
result_url = Field.String("result_url", "result")
errors_url = Field.String("errors_url", "errors")
Expand Down
25 changes: 25 additions & 0 deletions libs/labelbox/src/labelbox/schema/taskstatus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from enum import Enum


class TaskStatus(str, Enum):
In_Progress = "IN_PROGRESS"
Complete = "COMPLETE"
Canceling = "CANCELLING"
Canceled = "CANCELED"
Failed = "FAILED"
Unknown = "UNKNOWN"

@classmethod
def _missing_(cls, value):
"""Handle missing or unknown task status values.
If a task status value is not found in the enum, this method returns
the Unknown status instead of raising an error.
Args:
value: The status value that doesn't match any enum member
Returns:
TaskStatus.Unknown: The default status for unrecognized values
"""
return cls.Unknown
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import time

from labelbox import DataRow, ExportTask, StreamType
from labelbox import DataRow, ExportTask, StreamType, Task, TaskStatus


class TestExportDataRow:
Expand Down Expand Up @@ -135,3 +135,33 @@ def test_cancel_export_task(
# Verify the task was cancelled
cancelled_task = client.get_task_by_id(export_task.uid)
assert cancelled_task.status in ["CANCELING", "CANCELED"]

def test_task_filter(self, client, data_row, wait_for_data_row_processing):
organization = client.get_organization()
user = client.get_user()

export_task = DataRow.export(
client=client,
data_rows=[data_row],
task_name="TestExportDataRow:test_task_filter",
)

# Check if task is listed "in progress" in organization's tasks
org_tasks_in_progress = organization.tasks(
where=Task.status_as_enum == TaskStatus.In_Progress
)
retrieved_task_in_progress = next(
(t for t in org_tasks_in_progress if t.uid == export_task.uid), ""
)
assert getattr(retrieved_task_in_progress, "uid", "") == export_task.uid

export_task.wait_till_done()

# Check if task is listed "complete" in user's created tasks
user_tasks_complete = user.created_tasks(
where=Task.status_as_enum == TaskStatus.Complete
)
retrieved_task_complete = next(
(t for t in user_tasks_complete if t.uid == export_task.uid), ""
)
assert getattr(retrieved_task_complete, "uid", "") == export_task.uid
Loading