Skip to content

Commit 758598c

Browse files
committed
Initial commit to speed up submission
1 parent 7bd98c8 commit 758598c

File tree

4 files changed

+88
-28
lines changed

4 files changed

+88
-28
lines changed

src/service/core/workflow/objects.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import collections
2020
import datetime
21+
import json
2122
import math
2223
from typing import Any, Dict, List, NamedTuple, Optional, Protocol, Set
2324
import yaml
@@ -1040,6 +1041,7 @@ def send_submit_workflow_to_queue(self,
10401041

10411042
# Write workflow and group objects to the database
10421043
workflow_obj.insert_to_db()
1044+
task_entries: list[tuple] = []
10431045
for group_obj in workflow_obj.groups:
10441046
group_obj.workflow_id_internal = workflow_obj.workflow_id
10451047
group_obj.spec = \
@@ -1048,13 +1050,23 @@ def send_submit_workflow_to_queue(self,
10481050
group_obj.insert_to_db()
10491051
for task_obj, task_obj_spec in zip(group_obj.tasks, group_obj.spec.tasks):
10501052
task_obj.workflow_id_internal = workflow_obj.workflow_id
1051-
task_obj.insert_to_db(
1052-
gpu_count=task_obj_spec.resources.gpu or 0,
1053-
cpu_count=task_obj_spec.resources.cpu or 0,
1054-
disk_count=common.convert_resource_value_str(
1053+
workflow_uuid = task_obj.workflow_uuid if task_obj.workflow_uuid else ''
1054+
task_entries.append((
1055+
task_obj.workflow_id_internal, task_obj.name, task_obj.group_name,
1056+
task_obj.task_db_key, task_obj.retry_id, task_obj.task_uuid,
1057+
task.TaskGroupStatus.WAITING.name,
1058+
kb_objects.construct_pod_name(workflow_uuid, task_obj.task_uuid),
1059+
None,
1060+
task_obj_spec.resources.gpu or 0,
1061+
task_obj_spec.resources.cpu or 0,
1062+
common.convert_resource_value_str(
10551063
task_obj_spec.resources.storage or '0', 'GiB'),
1056-
memory_count=common.convert_resource_value_str(
1057-
task_obj_spec.resources.memory or '0', 'GiB'))
1064+
common.convert_resource_value_str(
1065+
task_obj_spec.resources.memory or '0', 'GiB'),
1066+
json.dumps(task_obj.exit_actions, default=common.pydantic_encoder),
1067+
task_obj.lead,
1068+
))
1069+
task.Task.batch_insert_to_db(postgres, task_entries)
10581070

10591071
logs = f'{service_url}/api/workflow/{workflow_obj.workflow_id}/logs'
10601072
context = WorkflowServiceContext.get()

src/utils/job/jobs.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ def execute(self, context: JobExecutionContext,
222222

223223
self.workflow_id = workflow_obj.workflow_id
224224

225+
task_entries: list[tuple] = []
225226
for group_obj in workflow_obj.groups:
226227
group_obj.workflow_id_internal = workflow_obj.workflow_id
227228
group_obj.spec = \
@@ -230,19 +231,23 @@ def execute(self, context: JobExecutionContext,
230231
group_obj.insert_to_db()
231232
for task_obj, task_obj_spec in zip(group_obj.tasks, group_obj.spec.tasks):
232233
task_obj.workflow_id_internal = workflow_obj.workflow_id
233-
task_obj.insert_to_db(
234-
gpu_count=task_obj_spec.resources.gpu or 0,
235-
cpu_count=task_obj_spec.resources.cpu or 0,
236-
disk_count=common.convert_resource_value_str(
234+
workflow_uuid = task_obj.workflow_uuid if task_obj.workflow_uuid else ''
235+
task_entries.append((
236+
task_obj.workflow_id_internal, task_obj.name, task_obj.group_name,
237+
task_obj.task_db_key, task_obj.retry_id, task_obj.task_uuid,
238+
task.TaskGroupStatus.WAITING.name,
239+
kb_objects.construct_pod_name(workflow_uuid, task_obj.task_uuid),
240+
None,
241+
task_obj_spec.resources.gpu or 0,
242+
task_obj_spec.resources.cpu or 0,
243+
common.convert_resource_value_str(
237244
task_obj_spec.resources.storage or '0', 'GiB'),
238-
memory_count=common.convert_resource_value_str(
239-
task_obj_spec.resources.memory or '0', 'GiB'))
240-
241-
current_timestamp = datetime.datetime.now()
242-
time_elapsed = last_timestamp - current_timestamp
243-
if time_elapsed > progress_iter_freq:
244-
progress_writer.report_progress()
245-
last_timestamp = current_timestamp
245+
common.convert_resource_value_str(
246+
task_obj_spec.resources.memory or '0', 'GiB'),
247+
json.dumps(task_obj.exit_actions, default=common.pydantic_encoder),
248+
task_obj.lead,
249+
))
250+
task.Task.batch_insert_to_db(postgres, task_entries)
246251
progress_writer.report_progress()
247252

248253
# Fetch workflow_obj to get latest info

src/utils/job/task.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1041,6 +1041,42 @@ class Task(pydantic.BaseModel):
10411041
class Config:
10421042
arbitrary_types_allowed = True
10431043

1044+
@staticmethod
1045+
def batch_insert_to_db(
1046+
database: connectors.PostgresConnector,
1047+
task_entries: List[Tuple],
1048+
batch_size: int = 100,
1049+
):
1050+
"""Batch-insert multiple tasks in a single query.
1051+
1052+
Args:
1053+
database: The Postgres connector instance.
1054+
task_entries: List of tuples, each containing the full set of
1055+
column values for a single task row (same order as insert_to_db).
1056+
batch_size: Maximum number of rows per INSERT statement.
1057+
"""
1058+
if not task_entries:
1059+
return
1060+
1061+
for i in range(0, len(task_entries), batch_size):
1062+
chunk = task_entries[i:i + batch_size]
1063+
values_clause = ','.join(
1064+
['(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)']
1065+
* len(chunk)
1066+
)
1067+
flat_args: List[Any] = []
1068+
for entry in chunk:
1069+
flat_args.extend(entry)
1070+
1071+
insert_cmd = f'''
1072+
INSERT INTO tasks
1073+
(workflow_id, name, group_name, task_db_key, retry_id, task_uuid,
1074+
status, pod_name, failure_message, gpu_count, cpu_count,
1075+
disk_count, memory_count, exit_actions, lead)
1076+
VALUES {values_clause} ON CONFLICT DO NOTHING;
1077+
'''
1078+
database.execute_commit_command(insert_cmd, tuple(flat_args))
1079+
10441080
def insert_to_db(self, gpu_count: float, cpu_count: float, disk_count: float,
10451081
memory_count: float, status: TaskGroupStatus = TaskGroupStatus.WAITING,
10461082
failure_message: str | None = None):

src/utils/job/workflow.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,10 @@ def validate_credentials(self, user: str):
530530
workflow_config = database.get_workflow_configs()
531531
dataset_config = database.get_dataset_configs()
532532
image_hash_map: Dict[str, str] = {}
533+
default_user_bucket = connectors.UserProfile.fetch_from_db(database, user).bucket
534+
default_service_bucket = dataset_config.default_bucket
535+
user_creds = database.get_all_data_creds(user)
536+
generic_cred_cache: Dict[str, Any] = {}
533537
for group in self.groups:
534538
for group_task in group.tasks:
535539
response = self.validate_registry(
@@ -562,13 +566,18 @@ def validate_credentials(self, user: str):
562566
self.validate_data(
563567
user, dataset_config, group_task, seen_data_input,
564568
seen_data_output, workflow_config.credential_config.disable_data_validation,
565-
seen_bucket_input, seen_bucket_output)
566-
self.validate_generic_cred(user, database, group_task)
569+
seen_bucket_input, seen_bucket_output,
570+
default_user_bucket, default_service_bucket, user_creds)
571+
self.validate_generic_cred(user, database, group_task,
572+
generic_cred_cache)
567573

568574
def validate_generic_cred(self, user: str, database: connectors.PostgresConnector,
569-
group_task: task.TaskSpec):
575+
group_task: task.TaskSpec,
576+
generic_cred_cache: Dict[str, Any]):
570577
for cred_name, cred_map in group_task.credentials.items():
571-
payload = database.get_generic_cred(user, cred_name)
578+
if cred_name not in generic_cred_cache:
579+
generic_cred_cache[cred_name] = database.get_generic_cred(user, cred_name)
580+
payload = generic_cred_cache[cred_name]
572581
if isinstance(cred_map, str):
573582
continue
574583
elif isinstance(cred_map, Dict):
@@ -622,12 +631,10 @@ def validate_registry(self, user: str,
622631
def validate_data(self, user: str, dataset_config: connectors.DatasetConfig,
623632
group_task: task.TaskSpec, seen_uri_input: Set[str],
624633
seen_uri_output: Set[str], disabled_data: List[str],
625-
seen_bucket_input: Set[str], seen_bucket_output: Set[str]):
626-
627-
postgres = connectors.PostgresConnector.get_instance()
628-
default_user_bucket = connectors.UserProfile.fetch_from_db(postgres, user).bucket
629-
default_service_bucket = postgres.get_dataset_configs().default_bucket
630-
user_creds = postgres.get_all_data_creds(user)
634+
seen_bucket_input: Set[str], seen_bucket_output: Set[str],
635+
default_user_bucket: Optional[str],
636+
default_service_bucket: str,
637+
user_creds: Dict[str, Any]):
631638

632639
def _validate_input_output(data_spec: Union[task.InputType, task.OutputType, task.TaskKPI],
633640
is_input: bool):

0 commit comments

Comments
 (0)