Skip to content

Commit 8f9544e

Browse files
committed
feat(web): add priority parameter for vGPU task queue management
1 parent 327755b commit 8f9544e

File tree

10 files changed

+88
-23
lines changed

10 files changed

+88
-23
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1010
### Added
1111
- Add support for `np.unwrap` in `tidy3d.plugins.autograd`.
1212
- Add Nunley variant to germanium material library based on Nunley et al. 2016 data.
13+
- Added `priority` parameter to `web.run()` and related functions to allow vGPU users to set task priority (1-10) in the queue.
1314

1415
### Changed
1516
- Switched to an analytical gradient calculation for spatially-varying pole-residue models (`CustomPoleResidue`).

tests/test_web/test_tidy3d_task.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ def test_submit(set_api_key):
218218
"workerGroup": None,
219219
"enableCaching": Env.current.enable_caching,
220220
"payType": PayType.AUTO,
221+
"priority": None,
221222
}
222223
)
223224
],

tests/test_web/test_webapi.py

Lines changed: 62 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -182,29 +182,36 @@ def mock_get_info(monkeypatch, set_api_key):
182182
def mock_start(monkeypatch, set_api_key, mock_get_info):
183183
"""Mocks webapi.start."""
184184

185-
responses.add(
186-
responses.POST,
187-
f"{Env.current.web_api_endpoint}/tidy3d/tasks/{TASK_ID}/submit",
188-
match=[
189-
matchers.json_params_matcher(
190-
{
191-
"solverVersion": None,
192-
"workerGroup": None,
193-
"protocolVersion": td.version.__version__,
194-
"enableCaching": Env.current.enable_caching,
195-
"payType": PayType.AUTO,
185+
def add_mock_response(priority=None):
186+
expected_body = {
187+
"solverVersion": None,
188+
"workerGroup": None,
189+
"protocolVersion": td.version.__version__,
190+
"enableCaching": Env.current.enable_caching,
191+
"payType": PayType.AUTO,
192+
"priority": priority,
193+
}
194+
195+
responses.add(
196+
responses.POST,
197+
f"{Env.current.web_api_endpoint}/tidy3d/tasks/{TASK_ID}/submit",
198+
match=[matchers.json_params_matcher(expected_body)],
199+
json={
200+
"data": {
201+
"taskId": TASK_ID,
202+
"taskName": TASK_NAME,
203+
"createdAt": CREATED_AT,
196204
}
197-
)
198-
],
199-
json={
200-
"data": {
201-
"taskId": TASK_ID,
202-
"taskName": TASK_NAME,
203-
"createdAt": CREATED_AT,
204-
}
205-
},
206-
status=200,
207-
)
205+
},
206+
status=200,
207+
)
208+
209+
# Add response for calls without priority
210+
add_mock_response(None)
211+
212+
# Add responses for calls with specific priority values
213+
for priority in [1, 5, 10]:
214+
add_mock_response(priority)
208215

209216

210217
@pytest.fixture
@@ -322,6 +329,39 @@ def test_start(mock_start):
322329
start(TASK_ID)
323330

324331

332+
@responses.activate
333+
@pytest.mark.parametrize("priority", [1, 5, 10, None])
334+
def test_start_with_valid_priority(mock_start, priority):
335+
"""Test start with valid priority values."""
336+
start(TASK_ID, priority=priority)
337+
338+
339+
@responses.activate
340+
@pytest.mark.parametrize("priority", [0, -1, 11, 15])
341+
def test_start_with_invalid_priority(mock_start, priority):
342+
"""Test start with invalid priority values."""
343+
with pytest.raises(ValueError, match="Priority must be between '1' and '10' if specified."):
344+
start(TASK_ID, priority=priority)
345+
346+
347+
@responses.activate
348+
@pytest.mark.parametrize("priority", [5, None])
349+
def test_run_with_valid_priority(mock_webapi, monkeypatch, priority):
350+
"""Test run with valid priority parameter."""
351+
monkeypatch.setattr(f"{api_path}.load", lambda *args, **kwargs: True)
352+
sim = make_sim()
353+
run(sim, TASK_NAME, folder_name=PROJECT_NAME, priority=priority)
354+
355+
356+
@responses.activate
357+
@pytest.mark.parametrize("priority", [0, -1, 11, 15])
358+
def test_run_with_invalid_priority(mock_webapi, priority):
359+
"""Test run with invalid priority values."""
360+
sim = make_sim()
361+
with pytest.raises(ValueError, match="Priority must be between '1' and '10' if specified."):
362+
run(sim, TASK_NAME, folder_name=PROJECT_NAME, priority=priority)
363+
364+
325365
@responses.activate
326366
def test_get_run_info(mock_get_run_info):
327367
assert get_run_info(TASK_ID) == (100, 0)

tests/test_web/test_webapi_eme.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def mock_start(monkeypatch, set_api_key, mock_get_info):
131131
"protocolVersion": td.version.__version__,
132132
"enableCaching": Env.current.enable_caching,
133133
"payType": PayType.AUTO,
134+
"priority": None,
134135
}
135136
)
136137
],

tests/test_web/test_webapi_heat.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ def mock_start(monkeypatch, set_api_key, mock_get_info):
128128
"protocolVersion": td.version.__version__,
129129
"enableCaching": Env.current.enable_caching,
130130
"payType": PayType.AUTO,
131+
"priority": None,
131132
}
132133
)
133134
],

tests/test_web/test_webapi_mode.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ def mock_start(monkeypatch, set_api_key, mock_get_info):
164164
"protocolVersion": td.version.__version__,
165165
"enableCaching": Env.current.enable_caching,
166166
"payType": PayType.AUTO,
167+
"priority": None,
167168
}
168169
)
169170
],

tests/test_web/test_webapi_mode_sim.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def mock_start(monkeypatch, set_api_key, mock_get_info):
160160
"protocolVersion": td.version.__version__,
161161
"enableCaching": Env.current.enable_caching,
162162
"payType": PayType.AUTO,
163+
"priority": None,
163164
}
164165
)
165166
],

tidy3d/web/api/autograd/autograd.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def run(
108108
max_num_adjoint_per_fwd: int = MAX_NUM_ADJOINT_PER_FWD,
109109
reduce_simulation: typing.Literal["auto", True, False] = "auto",
110110
pay_type: typing.Union[PayType, str] = PayType.AUTO,
111+
priority: typing.Optional[int] = None,
111112
) -> SimulationDataType:
112113
"""
113114
Submits a :class:`.Simulation` to server, starts running, monitors progress, downloads,
@@ -147,6 +148,8 @@ def run(
147148
Whether to reduce structures in the simulation to the simulation domain only. Note: currently only implemented for the mode solver.
148149
pay_type: typing.Union[PayType, str] = PayType.AUTO
149150
Which method to pay for the simulation.
151+
priority: int = None
152+
Task priority for vGPU queue (1=lowest, 10=highest).
150153
Returns
151154
-------
152155
Union[:class:`.SimulationData`, :class:`.HeatSimulationData`, :class:`.EMESimulationData`]
@@ -191,6 +194,8 @@ def run(
191194
:meth:`tidy3d.web.api.container.Batch.monitor`
192195
Monitor progress of each of the running tasks.
193196
"""
197+
if priority is not None and (priority < 1 or priority > 10):
198+
raise ValueError("Priority must be between '1' and '10' if specified.")
194199
if is_valid_for_autograd(simulation):
195200
return _run(
196201
simulation=simulation,
@@ -225,6 +230,7 @@ def run(
225230
parent_tasks=parent_tasks,
226231
reduce_simulation=reduce_simulation,
227232
pay_type=pay_type,
233+
priority=priority,
228234
)
229235

230236

tidy3d/web/api/webapi.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def run(
8383
parent_tasks: Optional[list[str]] = None,
8484
reduce_simulation: Literal["auto", True, False] = "auto",
8585
pay_type: Union[PayType, str] = PayType.AUTO,
86+
priority: Optional[int] = None,
8687
) -> SimulationDataType:
8788
"""
8889
Submits a :class:`.Simulation` to server, starts running, monitors progress, downloads,
@@ -117,7 +118,8 @@ def run(
117118
Whether to reduce structures in the simulation to the simulation domain only. Note: currently only implemented for the mode solver.
118119
pay_type: Union[PayType, str] = PayType.AUTO
119120
Which method to pay the simulation.
120-
121+
priority: int = None
122+
Task priority for vGPU queue (1=lowest, 10=highest).
121123
Returns
122124
-------
123125
Union[:class:`.SimulationData`, :class:`.HeatSimulationData`, :class:`.EMESimulationData`]
@@ -179,6 +181,7 @@ def run(
179181
solver_version=solver_version,
180182
worker_group=worker_group,
181183
pay_type=pay_type,
184+
priority=priority,
182185
)
183186
monitor(task_id, verbose=verbose)
184187
data = load(
@@ -372,6 +375,7 @@ def start(
372375
solver_version: Optional[str] = None,
373376
worker_group: Optional[str] = None,
374377
pay_type: Union[PayType, str] = PayType.AUTO,
378+
priority: Optional[int] = None,
375379
) -> None:
376380
"""Start running the simulation associated with task.
377381
@@ -386,17 +390,22 @@ def start(
386390
worker group
387391
pay_type: Union[PayType, str] = PayType.AUTO
388392
Which method to pay the simulation
393+
priority: int = None
394+
Task priority for vGPU queue (1=lowest, 10=highest).
389395
Note
390396
----
391397
To monitor progress, can call :meth:`monitor` after starting simulation.
392398
"""
399+
if priority is not None and (priority < 1 or priority > 10):
400+
raise ValueError("Priority must be between '1' and '10' if specified.")
393401
task = SimulationTask.get(task_id)
394402
if not task:
395403
raise ValueError("Task not found.")
396404
task.submit(
397405
solver_version=solver_version,
398406
worker_group=worker_group,
399407
pay_type=pay_type,
408+
priority=priority,
400409
)
401410

402411

tidy3d/web/core/task_core.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,7 @@ def submit(
429429
solver_version: Optional[str] = None,
430430
worker_group: Optional[str] = None,
431431
pay_type: Union[PayType, str] = PayType.AUTO,
432+
priority: Optional[int] = None,
432433
):
433434
"""Kick off this task.
434435
@@ -444,6 +445,8 @@ def submit(
444445
worker group
445446
pay_type: Union[PayType, str] = PayType.AUTO
446447
Which method to pay the simulation.
448+
priority: int = None
449+
Task priority for vGPU queue (1=lowest, 10=highest).
447450
"""
448451
pay_type = PayType(pay_type) if not isinstance(pay_type, PayType) else pay_type
449452

@@ -460,6 +463,7 @@ def submit(
460463
"protocolVersion": protocol_version,
461464
"enableCaching": Env.current.enable_caching,
462465
"payType": pay_type.value,
466+
"priority": priority,
463467
},
464468
)
465469

0 commit comments

Comments
 (0)