Skip to content

feat(webapi): batch task submission via dedicated endpoint #2694

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
241 changes: 241 additions & 0 deletions tests/test_web/test_webapi.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Tests webapi and things that depend on it
from __future__ import annotations

import json

import numpy as np
import pytest
import responses
Expand Down Expand Up @@ -296,6 +298,165 @@ def mock_get_run_info(monkeypatch, set_api_key):
)


@pytest.fixture
def mock_batch_upload_single(monkeypatch, set_api_key):
"""Mocks batch upload endpoint for single task."""
# Mock folder retrieval
responses.add(
responses.GET,
f"{Env.current.web_api_endpoint}/tidy3d/project",
match=[matchers.query_param_matcher({"projectName": PROJECT_NAME})],
json={"data": {"projectId": FOLDER_ID, "projectName": PROJECT_NAME}},
status=200,
)

# mock batch endpoint - returns single task
def batch_request_matcher(request):
json_data = json.loads(request.body)
assert "tasks" in json_data
assert "batchType" in json_data
assert "groupName" in json_data
for task in json_data["tasks"]:
assert "groupName" in task
assert task["groupName"] == json_data["groupName"]
return True, None

responses.add(
responses.POST,
f"{Env.current.web_api_endpoint}/tidy3d/projects/{FOLDER_ID}/batch-tasks",
match=[batch_request_matcher],
json={"batchId": "batch_123", "tasks": [{"taskId": "task_id_0", "taskName": "task_0"}]},
status=200,
)

# mock task detail endpoints for the single task
responses.add(
responses.GET,
f"{Env.current.web_api_endpoint}/tidy3d/tasks/task_id_0",
json={
"data": {
"taskId": "task_id_0",
"taskName": "task_0",
"createdAt": CREATED_AT,
"fileType": "Gz",
"resourcePath": "output/task_id_0.json",
"solverVersion": None,
"taskType": TaskType.FDTD.name,
}
},
status=200,
)

responses.add(
responses.GET,
f"{Env.current.web_api_endpoint}/tidy3d/tasks/task_id_0/detail",
json={
"data": {
"taskId": "task_id_0",
"taskName": "task_0",
"createdAt": CREATED_AT,
"realFlexUnit": FLEX_UNIT,
"estFlexUnit": EST_FLEX_UNIT,
"taskType": TaskType.FDTD.name,
"metadataStatus": "processed",
"status": "draft",
"s3Storage": 1.0,
}
},
status=200,
)

def mock_upload_file(*args, **kwargs):
pass

monkeypatch.setattr("tidy3d.web.core.task_core.upload_file", mock_upload_file)


@pytest.fixture
def mock_batch_upload_triple(monkeypatch, set_api_key):
"""Mocks batch upload endpoint for three tasks."""
# Mock folder retrieval
responses.add(
responses.GET,
f"{Env.current.web_api_endpoint}/tidy3d/project",
match=[matchers.query_param_matcher({"projectName": PROJECT_NAME})],
json={"data": {"projectId": FOLDER_ID, "projectName": PROJECT_NAME}},
status=200,
)

def batch_request_matcher(request):
import json
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Redundant import of json module - already imported at line 4

Suggested change
import json
def batch_request_matcher(request):
json_data = json.loads(request.body)
assert "tasks" in json_data
assert "batchType" in json_data
assert "groupName" in json_data
for task in json_data["tasks"]:
assert "groupName" in task
assert task["groupName"] == json_data["groupName"]
return True, None


json_data = json.loads(request.body)
assert "tasks" in json_data
assert "batchType" in json_data
assert "groupName" in json_data
for task in json_data["tasks"]:
assert "groupName" in task
assert task["groupName"] == json_data["groupName"]
return True, None

responses.add(
responses.POST,
f"{Env.current.web_api_endpoint}/tidy3d/projects/{FOLDER_ID}/batch-tasks",
match=[batch_request_matcher],
json={
"batchId": "batch_123",
"tasks": [
{"taskId": "task_id_0", "taskName": "task_0"},
{"taskId": "task_id_1", "taskName": "task_1"},
{"taskId": "task_id_2", "taskName": "task_2"},
],
},
status=200,
)

for i in range(3):
task_name = f"task_{i}"
task_id = f"task_id_{i}"

responses.add(
responses.GET,
f"{Env.current.web_api_endpoint}/tidy3d/tasks/{task_id}",
json={
"data": {
"taskId": task_id,
"taskName": task_name,
"createdAt": CREATED_AT,
"fileType": "Gz",
"resourcePath": f"output/{task_id}.json",
"solverVersion": None,
"taskType": TaskType.FDTD.name,
}
},
status=200,
)

responses.add(
responses.GET,
f"{Env.current.web_api_endpoint}/tidy3d/tasks/{task_id}/detail",
json={
"data": {
"taskId": task_id,
"taskName": task_name,
"createdAt": CREATED_AT,
"realFlexUnit": FLEX_UNIT,
"estFlexUnit": EST_FLEX_UNIT,
"taskType": TaskType.FDTD.name,
"metadataStatus": "processed",
"status": "draft",
"s3Storage": 1.0,
}
},
status=200,
)

def mock_upload_file(*args, **kwargs):
pass

monkeypatch.setattr("tidy3d.web.core.task_core.upload_file", mock_upload_file)


@pytest.fixture
def mock_webapi(
mock_upload, mock_metadata, mock_get_info, mock_start, mock_monitor, mock_download, mock_load
Expand Down Expand Up @@ -628,6 +789,86 @@ def test_batch(mock_webapi, mock_job_status, mock_load, tmp_path):
assert b2.real_cost() == FLEX_UNIT * len(sims)


@responses.activate
def test_batch_with_endpoint(mock_batch_upload_triple, tmp_path):
"""Test batch with new batch endpoint."""

sims = {f"task_{i}": make_sim() for i in range(3)}

batch = Batch(simulations=sims, folder_name=PROJECT_NAME, use_batch_endpoint=True)

assert batch.use_batch_endpoint is True

# access jobs property to trigger batch submission
jobs = batch.jobs

# verify jobs were created with pre-assigned task_ids
assert len(jobs) == 3
for i, (task_name, job) in enumerate(jobs.items()):
assert task_name == f"task_{i}"
assert job.task_id == f"task_id_{i}"
assert job._cached_properties.get("task_id") == f"task_id_{i}"
Comment on lines +809 to +810
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Accessing _cached_properties directly may be fragile - consider using a public method if this internal access pattern is common


# test serialization preserves the flag
fname = str(tmp_path / "batch_endpoint.json")
batch.to_file(fname)
batch_loaded = Batch.from_file(fname)

assert batch_loaded.use_batch_endpoint is True
assert len(batch_loaded.jobs) == 3


@responses.activate
def test_batch_backward_compatibility(mock_webapi, mock_job_status, mock_load, tmp_path):
"""Test that default behavior remains unchanged (backward compatibility)."""
sims = {TASK_NAME: make_sim()}

# Create batch without specifying use_batch_endpoint (should default to False)
batch = Batch(simulations=sims, folder_name=PROJECT_NAME)

# Verify default is False
assert batch.use_batch_endpoint is False

# Access jobs to trigger normal flow
jobs = batch.jobs
assert len(jobs) == 1

# Run and verify it works as before
batch.run(path_dir=str(tmp_path))
assert batch.real_cost() == FLEX_UNIT * len(sims)


@responses.activate
def test_batch_endpoint_integration(
mock_batch_upload_single, mock_webapi, mock_job_status, tmp_path
):
"""Test both batch endpoint modes produce compatible results."""
sim = make_sim()

# old way
batch_old = Batch(
simulations={"task_0": sim}, folder_name=PROJECT_NAME, use_batch_endpoint=False
)
fname_old = str(tmp_path / "batch_old.json")
batch_old.to_file(fname_old)

# new way
batch_new = Batch(
simulations={"task_0": sim}, folder_name=PROJECT_NAME, use_batch_endpoint=True
)
fname_new = str(tmp_path / "batch_new.json")
batch_new.to_file(fname_new)

# load and verify both work
batch_old_loaded = Batch.from_file(fname_old)
batch_new_loaded = Batch.from_file(fname_new)

assert batch_old_loaded.use_batch_endpoint is False
assert batch_new_loaded.use_batch_endpoint is True
assert len(batch_old_loaded.jobs) == 1
assert len(batch_new_loaded.jobs) == 1


@responses.activate
def test_create_output_dirs(mock_webapi, tmp_path, monkeypatch):
"""Test that Job and Batch create output directories if they don't exist."""
Expand Down
Loading