diff --git a/tests/test_web/test_webapi.py b/tests/test_web/test_webapi.py index 2f63da9d20..db25d63964 100644 --- a/tests/test_web/test_webapi.py +++ b/tests/test_web/test_webapi.py @@ -1,6 +1,8 @@ # Tests webapi and things that depend on it from __future__ import annotations +import json + import numpy as np import pytest import responses @@ -296,6 +298,165 @@ def mock_get_run_info(monkeypatch, set_api_key): ) +@pytest.fixture +def mock_batch_upload_single(monkeypatch, set_api_key): + """Mocks batch upload endpoint for single task.""" + # Mock folder retrieval + responses.add( + responses.GET, + f"{Env.current.web_api_endpoint}/tidy3d/project", + match=[matchers.query_param_matcher({"projectName": PROJECT_NAME})], + json={"data": {"projectId": FOLDER_ID, "projectName": PROJECT_NAME}}, + status=200, + ) + + # mock batch endpoint - returns single task + def batch_request_matcher(request): + json_data = json.loads(request.body) + assert "tasks" in json_data + assert "batchType" in json_data + assert "groupName" in json_data + for task in json_data["tasks"]: + assert "groupName" in task + assert task["groupName"] == json_data["groupName"] + return True, None + + responses.add( + responses.POST, + f"{Env.current.web_api_endpoint}/tidy3d/projects/{FOLDER_ID}/batch-tasks", + match=[batch_request_matcher], + json={"batchId": "batch_123", "tasks": [{"taskId": "task_id_0", "taskName": "task_0"}]}, + status=200, + ) + + # mock task detail endpoints for the single task + responses.add( + responses.GET, + f"{Env.current.web_api_endpoint}/tidy3d/tasks/task_id_0", + json={ + "data": { + "taskId": "task_id_0", + "taskName": "task_0", + "createdAt": CREATED_AT, + "fileType": "Gz", + "resourcePath": "output/task_id_0.json", + "solverVersion": None, + "taskType": TaskType.FDTD.name, + } + }, + status=200, + ) + + responses.add( + responses.GET, + f"{Env.current.web_api_endpoint}/tidy3d/tasks/task_id_0/detail", + json={ + "data": { + "taskId": "task_id_0", + "taskName": "task_0", + "createdAt": CREATED_AT, + "realFlexUnit": FLEX_UNIT, + "estFlexUnit": EST_FLEX_UNIT, + "taskType": TaskType.FDTD.name, + "metadataStatus": "processed", + "status": "draft", + "s3Storage": 1.0, + } + }, + status=200, + ) + + def mock_upload_file(*args, **kwargs): + pass + + monkeypatch.setattr("tidy3d.web.core.task_core.upload_file", mock_upload_file) + + +@pytest.fixture +def mock_batch_upload_triple(monkeypatch, set_api_key): + """Mocks batch upload endpoint for three tasks.""" + # Mock folder retrieval + responses.add( + responses.GET, + f"{Env.current.web_api_endpoint}/tidy3d/project", + match=[matchers.query_param_matcher({"projectName": PROJECT_NAME})], + json={"data": {"projectId": FOLDER_ID, "projectName": PROJECT_NAME}}, + status=200, + ) + + def batch_request_matcher(request): + import json + + json_data = json.loads(request.body) + assert "tasks" in json_data + assert "batchType" in json_data + assert "groupName" in json_data + for task in json_data["tasks"]: + assert "groupName" in task + assert task["groupName"] == json_data["groupName"] + return True, None + + responses.add( + responses.POST, + f"{Env.current.web_api_endpoint}/tidy3d/projects/{FOLDER_ID}/batch-tasks", + match=[batch_request_matcher], + json={ + "batchId": "batch_123", + "tasks": [ + {"taskId": "task_id_0", "taskName": "task_0"}, + {"taskId": "task_id_1", "taskName": "task_1"}, + {"taskId": "task_id_2", "taskName": "task_2"}, + ], + }, + status=200, + ) + + for i in range(3): + task_name = f"task_{i}" + task_id = f"task_id_{i}" + + responses.add( + responses.GET, + f"{Env.current.web_api_endpoint}/tidy3d/tasks/{task_id}", + json={ + "data": { + "taskId": task_id, + "taskName": task_name, + "createdAt": CREATED_AT, + "fileType": "Gz", + "resourcePath": f"output/{task_id}.json", + "solverVersion": None, + "taskType": TaskType.FDTD.name, + } + }, + status=200, + ) + + responses.add( + responses.GET, + f"{Env.current.web_api_endpoint}/tidy3d/tasks/{task_id}/detail", + json={ + "data": { + "taskId": task_id, + "taskName": task_name, + "createdAt": CREATED_AT, + "realFlexUnit": FLEX_UNIT, + "estFlexUnit": EST_FLEX_UNIT, + "taskType": TaskType.FDTD.name, + "metadataStatus": "processed", + "status": "draft", + "s3Storage": 1.0, + } + }, + status=200, + ) + + def mock_upload_file(*args, **kwargs): + pass + + monkeypatch.setattr("tidy3d.web.core.task_core.upload_file", mock_upload_file) + + @pytest.fixture def mock_webapi( mock_upload, mock_metadata, mock_get_info, mock_start, mock_monitor, mock_download, mock_load @@ -628,6 +789,86 @@ def test_batch(mock_webapi, mock_job_status, mock_load, tmp_path): assert b2.real_cost() == FLEX_UNIT * len(sims) +@responses.activate +def test_batch_with_endpoint(mock_batch_upload_triple, tmp_path): + """Test batch with new batch endpoint.""" + + sims = {f"task_{i}": make_sim() for i in range(3)} + + batch = Batch(simulations=sims, folder_name=PROJECT_NAME, use_batch_endpoint=True) + + assert batch.use_batch_endpoint is True + + # access jobs property to trigger batch submission + jobs = batch.jobs + + # verify jobs were created with pre-assigned task_ids + assert len(jobs) == 3 + for i, (task_name, job) in enumerate(jobs.items()): + assert task_name == f"task_{i}" + assert job.task_id == f"task_id_{i}" + assert job._cached_properties.get("task_id") == f"task_id_{i}" + + # test serialization preserves the flag + fname = str(tmp_path / "batch_endpoint.json") + batch.to_file(fname) + batch_loaded = Batch.from_file(fname) + + assert batch_loaded.use_batch_endpoint is True + assert len(batch_loaded.jobs) == 3 + + +@responses.activate +def test_batch_backward_compatibility(mock_webapi, mock_job_status, mock_load, tmp_path): + """Test that default behavior remains unchanged (backward compatibility).""" + sims = {TASK_NAME: make_sim()} + + # Create batch without specifying use_batch_endpoint (should default to False) + batch = Batch(simulations=sims, folder_name=PROJECT_NAME) + + # Verify default is False + assert batch.use_batch_endpoint is False + + # Access jobs to trigger normal flow + jobs = batch.jobs + assert len(jobs) == 1 + + # Run and verify it works as before + batch.run(path_dir=str(tmp_path)) + assert batch.real_cost() == FLEX_UNIT * len(sims) + + +@responses.activate +def test_batch_endpoint_integration( + mock_batch_upload_single, mock_webapi, mock_job_status, tmp_path +): + """Test both batch endpoint modes produce compatible results.""" + sim = make_sim() + + # old way + batch_old = Batch( + simulations={"task_0": sim}, folder_name=PROJECT_NAME, use_batch_endpoint=False + ) + fname_old = str(tmp_path / "batch_old.json") + batch_old.to_file(fname_old) + + # new way + batch_new = Batch( + simulations={"task_0": sim}, folder_name=PROJECT_NAME, use_batch_endpoint=True + ) + fname_new = str(tmp_path / "batch_new.json") + batch_new.to_file(fname_new) + + # load and verify both work + batch_old_loaded = Batch.from_file(fname_old) + batch_new_loaded = Batch.from_file(fname_new) + + assert batch_old_loaded.use_batch_endpoint is False + assert batch_new_loaded.use_batch_endpoint is True + assert len(batch_old_loaded.jobs) == 1 + assert len(batch_new_loaded.jobs) == 1 + + @responses.activate def test_create_output_dirs(mock_webapi, tmp_path, monkeypatch): """Test that Job and Batch create output directories if they don't exist.""" diff --git a/tidy3d/web/api/container.py b/tidy3d/web/api/container.py index fb5b1915bb..0db19077cc 100644 --- a/tidy3d/web/api/container.py +++ b/tidy3d/web/api/container.py @@ -22,7 +22,7 @@ from tidy3d.web.core.constants import TaskId, TaskName from tidy3d.web.core.task_core import Folder from tidy3d.web.core.task_info import RunInfo, TaskInfo -from tidy3d.web.core.types import PayType +from tidy3d.web.core.types import BatchType, PayType from .tidy3d_stub import SimulationDataType, SimulationType @@ -560,6 +560,19 @@ class Batch(WebContainer): description="Specify the payment method.", ) + use_batch_endpoint: bool = pd.Field( + False, + title="Use Batch Endpoint", + description="Use new batch submission endpoint for improved performance. " + "When True, submits all simulations in a single API call.", + ) + + batch_type: BatchType = pd.Field( + BatchType.DEFAULT, + title="Batch Type", + description="Internal batch type for server-side optimization.", + ) + jobs_cached: dict[TaskName, Job] = pd.Field( None, title="Jobs (Cached)", @@ -617,6 +630,16 @@ def jobs(self) -> dict[TaskName, Job]: if self.jobs_cached is not None: return self.jobs_cached + if self.use_batch_endpoint: + jobs = self._create_jobs_from_batch() + # store in cache to avoid re-submitting + object.__setattr__(self, "jobs_cached", jobs) + return jobs + + return self._create_individual_jobs() + + def _create_individual_jobs(self) -> dict[TaskName, Job]: + """Create jobs individually using the existing approach.""" # the type of job to upload (to generalize to subclasses) JobType = self._job_type self_dict = self.dict() @@ -641,6 +664,53 @@ def jobs(self) -> dict[TaskName, Job]: jobs[task_name] = job return jobs + def _submit_batch(self) -> tuple[str, dict[str, str]]: + """Submit all simulations using the batch endpoint.""" + return web.upload_batch( + simulations=self.simulations, + folder_name=self.folder_name, + callback_url=self.callback_url, + verbose=self.verbose, + simulation_type=self.simulation_type, + parent_tasks=self.parent_tasks, + source_required=True, + solver_version=self.solver_version, + reduce_simulation=self.reduce_simulation, + batch_type=self.batch_type.value, + ) + + def _create_jobs_from_batch(self) -> dict[TaskName, Job]: + """Create jobs from batch submission response.""" + batch_id, task_ids = self._submit_batch() + + # create Job objects with pre-assigned task_ids + JobType = self._job_type + self_dict = self.dict() + + jobs = {} + for task_name, task_id in task_ids.items(): + job_kwargs = {} + + for key in JobType._upload_fields: + if key in self_dict: + job_kwargs[key] = self_dict.get(key) + + job_kwargs["task_name"] = task_name + job_kwargs["simulation"] = self.simulations[task_name] + job_kwargs["verbose"] = False + job_kwargs["solver_version"] = self.solver_version + job_kwargs["pay_type"] = self.pay_type + job_kwargs["reduce_simulation"] = self.reduce_simulation + job_kwargs["task_id_cached"] = task_id + + if self.parent_tasks and task_name in self.parent_tasks: + job_kwargs["parent_tasks"] = self.parent_tasks[task_name] + + job = JobType(**job_kwargs) + jobs[task_name] = job + + return jobs + def to_file(self, fname: str) -> None: """Exports :class:`Tidy3dBaseModel` instance to .yaml, .json, or .hdf5 file @@ -709,12 +779,23 @@ def start(self) -> None: Note ---- + Current implementation starts each job individually using ThreadPoolExecutor. + This could be enhanced with a dedicated batch start endpoint when using + the batch endpoint flow (use_batch_endpoint=True). + + Future enhancement for batch endpoint mode: + - Could use BatchTask.submit() for atomic batch submission + - Benefits: Single API call, server-side coordination, better error handling + - Would eliminate need for client-side thread pool management + To monitor the running simulations, can call :meth:`Batch.monitor`. """ if self.verbose: console = get_logging_console() console.log(f"Started working on Batch containing {self.num_jobs} tasks.") + # TODO: For batch endpoint mode (use_batch_endpoint=True), consider using + # BatchTask.submit() for atomic batch submission instead of ThreadPoolExecutor with ThreadPoolExecutor(max_workers=self.num_workers) as executor: for _, job in self.jobs.items(): executor.submit(job.start) @@ -734,7 +815,19 @@ def get_run_info(self) -> dict[TaskName, RunInfo]: return run_info_dict def monitor(self) -> None: - """Monitor progress of each of the running tasks.""" + """Monitor progress of each of the running tasks. + + Note + ---- + Current implementation monitors each task individually. For batch endpoint mode + (use_batch_endpoint=True), this could be enhanced with batch monitoring endpoints + for better performance. + + Future enhancement: + - Batch Status Endpoint: GET /tidy3d/projects/{folder_id}/batch/{batch_id}/status + - Benefits: Single API call for all task statuses, reduced polling overhead + - Could provide aggregate progress information and batch-level status + """ def pbar_description( task_name: str, status: str, max_name_length: int, status_width: int diff --git a/tidy3d/web/api/webapi.py b/tidy3d/web/api/webapi.py index c45012499b..d1dad464c5 100644 --- a/tidy3d/web/api/webapi.py +++ b/tidy3d/web/api/webapi.py @@ -26,9 +26,9 @@ TaskId, ) from tidy3d.web.core.environment import Env -from tidy3d.web.core.task_core import Folder, SimulationTask +from tidy3d.web.core.task_core import BatchTask, Folder, Task from tidy3d.web.core.task_info import ChargeType, TaskInfo -from tidy3d.web.core.types import PayType +from tidy3d.web.core.types import BatchType, PayType from .connect_util import REFRESH_TIME, get_grid_points_str, get_time_steps_str, wait_for_connection from .tidy3d_stub import SimulationDataType, SimulationType, Tidy3dStub, Tidy3dStubData @@ -263,7 +263,7 @@ def upload( task_type = stub.get_type() - task = SimulationTask.create( + task = Task.create( task_type, task_name, folder_name, callback_url, simulation_type, parent_tasks, "Gz" ) if verbose: @@ -302,6 +302,134 @@ def upload( return task.task_id +@wait_for_connection +def upload_batch( + simulations: dict[str, SimulationType], + folder_name: str = "default", + callback_url: Optional[str] = None, + verbose: bool = True, + simulation_type: str = "tidy3d", + parent_tasks: Optional[dict[str, list[str]]] = None, + source_required: bool = True, + solver_version: Optional[str] = None, + reduce_simulation: Literal["auto", True, False] = "auto", + batch_type: str = BatchType.DEFAULT.value, +) -> tuple[str, dict[str, str]]: + """ + Upload multiple simulations to server in a single batch, but do not start running. + + Parameters + ---------- + simulations : dict[str, Union[:class:`.Simulation`, :class:`.HeatSimulation`, :class:`.EMESimulation`]] + Mapping of task names to simulation objects. + folder_name : str + Name of folder to store tasks on web UI + callback_url : str = None + Http PUT url to receive simulation finish event. + verbose : bool = True + If ``True``, will print progressbars and status, otherwise, will run silently. + simulation_type : str = "tidy3d" + Type of simulation being uploaded. + parent_tasks : dict[str, list[str]] + Mapping of task names to lists of parent task ids. + source_required : bool = True + If ``True``, simulations without sources will raise an error before being uploaded. + solver_version : str = None + Target solver version. + reduce_simulation : Literal["auto", True, False] = "auto" + Whether to reduce structures in the simulation to the simulation domain only. + batch_type : str = BatchType.DEFAULT.value + Internal batch type for server-side optimization. + + Returns + ------- + tuple[str, dict[str, str]] + Batch ID and mapping of task names to task IDs. + """ + if verbose: + console = get_logging_console() + console.log(f"Uploading batch of {len(simulations)} simulations to '{folder_name}'...") + + # validate all simulations first + for task_name, simulation in simulations.items(): + if isinstance(simulation, (ModeSolver, ModeSimulation)): + simulation = get_reduced_simulation(simulation, reduce_simulation) + simulations[task_name] = simulation + + stub = Tidy3dStub(simulation=simulation) + stub.validate_pre_upload(source_required=source_required) + + # create batch on server + batch_task = BatchTask.create( + simulations=simulations, + folder_name=folder_name, + callback_url=callback_url, + simulation_type=simulation_type, + parent_tasks=parent_tasks, + batch_type=batch_type, + ) + + if verbose: + console.log(f"Created batch '{batch_task.batch_id}' with {len(batch_task.task_ids)} tasks.") + + # upload simulation files for each task + for task_name, task_id in batch_task.task_ids.items(): + simulation = simulations[task_name] + stub = Tidy3dStub(simulation=simulation) + task_type = stub.get_type() + + remote_sim_file = SIM_FILE_HDF5_GZ + if task_type == "MODE_SOLVER": + remote_sim_file = MODE_FILE_HDF5_GZ + + # get the task object to upload the simulation + task = Task.get(task_id, verbose=False) + task.upload_simulation( + stub=stub, + verbose=False, + progress_callback=None, + remote_sim_file=remote_sim_file, + ) + + if solver_version is not None: + estimate_cost(task_id=task_id, solver_version=solver_version, verbose=False) + + if verbose: + console.log(f"Batch upload complete. Batch ID: '{batch_task.batch_id}'") + + return batch_task.batch_id, batch_task.task_ids + + +# TODO: Future batch webapi functions +# ==================================== +# The following batch operations are currently handled individually but could benefit +# from dedicated batch endpoints for improved performance and atomicity: +# +# def start_batch(batch_id: str, folder_name: str = "default", **kwargs) -> None: +# """Start all tasks in a batch simultaneously.""" +# # Would use: POST /tidy3d/projects/{folder_id}/batch/{batch_id}/submit +# # Benefits: Atomic batch submission, server-side coordination +# +# def delete_batch(batch_id: str, folder_name: str = "default") -> None: +# """Delete entire batch and all associated tasks.""" +# # Would use: DELETE /tidy3d/projects/{folder_id}/batch/{batch_id} +# # Benefits: Atomic batch deletion, server-side cleanup +# +# def get_batch_info(batch_id: str, folder_name: str = "default") -> BatchInfo: +# """Get batch status and metadata.""" +# # Would use: GET /tidy3d/projects/{folder_id}/batch/{batch_id} +# # Benefits: Single API call for batch overview, aggregate status +# +# def monitor_batch(batch_id: str, folder_name: str = "default") -> dict[str, str]: +# """Monitor progress of all tasks in a batch.""" +# # Would use: GET /tidy3d/projects/{folder_id}/batch/{batch_id}/status +# # Benefits: Efficient batch progress tracking, reduced API calls +# +# These functions would complement the existing upload_batch() function to provide +# a complete batch operations API that minimizes HTTP overhead and enables +# server-side optimizations. + + def get_reduced_simulation(simulation, reduce_simulation): """ Adjust the given simulation object based on the reduce_simulation parameter. Currently only @@ -365,7 +493,7 @@ def get_info(task_id: TaskId, verbose: bool = True) -> TaskInfo: :class:`TaskInfo` Object containing information about status, size, credits of task. """ - task = SimulationTask.get(task_id, verbose) + task = Task.get(task_id, verbose) if not task: raise ValueError("Task not found.") return TaskInfo(**{"taskId": task.task_id, "taskType": task.task_type, **task.dict()}) @@ -400,7 +528,7 @@ def start( """ if priority is not None and (priority < 1 or priority > 10): raise ValueError("Priority must be between '1' and '10' if specified.") - task = SimulationTask.get(task_id) + task = Task.get(task_id) if not task: raise ValueError("Task not found.") task.submit( @@ -429,7 +557,7 @@ def get_run_info(task_id: TaskId) -> tuple[Optional[float], Optional[float]]: Average field intensity normalized to max value (1.0). Is ``None`` if run info not available. """ - task = SimulationTask(taskId=task_id) + task = Task(taskId=task_id) return task.get_running_info() @@ -448,7 +576,7 @@ def get_status(task_id) -> str: if status == "error": try: # Try to obtain the error message - task = SimulationTask(taskId=task_id) + task = Task(taskId=task_id) with tempfile.NamedTemporaryFile(suffix=".json") as tmp_file: task.get_error_json(to_file=tmp_file.name) with open(tmp_file.name) as f: @@ -660,7 +788,7 @@ def download( if task_type == "MODE_SOLVER": remote_data_file = MODE_DATA_HDF5_GZ - task = SimulationTask(taskId=task_id) + task = Task(taskId=task_id) task.get_sim_data_hdf5( path, verbose=verbose, @@ -684,7 +812,7 @@ def download_json(task_id: TaskId, path: str = SIM_FILE_JSON, verbose: bool = Tr """ - task = SimulationTask(taskId=task_id) + task = Task(taskId=task_id) task.get_simulation_json(path, verbose=verbose) @@ -716,7 +844,7 @@ def download_hdf5( if task_type == "MODE_SOLVER": remote_sim_file = MODE_FILE_HDF5_GZ - task = SimulationTask(taskId=task_id) + task = Task(taskId=task_id) task.get_simulation_hdf5( path, verbose=verbose, progress_callback=progress_callback, remote_sim_file=remote_sim_file ) @@ -743,7 +871,7 @@ def load_simulation( Simulation loaded from downloaded json file. """ - task = SimulationTask.get(task_id) + task = Task.get(task_id) task.get_simulation_json(path, verbose=verbose) return Tidy3dStub.from_file(path) @@ -772,7 +900,7 @@ def download_log( ---- To load downloaded results into data, call :meth:`load` with option ``replace_existing=False``. """ - task = SimulationTask(taskId=task_id) + task = Task(taskId=task_id) task.get_log(path, verbose=verbose, progress_callback=progress_callback) @@ -845,7 +973,7 @@ def delete(task_id: TaskId, versions: bool = False) -> TaskInfo: TaskInfo Object containing information about status, size, credits of task. """ - task = SimulationTask(taskId=task_id) + task = Task(taskId=task_id) task.delete(versions=versions) return TaskInfo(**{"taskId": task.task_id, **task.dict()}) @@ -891,7 +1019,7 @@ def abort(task_id: TaskId) -> TaskInfo: Object containing information about status, size, credits of task. """ - task = SimulationTask.get(task_id) + task = Task.get(task_id) if not task: raise ValueError("Task not found.") task.abort() @@ -985,7 +1113,7 @@ def estimate_cost( print(f'The estimated maximum cost is {estimated_cost:.3f} Flex Credits.') """ - task = SimulationTask.get(task_id) + task = Task.get(task_id) if not task: raise ValueError("Task not found.") diff --git a/tidy3d/web/core/task_core.py b/tidy3d/web/core/task_core.py index 376aef301b..2196a10880 100644 --- a/tidy3d/web/core/task_core.py +++ b/tidy3d/web/core/task_core.py @@ -5,6 +5,8 @@ import os import pathlib import tempfile +import uuid +from abc import ABC from datetime import datetime from typing import Callable, Optional, Union @@ -25,7 +27,14 @@ from .http_util import http from .s3utils import download_file, download_gz_file, upload_file from .stub import TaskStub -from .types import PayType, Queryable, ResourceLifecycle, Submittable, Tidy3DResource +from .types import ( + BatchType, + PayType, + Queryable, + ResourceLifecycle, + Submittable, + Tidy3DResource, +) class Folder(Tidy3DResource, Queryable, extra=Extra.allow): @@ -115,13 +124,13 @@ def list_tasks(self) -> list[Tidy3DResource]: Returns ------- - tasks : List[:class:`.SimulationTask`] + tasks : List[:class:`.Task`] List of tasks in this folder """ resp = http.get(f"tidy3d/projects/{self.folder_id}/tasks") return ( parse_obj_as( - list[SimulationTask], + list[Task], resp, ) if resp @@ -129,34 +138,15 @@ def list_tasks(self) -> list[Tidy3DResource]: ) -class SimulationTask(ResourceLifecycle, Submittable, extra=Extra.allow): - """Interface for managing the running of a :class:`.Simulation` task on server.""" +class BaseTask(ResourceLifecycle, ABC): + """Base class for all task types with shared server communication functionality.""" - task_id: Optional[str] = Field( - ..., - title="task_id", - description="Task ID number, set when the task is uploaded, leave as None.", - alias="taskId", - ) folder_id: Optional[str] = Field( None, title="folder_id", description="Folder ID number, set when the task is uploaded, leave as None.", alias="folderId", ) - status: Optional[str] = Field(title="status", description="Simulation task status.") - - real_flex_unit: float = Field( - None, title="real FlexCredits", description="Billed FlexCredits.", alias="realCost" - ) - - created_at: Optional[datetime] = Field( - title="created_at", description="Time at which this task was created.", alias="createdAt" - ) - - task_type: Optional[str] = Field( - title="task_type", description="The type of task.", alias="taskType" - ) folder_name: Optional[str] = Field( "default", @@ -173,6 +163,229 @@ class SimulationTask(ResourceLifecycle, Submittable, extra=Extra.allow): "``{'id', 'status', 'name', 'workUnit', 'solverVersion'}``.", ) + @classmethod + def _get_folder(cls, folder_name: str) -> Folder: + """Get or create folder by name.""" + return Folder.get(folder_name, create=True) + + @classmethod + def _validate_simulations( + cls, + simulations: dict[str, td.components.base_sim.simulation.AbstractSimulation], + source_required: bool = True, + ) -> None: + """Validate simulations before upload.""" + for simulation in simulations.values(): + from tidy3d.web.api.tidy3d_stub import Tidy3dStub + + stub = Tidy3dStub(simulation=simulation) + stub.validate_pre_upload(source_required=source_required) + + +class BatchTask(BaseTask): + """Batch simulation task for handling multiple simulations.""" + + batch_id: str = Field( + ..., + title="Batch ID", + description="Unique identifier for the batch", + ) + task_ids: dict[str, str] = Field( + ..., + title="Task IDs", + description="Mapping of task names to task IDs", + ) + + @classmethod + def create( + cls, + simulations: dict[str, td.components.base_sim.simulation.AbstractSimulation], + folder_name: str = "default", + callback_url: Optional[str] = None, + simulation_type: str = "tidy3d", + parent_tasks: Optional[dict[str, list[str]]] = None, + file_type: str = "Gz", + batch_type: str = BatchType.DEFAULT.value, + ) -> BatchTask: + """Create multiple simulation tasks in a single batch on the server. + + Parameters + ---------- + simulations : dict[str, AbstractSimulation] + Mapping of task names to simulation objects. + folder_name : str + The name of the folder to store the tasks. Default is "default". + callback_url : str + Http PUT url to receive simulation finish event. + simulation_type : str + Type of simulation being uploaded. + parent_tasks : dict[str, list[str]] + Mapping of task names to lists of parent task ids. + file_type : str + The simulation file type Json, Hdf5, Gz + batch_type : str + Internal batch type for server-side optimization. + + Returns + ------- + BatchTask + Object containing batch_id and mapping of task names to task IDs. + """ + # Handle backwards compatibility + if simulation_type is None: + simulation_type = "tidy3d" + + folder = cls._get_folder(folder_name) + + cls._validate_simulations(simulations) + + from tidy3d.web.api.tidy3d_stub import Tidy3dStub + + batch_data = [] + for task_name, sim in simulations.items(): + stub = Tidy3dStub(simulation=sim) + task_data = { + "taskName": task_name, + "taskType": stub.get_type(), + "callbackUrl": callback_url, + "simulationType": simulation_type, + "fileType": file_type, + } + if parent_tasks and task_name in parent_tasks: + task_data["parentTasks"] = parent_tasks[task_name] + batch_data.append(task_data) + + # generate a unique group name using a short UUID + short_uuid = str(uuid.uuid4())[:8] + group_name = f"batch_{short_uuid}" + + for task in batch_data: + task["groupName"] = group_name + + request_data = { + "tasks": batch_data, + "batchType": batch_type, + "groupName": group_name, + } + + resp = http.post( + f"tidy3d/projects/{folder.folder_id}/batch-tasks", + request_data, + ) + + task_ids = {task["taskName"]: task["taskId"] for task in resp["tasks"]} + + return cls( + batch_id=resp["batchId"], + task_ids=task_ids, + folder_id=folder.folder_id, + folder_name=folder_name, + callback_url=callback_url, + ) + + def delete(self) -> None: + """Delete all tasks in this batch. + + Note + ---- + Current implementation deletes each task individually. This could be enhanced + with a dedicated batch delete endpoint for better performance and atomicity. + + Future enhancement: + - Batch Delete Endpoint: DELETE /tidy3d/projects/{folder_id}/batch/{batch_id} + - Benefits: Atomic batch deletion, server-side cleanup optimizations + - Would allow deleting entire batch and all associated tasks in single operation + - Could support options like deleting only completed tasks, preserving data, etc. + """ + # TODO: Replace with batch delete endpoint when available + # DELETE /tidy3d/projects/{folder_id}/batch/{batch_id} + for task_id in self.task_ids.values(): + http.delete(f"tidy3d/tasks/{task_id}") + + def submit(self) -> None: + """Submit all tasks in the batch to start running. + + Note + ---- + Current implementation submits each task individually. This could be enhanced + with a dedicated batch submit endpoint for better performance and atomicity. + + Future enhancement: + - Batch Submit Endpoint: POST /tidy3d/projects/{folder_id}/batch/{batch_id}/submit + - Benefits: Reduced HTTP overhead, atomic batch submission, server-side optimizations + - Would allow submitting all tasks in a single API call with batch-level parameters + """ + # TODO: Replace with batch submit endpoint when available + # POST /tidy3d/projects/{folder_id}/batch/{batch_id}/submit + for task_id in self.task_ids.values(): + http.post(f"tidy3d/tasks/{task_id}/submit") + + @classmethod + def get(cls, batch_id: str, verbose: bool = True) -> BatchTask: + """Get batch from the server by batch_id. + + Note + ---- + This is a placeholder implementation as batch retrieval endpoints do not exist yet. + Current workaround would require storing batch_id and task_ids locally or retrieving + them through individual task queries. + + Future enhancement: + - Batch Retrieval Endpoint: GET /tidy3d/projects/{folder_id}/batch/{batch_id} + - Should return: batch_id, task_ids mapping, batch metadata, status summary + - Could include aggregate batch status (all_completed, some_failed, etc.) + - May support filtering (only completed tasks, failed tasks, etc.) + + Alternative implementation approaches: + 1. Add batch_id field to individual tasks for reverse lookup + 2. Store batch metadata in separate batch management system + 3. Use task naming conventions to group related tasks + + Parameters + ---------- + batch_id: str + Unique identifier of batch on server. + verbose: bool + If `True`, will print progress, otherwise, will run silently. + + Returns + ------- + BatchTask + BatchTask object containing batch info. + """ + # TODO: Implement batch retrieval endpoint when available + # GET /tidy3d/projects/{folder_id}/batch/{batch_id} + # For now, this is a placeholder to satisfy the abstract method requirement + raise NotImplementedError( + "Batch retrieval not yet implemented. " + "Future batch retrieval endpoint (GET /tidy3d/projects/{folder_id}/batch/{batch_id}) " + "would return batch metadata and task IDs mapping." + ) + + +class Task(BaseTask, Submittable, extra=Extra.allow): + """Interface for managing the running of a :class:`.Simulation` task on server.""" + + task_id: Optional[str] = Field( + ..., + title="task_id", + description="Task ID number, set when the task is uploaded, leave as None.", + alias="taskId", + ) + status: Optional[str] = Field(title="status", description="Simulation task status.") + + real_flex_unit: float = Field( + None, title="real FlexCredits", description="Billed FlexCredits.", alias="realCost" + ) + + created_at: Optional[datetime] = Field( + title="created_at", description="Time at which this task was created.", alias="createdAt" + ) + + task_type: Optional[str] = Field( + title="task_type", description="The type of task.", alias="taskType" + ) + # simulation_type: str = pd.Field( # None, # title="Simulation Type", @@ -209,7 +422,7 @@ def create( simulation_type: str = "tidy3d", parent_tasks: Optional[list[str]] = None, file_type: str = "Gz", - ) -> SimulationTask: + ) -> Task: """Create a new task on the server. Parameters @@ -232,8 +445,8 @@ def create( Returns ------- - :class:`SimulationTask` - :class:`SimulationTask` object containing info about status, size, + :class:`Task` + :class:`Task` object containing info about status, size, credits of task and others. """ @@ -241,7 +454,7 @@ def create( if simulation_type is None: simulation_type = "tidy3d" - folder = Folder.get(folder_name, create=True) + folder = cls._get_folder(folder_name) resp = http.post( f"tidy3d/projects/{folder.folder_id}/tasks", { @@ -253,10 +466,10 @@ def create( "fileType": file_type, }, ) - return SimulationTask(**resp, taskType=task_type, folder_name=folder_name) + return Task(**resp, taskType=task_type, folder_name=folder_name) @classmethod - def get(cls, task_id: str, verbose: bool = True) -> SimulationTask: + def get(cls, task_id: str, verbose: bool = True) -> Task: """Get task from the server by id. Parameters @@ -268,8 +481,8 @@ def get(cls, task_id: str, verbose: bool = True) -> SimulationTask: Returns ------- - :class:`.SimulationTask` - :class:`.SimulationTask` object containing info about status, + :class:`.Task` + :class:`.Task` object containing info about status, size, credits of task and others. """ try: @@ -278,23 +491,23 @@ def get(cls, task_id: str, verbose: bool = True) -> SimulationTask: td.log.error(f"The requested task ID '{task_id}' does not exist.") raise e - task = SimulationTask(**resp) if resp else None + task = Task(**resp) if resp else None return task @classmethod - def get_running_tasks(cls) -> list[SimulationTask]: + def get_running_tasks(cls) -> list[Task]: """Get a list of running tasks from the server" Returns ------- - List[:class:`.SimulationTask`] - :class:`.SimulationTask` object containing info about status, + List[:class:`.Task`] + :class:`.Task` object containing info about status, size, credits of task and others. """ resp = http.get("tidy3d/py/tasks") if not resp: return [] - return parse_obj_as(list[SimulationTask], resp) + return parse_obj_as(list[Task], resp) def delete(self, versions: bool = False): """Delete current task from server. @@ -688,11 +901,11 @@ def validate_post_upload(self, parent_tasks: Optional[list[str]] = None): ) try: # get mesh task info - mesh_task = SimulationTask.get(parent_tasks[0], verbose=False) + mesh_task = Task.get(parent_tasks[0], verbose=False) assert mesh_task.task_type == "VOLUME_MESH" assert mesh_task.status == "success" # get up-to-date task info - task = SimulationTask.get(self.task_id, verbose=False) + task = Task.get(self.task_id, verbose=False) if task.fileMd5 != mesh_task.childFileMd5: raise ValidationError( "Simulation stored in parent task 'VolumeMesher' does not match the " @@ -706,3 +919,7 @@ def validate_post_upload(self, parent_tasks: Optional[list[str]] = None): except Exception as e: raise WebError(f"Provided 'parent_tasks' failed validation: {e!s}") from e + + +# Backward compatibility alias +SimulationTask = Task diff --git a/tidy3d/web/core/types.py b/tidy3d/web/core/types.py index 2ecf066402..b431dba983 100644 --- a/tidy3d/web/core/types.py +++ b/tidy3d/web/core/types.py @@ -68,3 +68,16 @@ def _missing_(cls, value: object) -> PayType: if key in cls.__members__: return cls.__members__[key] return super()._missing_(value) + + +class BatchType(str, Enum): + """Batch type for server-side batch processing optimization.""" + + BATCH = "BATCH" + INVDES = "INVDES" + PERMUTATION = "Permutation" + PARALLEL = "Parallel" + MONTE_CARLO = "MONTE_CARLO" + RF_SWEEP = "RF_SWEEP" + + DEFAULT = "RF_SWEEP" # noqa: PIE796