diff --git a/AUTOGRAD_REFACTOR_PLAN.md b/AUTOGRAD_REFACTOR_PLAN.md new file mode 100644 index 0000000000..392b59159f --- /dev/null +++ b/AUTOGRAD_REFACTOR_PLAN.md @@ -0,0 +1,378 @@ +# Autograd Refactor Plan: Making Job and Batch Autograd-Compatible + +## Background Context + +This project supports differentiation through simulation running using autograd. Currently some of the functional web API tools are supported, namely `web.run()` and `web.run_async()`. All relevant web API is found in `tidy3d/web/api/`. + +### Current Architecture Overview + +**Web API Structure (`tidy3d/web/api/`):** +- **`webapi.py`**: Core HTTP API functions (`run()`, `upload()`, `start()`, `monitor()`, `download()`, etc.) - these are NOT autograd compatible +- **`autograd/autograd.py`**: Autograd-compatible wrappers around webapi functions with primitives for forward/backward passes +- **`container.py`**: Object-oriented interfaces (`Job`, `Batch`, `BatchData`) that provide stateful wrappers around webapi functions +- **`asynchronous.py`**: Simple batch interface that creates a `Batch` and calls `batch.run()` + +**Current Autograd Support:** +- Autograd primitives are defined in `autograd/autograd.py` using `@primitive` decorators +- These wrap the underlying webapi functions and add gradient computation +- The autograd versions are exported in `tidy3d/web/__init__.py` (autograd `run` and `run_async` shadow the webapi versions) +- Tests for autograd are in `tests/test_components/test_autograd.py` +- Web unit tests use mocking patterns in `tests/test_web/test_webapi.py` + +**Key Relationships:** +- `Job` and `Batch` call webapi functions internally (e.g., `Job.run()` calls `webapi.upload()`, `webapi.start()`, etc.) +- `autograd.py` ultimately calls `_run_tidy3d()` and `_run_async_tidy3d()` which create `Job` and `Batch` instances +- `asynchronous.py` is just a thin wrapper around `Batch.run()` +- `BatchData` is defined in `container.py` but used widely across the codebase + +## Problem Statement + +Currently, `Job.run()` and `Batch.run()` are not autograd-differentiable. Users cannot write code like: + +```python +def f(x): + sim = make_simulation(x) + data = Job(simulation=sim).run() # ❌ Not differentiable + return postprocess(data) +``` + +**Why this matters:** Users prefer the object-oriented interface (`Job`, `Batch`) for its convenience and state management, but these bypass the autograd-aware `run()` and `run_async()` functions from `autograd.py`. + +## Current Circular Dependency Issues + +The main blocker is circular dependencies that prevent `Job.run()` and `Batch.run()` from calling the autograd-compatible functions: + +1. **autograd.py** imports from **container.py** (`Batch`, `BatchData`, `Job`) +2. **asynchronous.py** imports from **container.py** (`Batch`, `BatchData`) +3. **container.py** needs autograd functions but can't import them due to circular imports +4. **BatchData** is defined in container.py but used widely across the codebase + +**Additional complexity:** The `run_async()` function in `autograd.py` actually calls `Batch` under the hood because `Batch` contains all the batch processing logic. So there's a circular dependency where: +- `autograd.run_async()` needs `Batch` +- But `Batch.run()` should call `autograd.run_async()` for autograd compatibility + +**BatchData coupling:** `BatchData` is defined in the same module as `Job` and `Batch`, meaning anything that uses `BatchData` also imports the heavyweight container classes, contributing to circular import issues. + +## Step-by-Step Plan + +### Step 1: Add Autograd Tests for Job and Batch + +**Goal**: Create failing tests that demonstrate the desired behavior + +**Files to modify**: +- `tests/test_components/test_autograd.py` + +**Tasks**: +1. Add test cases that mirror existing autograd tests but use `Job.run()` and `Batch.run()` +2. Extend the `use_emulated_run` fixture to also mock Job and Batch operations (study the existing mocking patterns in `tests/test_web/test_webapi.py` for guidance) +3. Add tests for both single and batch scenarios: + ```python + def test_job_autograd_compatibility(use_emulated_run): + def objective(params): + sim = make_simulation(params) + job = Job(simulation=sim, task_name="test") + data = job.run() # Should be differentiable + return postprocess(data) + + # Test gradient computation + grad = autograd.grad(objective)(test_params) + assert grad is not None + + def test_batch_autograd_compatibility(use_emulated_run): + def objective(params): + sims = {f"task_{i}": make_simulation(p) for i, p in enumerate(params)} + batch = Batch(simulations=sims) + data = batch.run() # Should be differentiable + return postprocess_batch(data) + + # Test gradient computation + grad = autograd.grad(objective)(test_params) + assert grad is not None + ``` + +**Expected result**: Tests fail initially, demonstrating the problem + +### Step 2: Extract BatchData to Separate Module + +**Goal**: Break circular import dependencies by making BatchData independent + +**Files to create**: +- `tidy3d/web/api/batch_data.py` + +**Files to modify**: +- All files that import BatchData (see grep results above) + +**Tasks**: +1. **Create `batch_data.py`**: + ```python + # tidy3d/web/api/batch_data.py + from __future__ import annotations + from typing import Mapping + from tidy3d.components.base import Tidy3dBaseModel + from tidy3d.web.core.constants import TaskName + from .tidy3d_stub import SimulationDataType + import tidy3d.web.api.webapi as web # Only import webapi, not container + + class BatchData(Tidy3dBaseModel, Mapping): + # Move entire BatchData class here with minimal dependencies + ``` + +2. **Update all imports**: + - Replace `from .container import BatchData` with `from .batch_data import BatchData` + - Update imports in: `asynchronous.py`, `autograd.py`, `__init__.py`, plugins, tests + +3. **Remove BatchData from container.py**: + - Keep only Job and Batch classes in container.py + - Update Batch.load() to import and create BatchData from the new module + +**Expected result**: BatchData is independent, reducing circular dependencies + +### Step 3: Reverse Container ↔ Asynchronous Dependency + +**Goal**: Make `asynchronous.py` contain the core batch logic, with `Batch` as a thin wrapper + +**Files to modify**: +- `tidy3d/web/api/asynchronous.py` +- `tidy3d/web/api/container.py` + +**Rationale**: The current setup has `asynchronous.py` depending on `Batch`, but we want `Batch.run()` to call `autograd.run_async()`. We need to reverse this so that the core batch upload/start/monitor/download logic lives in `asynchronous.py`, and `Batch` becomes a thin wrapper that calls these functions. + +**Tasks**: +1. **Move batch logic to asynchronous.py**: + ```python + # asynchronous.py + def upload_batch(simulations: dict[str, SimulationType], **kwargs) -> dict[str, Job]: + """Core batch upload logic moved from Batch.upload()""" + + def start_batch(jobs: dict[str, Job], **kwargs) -> None: + """Core batch start logic moved from Batch.start()""" + + def monitor_batch(jobs: dict[str, Job], **kwargs) -> None: + """Core batch monitor logic moved from Batch.monitor()""" + + def download_batch(jobs: dict[str, Job], path_dir: str, **kwargs) -> None: + """Core batch download logic moved from Batch.download()""" + + def load_batch(jobs: dict[str, Job], path_dir: str, **kwargs) -> BatchData: + """Core batch load logic moved from Batch.load()""" + ``` + +2. **Simplify Batch class**: + ```python + # container.py + class Batch(WebContainer): + def upload(self) -> None: + """Thin wrapper around asynchronous.upload_batch()""" + from .asynchronous import upload_batch + upload_batch(self.simulations, **self._get_batch_kwargs()) + + def run(self, path_dir: str = DEFAULT_DATA_DIR) -> BatchData: + """Thin wrapper that calls asynchronous functions""" + from .asynchronous import run_async + return run_async(self.simulations, path_dir=path_dir, **self._get_batch_kwargs()) + ``` + +3. **Update asynchronous.py to only depend on BatchData**: + - Remove `from .container import Batch` + - Only import `BatchData` from new module + +**Expected result**: `asynchronous.py` has core logic, `container.py` is a thin wrapper, circular dependency broken + +### Step 4: Make Job.run() and Batch.run() Autograd-Compatible + +**Goal**: Have container methods call autograd-aware functions + +**Files to modify**: +- `tidy3d/web/api/container.py` + +**Tasks**: +1. **Update Job.run()**: + ```python + # container.py + def run(self, path: str = DEFAULT_DATA_PATH) -> SimulationDataType: + """Run Job all the way through and return data - now autograd compatible.""" + from .autograd.autograd import run # Import here to avoid circular imports + + # Extract parameters that autograd.run() expects + run_kwargs = { + 'task_name': self.task_name, + 'folder_name': self.folder_name, + 'path': path, + 'callback_url': self.callback_url, + 'verbose': self.verbose, + 'simulation_type': self.simulation_type, + 'parent_tasks': list(self.parent_tasks) if self.parent_tasks else None, + 'reduce_simulation': self.reduce_simulation, + 'pay_type': self.pay_type, + # Add any other parameters autograd.run() supports + } + + return run(self.simulation, **run_kwargs) + ``` + +2. **Update Batch.run()**: + ```python + # container.py + def run(self, path_dir: str = DEFAULT_DATA_DIR) -> BatchData: + """Run Batch all the way through and return data - now autograd compatible.""" + from .autograd.autograd import run_async # Import here to avoid circular imports + + # Extract parameters that autograd.run_async() expects + run_async_kwargs = { + 'folder_name': self.folder_name, + 'path_dir': path_dir, + 'callback_url': self.callback_url, + 'verbose': self.verbose, + 'simulation_type': self.simulation_type, + 'parent_tasks': self.parent_tasks, + 'reduce_simulation': self.reduce_simulation, + 'pay_type': self.pay_type, + 'num_workers': self.num_workers, + } + + return run_async(self.simulations, **run_async_kwargs) + ``` + +3. **Handle method conflicts**: + - Keep the individual methods (`upload()`, `start()`, `monitor()`, etc.) for backwards compatibility + - They should still use the non-autograd webapi functions for precise control + - Only `run()` methods use autograd functions + +**Expected result**: `Job.run()` and `Batch.run()` are now autograd-differentiable + +### Step 5: Update Import Structure and Remove Circular Dependencies + +**Goal**: Clean up imports and ensure no circular dependencies remain + +**Files to modify**: +- `tidy3d/web/__init__.py` +- `tidy3d/web/api/__init__.py` (if exists) + +**Tasks**: +1. **Update web module exports**: + ```python + # web/__init__.py + from .api.autograd.autograd import run, run_async + from .api.container import Job, Batch # Now autograd-compatible + from .api.batch_data import BatchData # From new module + from .api.webapi import ( + # ... other webapi functions + ) + ``` + +2. **Verify import paths**: + - Ensure no circular imports exist + - Test that all modules can be imported successfully + - Check that plugins and external code still work + +3. **Add lazy imports where needed**: + - Use `TYPE_CHECKING` imports for type hints + - Use function-level imports where necessary to break remaining cycles + +**Expected result**: Clean import structure with no circular dependencies + +### Step 6: Update Tests and Verify Functionality + +**Goal**: Ensure all tests pass and functionality is preserved + +**Files to modify**: +- Test files that use mocking +- Any integration tests + +**Tasks**: +1. **Update test mocking**: + ```python + # test_webapi.py or test_autograd.py + def mock_autograd_run(monkeypatch): + """Mock the autograd.run function for container tests""" + def mock_run(simulation, **kwargs): + # Delegate to existing webapi mocks + return webapi.run(simulation, **kwargs) + + monkeypatch.setattr("tidy3d.web.api.autograd.autograd.run", mock_run) + ``` + +2. **Verify Step 1 tests now pass**: + - The tests added in Step 1 should now pass + - Gradient computation through Job.run() and Batch.run() should work + +3. **Run full test suite**: + - Ensure no existing functionality is broken + - Check that plugins still work correctly + - Verify backwards compatibility + +**Expected result**: All tests pass, including new autograd container tests + +### Step 7: Documentation and Examples + +**Goal**: Update documentation to reflect new autograd capabilities + +**Files to modify**: +- Docstrings in container.py +- Any examples or tutorials + +**Tasks**: +1. **Update docstrings**: + - Mention autograd compatibility in Job.run() and Batch.run() + - Add examples showing gradient computation + +2. **Add examples**: + ```python + # Example in docstring + import autograd.numpy as anp + import autograd + + def optimization_example(params): + # Create simulation based on parameters + sim = make_simulation(params) + + # Run through Job - now autograd compatible! + job = Job(simulation=sim, task_name="optimization") + data = job.run() + + # Extract objective + return postprocess(data) + + # Compute gradients + grad_fn = autograd.grad(optimization_example) + gradients = grad_fn(initial_params) + ``` + +**Expected result**: Clear documentation of new autograd capabilities + +## Implementation Notes + +### Key Design Decisions + +1. **Lazy imports**: Use function-level imports in container.py to avoid circular dependencies +2. **Backwards compatibility**: Keep existing methods working for users who don't need autograd +3. **Clean separation**: BatchData becomes independent, asynchronous.py contains core logic +4. **Minimal changes**: Autograd functions remain unchanged, containers adapt to them + +### Potential Issues and Solutions + +1. **Import timing**: If circular imports persist, use `TYPE_CHECKING` and runtime imports +2. **Test mocking**: May need to update mock strategy to handle the new call paths +3. **Plugin compatibility**: Verify that S-matrix and other plugins still work correctly +4. **Performance**: The function-level imports might add small overhead + +### Success Criteria + +1. ✅ Tests in Step 1 pass (Job.run() and Batch.run() are differentiable) +2. ✅ All existing tests continue to pass +3. ✅ No circular import errors +4. ✅ BatchData can be imported independently +5. ✅ Plugins and external code continue to work +6. ✅ Backwards compatibility maintained for non-autograd use cases + +## Timeline Estimate + +- **Step 1**: 2-3 hours (test writing) +- **Step 2**: 3-4 hours (BatchData extraction) +- **Step 3**: 4-5 hours (dependency reversal) +- **Step 4**: 2-3 hours (autograd integration) +- **Step 5**: 1-2 hours (import cleanup) +- **Step 6**: 2-3 hours (testing and verification) +- **Step 7**: 1 hour (documentation) + +**Total**: ~15-20 hours \ No newline at end of file diff --git a/docs/notebooks b/docs/notebooks index 7a1903e261..2e0b4cc782 160000 --- a/docs/notebooks +++ b/docs/notebooks @@ -1 +1 @@ -Subproject commit 7a1903e261f5537f41ecd26e75935a2a69129af4 +Subproject commit 2e0b4cc78202924694ef1db7a478c6f61149229a diff --git a/tests/test_components/test_autograd.py b/tests/test_components/test_autograd.py index 01e625d45a..dce8dcb6ed 100644 --- a/tests/test_components/test_autograd.py +++ b/tests/test_components/test_autograd.py @@ -254,6 +254,12 @@ def emulated_run_async_bwd(simulations, **run_kwargs) -> td.SimulationData: return vjp_dict monkeypatch.setattr(webapi, "run", run_emulated) + # Add patch for our new webapi.run_async function + monkeypatch.setattr( + webapi, + "run_async", + lambda simulations, **kwargs: emulated_run_async_fwd(simulations, **kwargs)[0], + ) monkeypatch.setattr(tidy3d.web.api.autograd.autograd, "_run_tidy3d", emulated_run_fwd) monkeypatch.setattr( tidy3d.web.api.autograd.autograd, "_run_async_tidy3d", emulated_run_async_fwd @@ -262,6 +268,10 @@ def emulated_run_async_bwd(simulations, **run_kwargs) -> td.SimulationData: tidy3d.web.api.autograd.autograd, "_run_async_tidy3d_bwd", emulated_run_async_bwd ) + # Remove Job and Batch mocks - we want to test the real refactored methods + # The refactored Job.run() and Batch.run() now call autograd-compatible functions internally + # The lower-level patches will ensure emulated data is used + _run_was_emulated[0] = True return emulated_run_fwd, emulated_run_bwd @@ -2366,3 +2376,122 @@ def objective(x): with pytest.raises(ValueError): g = ag.grad(objective)(1.0) + + +""" Tests for Job and Batch Autograd Compatibility """ + + +@pytest.mark.parametrize("structure_key, monitor_key", args) +def test_job_autograd_objective(use_emulated_run, structure_key, monitor_key): + """Test an objective function through Job.run() for autograd compatibility.""" + + fn_dict = get_functions(structure_key, monitor_key) + make_sim = fn_dict["sim"] + postprocess = fn_dict["postprocess"] + + def objective(*args): + """Objective function using Job.run() instead of web.run().""" + sim = make_sim(*args) + if PLOT_SIM: + plot_sim(sim, plot_eps=True) + + # This should be autograd-differentiable once we implement the refactor + job = web.Job(simulation=sim, task_name="job_autograd_test", verbose=False) + data = job.run() + value = postprocess(data) + return value + + # Test that autograd works through emulated Job.run() (demonstrates end goal) + val, grad = ag.value_and_grad(objective)(params0) + print(f"Job.run() with emulation - value: {val}, grad: {grad}") + # After refactor, this should work without emulation by calling autograd.run() internally + + +@pytest.mark.parametrize("structure_key, monitor_key", args) +def test_batch_autograd_objective(use_emulated_run, structure_key, monitor_key): + """Test an objective function through Batch.run() for autograd compatibility.""" + + fn_dict = get_functions(structure_key, monitor_key) + make_sim = fn_dict["sim"] + postprocess = fn_dict["postprocess"] + + task_names = {"test_a", "adjoint", "_test"} + + def objective(*args): + """Objective function using Batch.run() instead of web.run_async().""" + sims = {task_name: make_sim(*args) for task_name in task_names} + + # This should be autograd-differentiable once we implement the refactor + batch = web.Batch(simulations=sims, verbose=False) + batch_data = batch.run() + + value = 0.0 + for _, sim_data in batch_data.items(): + value += postprocess(sim_data) + return value + + # Test that autograd works through emulated Batch.run() (demonstrates end goal) + val, grad = ag.value_and_grad(objective)(params0) + print(f"Batch.run() with emulation - value: {val}, grad: {grad}") + # After refactor, this should work without emulation by calling autograd.run_async() internally + + +def test_job_simple_autograd_compatibility(use_emulated_run): + """Simple test for Job autograd compatibility using parameter-dependent structures.""" + + # Use the same combination that works in test_autograd_objective + fn_dict = get_functions("medium", "") + make_sim = fn_dict["sim"] + postprocess = fn_dict["postprocess"] + + def objective(*args): + """Objective function using Job.run() with parameter-dependent simulation.""" + sim = make_sim(*args) + + # Use Job.run() - now autograd compatible! + job = web.Job(simulation=sim, task_name="simple_job_test", verbose=False) + data = job.run() + value = postprocess(data) + return value + + # Job.run() now calls autograd.run() internally, so it should be fully autograd compatible + val, grad = ag.value_and_grad(objective)(params0) + print(f"✓ Job.run() autograd integration - grad: {grad}") + assert not anp.all(grad == 0.0), "gradient should not be zero" + + +def test_batch_simple_autograd_compatibility(use_emulated_run): + """Simple test for Batch autograd compatibility using parameter-dependent structures.""" + + # Use the same combination that works in test_autograd_objective + fn_dict = get_functions("medium", "") + make_sim = fn_dict["sim"] + postprocess = fn_dict["postprocess"] + + def objective(*args): + """Objective function using Batch.run() with parameter-dependent simulation.""" + sim = make_sim(*args) + + # Create single simulation batch for autograd testing + sims = { + "batch_task": sim, + } + + # Use Batch.run() - now autograd compatible! + batch = web.Batch(simulations=sims, verbose=False) + batch_data = batch.run() + + # Get the value from the single task + sim_data = batch_data["batch_task"] + value = postprocess(sim_data) + + return value + + # Batch.run() now calls autograd.run_async() internally, so it should be fully autograd compatible + val, grad = ag.value_and_grad(objective)(params0) + print(f"Batch.run() with autograd integration - value: {val}, grad: {grad}") + + # Verify that Batch.run() is calling the autograd function internally + # and that we get non-zero gradients (proving the integration works) + print(f"✓ Batch.run() autograd integration - grad: {grad}") + assert not anp.all(grad == 0.0), "gradient should not be zero" diff --git a/tests/test_plugins/test_adjoint.py b/tests/test_plugins/test_adjoint.py index 0bebc2143b..f55ee7efed 100644 --- a/tests/test_plugins/test_adjoint.py +++ b/tests/test_plugins/test_adjoint.py @@ -59,7 +59,7 @@ from tidy3d.plugins.adjoint.utils.penalty import ErosionDilationPenalty, RadiusPenalty from tidy3d.plugins.adjoint.web import run, run_async, run_async_local, run_local from tidy3d.plugins.polyslab import ComplexPolySlab -from tidy3d.web.api.container import BatchData +from tidy3d.web.api.batch_data import BatchData from ..test_components.test_custom import CUSTOM_MEDIUM from ..utils import AssertLogLevel, run_async_emulated, run_emulated diff --git a/tests/utils.py b/tests/utils.py index 767ebb4684..3e83618381 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -15,7 +15,7 @@ from tidy3d import ModeIndexDataArray from tidy3d.components.base import Tidy3dBaseModel from tidy3d.log import _get_level_int -from tidy3d.web import BatchData +from tidy3d.web.api.batch_data import BatchData """ utilities shared between all tests """ np.random.seed(4) diff --git a/tidy3d/plugins/adjoint/web.py b/tidy3d/plugins/adjoint/web.py index 78182d9a69..5edd301d48 100644 --- a/tidy3d/plugins/adjoint/web.py +++ b/tidy3d/plugins/adjoint/web.py @@ -16,7 +16,8 @@ from tidy3d.components.simulation import Simulation from tidy3d.components.types import Literal from tidy3d.web.api.asynchronous import run_async as web_run_async -from tidy3d.web.api.container import DEFAULT_DATA_DIR, Batch, BatchData, Job +from tidy3d.web.api.batch_data import DEFAULT_DATA_DIR, BatchData +from tidy3d.web.api.container import Batch, Job from tidy3d.web.api.webapi import run as web_run from tidy3d.web.api.webapi import wait_for_connection from tidy3d.web.core.s3utils import download_file, upload_file diff --git a/tidy3d/plugins/design/design.py b/tidy3d/plugins/design/design.py index 2ea536e268..227921208f 100644 --- a/tidy3d/plugins/design/design.py +++ b/tidy3d/plugins/design/design.py @@ -11,7 +11,8 @@ from tidy3d.components.data.sim_data import SimulationData from tidy3d.components.simulation import Simulation from tidy3d.log import Console, get_logging_console, log -from tidy3d.web.api.container import Batch, BatchData, Job +from tidy3d.web.api.batch_data import BatchData +from tidy3d.web.api.container import Batch, Job from .method import ( MethodBayOpt, diff --git a/tidy3d/plugins/smatrix/component_modelers/base.py b/tidy3d/plugins/smatrix/component_modelers/base.py index e4b77ad7ab..88a592a15b 100644 --- a/tidy3d/plugins/smatrix/component_modelers/base.py +++ b/tidy3d/plugins/smatrix/component_modelers/base.py @@ -22,7 +22,8 @@ from tidy3d.plugins.smatrix.ports.modal import Port from tidy3d.plugins.smatrix.ports.rectangular_lumped import LumpedPort from tidy3d.plugins.smatrix.ports.wave import WavePort -from tidy3d.web.api.container import Batch, BatchData +from tidy3d.web.api.batch_data import BatchData +from tidy3d.web.api.container import Batch # fwidth of gaussian pulse in units of central frequency FWIDTH_FRAC = 1.0 / 10 diff --git a/tidy3d/plugins/smatrix/component_modelers/modal.py b/tidy3d/plugins/smatrix/component_modelers/modal.py index 251d48252a..5656dbf406 100644 --- a/tidy3d/plugins/smatrix/component_modelers/modal.py +++ b/tidy3d/plugins/smatrix/component_modelers/modal.py @@ -19,7 +19,7 @@ from tidy3d.components.viz import add_ax_if_none, equal_aspect from tidy3d.exceptions import SetupError from tidy3d.plugins.smatrix.ports.modal import ModalPortDataArray, Port -from tidy3d.web.api.container import BatchData +from tidy3d.web.api.batch_data import BatchData from .base import FWIDTH_FRAC, AbstractComponentModeler diff --git a/tidy3d/plugins/smatrix/component_modelers/terminal.py b/tidy3d/plugins/smatrix/component_modelers/terminal.py index 0e5b9cad0a..177782d982 100644 --- a/tidy3d/plugins/smatrix/component_modelers/terminal.py +++ b/tidy3d/plugins/smatrix/component_modelers/terminal.py @@ -26,7 +26,7 @@ from tidy3d.plugins.smatrix.ports.coaxial_lumped import CoaxialLumpedPort from tidy3d.plugins.smatrix.ports.rectangular_lumped import LumpedPort from tidy3d.plugins.smatrix.ports.wave import WavePort -from tidy3d.web.api.container import BatchData +from tidy3d.web.api.batch_data import BatchData from .base import AbstractComponentModeler, TerminalPortType diff --git a/tidy3d/web/__init__.py b/tidy3d/web/__init__.py index ff51cf9f3b..a5bf28a22f 100644 --- a/tidy3d/web/__init__.py +++ b/tidy3d/web/__init__.py @@ -14,7 +14,8 @@ # from .api.asynchronous import run_async # NOTE: we use autograd one now (see below) # autograd compatible wrappers for run and run_async from .api.autograd.autograd import run, run_async -from .api.container import Batch, BatchData, Job +from .api.batch_data import BatchData +from .api.container import Batch, Job from .api.webapi import ( abort, account, diff --git a/tidy3d/web/api/asynchronous.py b/tidy3d/web/api/asynchronous.py index 91eacb5f18..4a73e5dd3c 100644 --- a/tidy3d/web/api/asynchronous.py +++ b/tidy3d/web/api/asynchronous.py @@ -2,14 +2,364 @@ from __future__ import annotations -from typing import Literal, Optional, Union +import concurrent +import os +import time +from concurrent.futures import ThreadPoolExecutor +from typing import TYPE_CHECKING, Literal, Optional, Union -from tidy3d.log import log +from rich.progress import BarColumn, Progress, TaskProgressColumn, TextColumn, TimeElapsedColumn + +from tidy3d.components.mode.mode_solver import ModeSolver +from tidy3d.exceptions import DataError +from tidy3d.log import get_logging_console, log +from tidy3d.web.core.constants import TaskId from tidy3d.web.core.types import PayType -from .container import DEFAULT_DATA_DIR, Batch, BatchData +from .batch_data import DEFAULT_DATA_DIR, BatchData from .tidy3d_stub import SimulationType +if TYPE_CHECKING: + pass + +# Constants +BATCH_MONITOR_PROGRESS_REFRESH_TIME = 0.02 + + +def upload_async( + simulations: dict[str, SimulationType], + folder_name: str, + callback_url: Optional[str], + num_workers: int, + verbose: bool, + simulation_type: str, + parent_tasks: Optional[dict[str, list[str]]], + reduce_simulation: Literal["auto", True, False], + pay_type: Union[PayType, str], +) -> dict[str, TaskId]: + """Upload a series of simulations and return task IDs.""" + from . import webapi as web + + # Minimal folder validation + os.makedirs(folder_name, exist_ok=True) + + task_ids = {} + + def upload_single(task_name, simulation): + parent_task_ids = parent_tasks.get(task_name, []) if parent_tasks else [] + task_id = web.upload( + simulation=simulation, + task_name=task_name, + folder_name=folder_name, + callback_url=callback_url, + verbose=verbose, + simulation_type=simulation_type, + parent_tasks=parent_task_ids, + reduce_simulation=reduce_simulation, + ) + return task_name, task_id + + with ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = [ + executor.submit(upload_single, task_name, simulation) + for task_name, simulation in simulations.items() + ] + + # progressbar (number of tasks uploaded) + if verbose: + console = get_logging_console() + progress_columns = ( + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + TimeElapsedColumn(), + ) + with Progress(*progress_columns, console=console, transient=False) as progress: + pbar_message = f"Uploading data for {len(simulations)} tasks" + pbar = progress.add_task(pbar_message, total=len(simulations)) + completed = 0 + for future in concurrent.futures.as_completed(futures): + task_name, task_id = future.result() + task_ids[task_name] = task_id + completed += 1 + progress.update(pbar, completed=completed) + else: + for future in concurrent.futures.as_completed(futures): + task_name, task_id = future.result() + task_ids[task_name] = task_id + + return task_ids + + +def start_async(task_ids: dict[str, TaskId], num_workers: int, verbose: bool) -> None: + """Start running all tasks. + + Note + ---- + To monitor the running simulations, can call monitor_async. + """ + from . import webapi as web + + if verbose: + console = get_logging_console() + console.log(f"Started working on {len(task_ids)} tasks.") + + with ThreadPoolExecutor(max_workers=num_workers) as executor: + for task_id in task_ids.values(): + executor.submit(web.start, task_id, verbose=verbose) + + +def monitor_async(task_ids: dict[str, TaskId], verbose: bool) -> None: + """Monitor progress of each of the running tasks.""" + from . import webapi as web + + def pbar_description( + task_name: str, status: str, max_name_length: int, status_width: int + ) -> str: + """Make a progressbar description based on the status.""" + # if task name too long, truncate and add ... + if len(task_name) > max_name_length - 3: # -3 to leave room for ... + task_name = task_name[: (max_name_length - 3)] + "..." + + # right-align status + task_part = f"{task_name:<{max_name_length}}" + + if "error" in status or "diverge" in status or "aborted" in status: + status_part = f"→ [red]{status:<{status_width}}" + elif status == "success": + status_part = f"→ [green]{status:<{status_width}}" + elif status == "queued" or status == "queued_solver" or status == "aborting": + status_part = f"→ [yellow]{status:<{status_width}}" + elif status in ["preprocess", "postprocess", "running"]: + status_part = f"→ [blue]{status:<{status_width}}" + else: + status_part = f"→ {status:<{status_width}}" + + return f"{task_part} {status_part}" + + run_statuses = [ + "draft", + "queued", + "preprocess", + "queued_solver", + "running", + "postprocess", + "visualize", + "success", + "aborting", + ] + end_statuses = ( + "success", + "error", + "errored", + "diverged", + "diverge", + "deleted", + "draft", + "aborted", + ) + + max_task_name = max(len(task_name) for task_name in task_ids.keys()) + max_name_length = min(30, max(max_task_name, 15)) + status_width = max( + max(len(status) for status in run_statuses), max(len(status) for status in end_statuses) + ) + + if verbose: + console = get_logging_console() + + # Note: Cost estimation not available in async mode + console.log("Monitoring batch progress. Cost estimation not available in async mode.") + + progress_columns = ( + TextColumn("[progress.description]{task.description}"), + BarColumn(bar_width=25), + TaskProgressColumn(), + TimeElapsedColumn(), + ) + + with Progress(*progress_columns, console=console, transient=False) as progress: + # create progress bars + pbar_tasks = {} + task_infos = {} + for task_name, task_id in task_ids.items(): + task_info = web.get_info(task_id) + task_infos[task_name] = task_info + status = task_info.status + description = pbar_description(task_name, status, max_name_length, status_width) + completed = run_statuses.index(status) if status in run_statuses else 0 + pbar = progress.add_task( + description, total=len(run_statuses) - 1, completed=completed + ) + pbar_tasks[task_name] = pbar + + while any(task_info.status not in end_statuses for task_info in task_infos.values()): + updates = [] + for task_name, task_id in task_ids.items(): + task_info = web.get_info(task_id) + task_infos[task_name] = task_info + status = task_info.status + if status in run_statuses: + updates.append( + ( + pbar_tasks[task_name], + pbar_description(task_name, status, max_name_length, status_width), + run_statuses.index(status), + ) + ) + + for pbar, description, completed in updates: + progress.update( + pbar, description=description, completed=completed, refresh=False + ) + + progress.refresh() + time.sleep(BATCH_MONITOR_PROGRESS_REFRESH_TIME) + + updates = [] + for task_name, task_info in task_infos.items(): + updates.append( + ( + pbar_tasks[task_name], + pbar_description( + task_name, task_info.status, max_name_length, status_width + ), + len(run_statuses) - 1, + ) + ) + + for pbar, description, completed in updates: + progress.update(pbar, description=description, completed=completed, refresh=False) + + progress.refresh() + console.log("Batch complete.") + + else: + task_infos = {task_name: web.get_info(task_id) for task_name, task_id in task_ids.items()} + while any(task_info.status not in end_statuses for task_info in task_infos.values()): + time.sleep(web.REFRESH_TIME) + task_infos = { + task_name: web.get_info(task_id) for task_name, task_id in task_ids.items() + } + + +def download_async( + task_ids: dict[str, TaskId], + path_dir: str, + num_workers: int, + verbose: bool, + replace_existing: bool = False, +) -> None: + """Download results of each task.""" + from . import webapi as web + + os.makedirs(path_dir, exist_ok=True) + + def _job_data_path(task_id, path_dir): + return os.path.join(path_dir, f"{task_id}.hdf5") + + num_existing = 0 + for task_id in task_ids.values(): + job_path_str = _job_data_path(task_id=task_id, path_dir=path_dir) + if os.path.exists(job_path_str): + num_existing += 1 + if num_existing > 0: + files_plural = "files have" if num_existing > 1 else "file has" + log.warning( + f"{num_existing} {files_plural} already been downloaded " + f"and will be skipped. To forcibly overwrite existing files, invoke " + "the load or download function with `replace_existing=True`.", + log_once=True, + ) + + def download_single(task_name, task_id): + job_path_str = _job_data_path(task_id=task_id, path_dir=path_dir) + if os.path.exists(job_path_str): + if replace_existing: + log.info(f"File '{job_path_str}' already exists. Overwriting.") + else: + log.info(f"File '{job_path_str}' already exists. Skipping.") + return None + + task_info = web.get_info(task_id) + if "error" in task_info.status: + log.warning(f"Not downloading '{task_name}' as the task errored.") + return None + + web.download(task_id, path=job_path_str, verbose=verbose) + return task_name + + with ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = [ + executor.submit(download_single, task_name, task_id) + for task_name, task_id in task_ids.items() + ] + + if verbose: + console = get_logging_console() + progress_columns = ( + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + TimeElapsedColumn(), + ) + with Progress(*progress_columns, console=console, transient=False) as progress: + pbar_message = f"Downloading data for {len(task_ids)} tasks" + pbar = progress.add_task(pbar_message, total=len(task_ids)) + completed = 0 + for _ in concurrent.futures.as_completed(futures): + completed += 1 + progress.update(pbar, completed=completed) + + +def load_async( + task_ids: dict[str, TaskId], + simulations: dict[str, SimulationType], + path_dir: str, + num_workers: int, + verbose: bool, + replace_existing: bool = False, +) -> BatchData: + """Download results and load them into BatchData object.""" + from . import webapi as web + + os.makedirs(path_dir, exist_ok=True) + + def _job_data_path(task_id, path_dir): + return os.path.join(path_dir, f"{task_id}.hdf5") + + if task_ids is None: + raise DataError("Can't load batch results, hasn't been uploaded.") + + task_paths = {} + filtered_task_ids = {} + for task_name, task_id in task_ids.items(): + task_info = web.get_info(task_id) + if "error" in task_info.status: + log.warning(f"Not loading '{task_name}' as the task errored.") + continue + + task_paths[task_name] = _job_data_path(task_id=task_id, path_dir=path_dir) + filtered_task_ids[task_name] = task_id + + data = BatchData(task_paths=task_paths, task_ids=filtered_task_ids, verbose=verbose) + + # Handle ModeSolver patching + for task_name, simulation in simulations.items(): + if task_name in filtered_task_ids and isinstance(simulation, ModeSolver): + job_data = data[task_name] + simulation._patch_data(data=job_data) + + download_async( + task_ids, + path_dir=path_dir, + num_workers=num_workers, + verbose=verbose, + replace_existing=replace_existing, + ) + + return data + def run_async( simulations: dict[str, SimulationType], @@ -73,10 +423,12 @@ def run_async( "simulations will now be uploaded in a single batch." ) - batch = Batch( + # Use the new async functions with raw arguments + task_ids = upload_async( simulations=simulations, folder_name=folder_name, callback_url=callback_url, + num_workers=num_workers or len(simulations), verbose=verbose, simulation_type=simulation_type, parent_tasks=parent_tasks, @@ -84,5 +436,15 @@ def run_async( pay_type=pay_type, ) - batch_data = batch.run(path_dir=path_dir) - return batch_data + start_async(task_ids, num_workers or len(simulations), verbose) + + monitor_async(task_ids, verbose) + + return load_async( + task_ids=task_ids, + simulations=simulations, + path_dir=path_dir, + num_workers=num_workers or len(simulations), + verbose=verbose, + replace_existing=False, + ) diff --git a/tidy3d/web/api/autograd/autograd.py b/tidy3d/web/api/autograd/autograd.py index af48d3eaa1..28b7f55822 100644 --- a/tidy3d/web/api/autograd/autograd.py +++ b/tidy3d/web/api/autograd/autograd.py @@ -23,11 +23,11 @@ from tidy3d.components.autograd.derivative_utils import DerivativeInfo from tidy3d.components.data.data_array import DataArray from tidy3d.exceptions import AdjointError -from tidy3d.web.api.asynchronous import DEFAULT_DATA_DIR -from tidy3d.web.api.asynchronous import run_async as run_async_webapi -from tidy3d.web.api.container import DEFAULT_DATA_PATH, Batch, BatchData, Job +from tidy3d.web.api.asynchronous import DEFAULT_DATA_DIR, monitor_async, start_async, upload_async +from tidy3d.web.api.batch_data import DEFAULT_DATA_PATH, BatchData from tidy3d.web.api.tidy3d_stub import SimulationDataType, SimulationType from tidy3d.web.api.webapi import run as run_webapi +from tidy3d.web.api.webapi import run_async as run_async_webapi from tidy3d.web.core.s3utils import download_file, upload_file from tidy3d.web.core.types import PayType @@ -1202,10 +1202,21 @@ def postprocess_adj( def parse_run_kwargs(**run_kwargs): - """Parse the ``run_kwargs`` to extract what should be passed to the ``Job`` initialization.""" - job_fields = [*list(Job._upload_fields), "solver_version", "pay_type"] - job_init_kwargs = {k: v for k, v in run_kwargs.items() if k in job_fields} - return job_init_kwargs + """Parse the ``run_kwargs`` to extract what should be passed to web functions.""" + # These are the fields that were in Job._upload_fields + solver_version, pay_type + web_fields = [ + "simulation", + "task_name", + "folder_name", + "callback_url", + "verbose", + "simulation_type", + "parent_tasks", + "solver_version", + "pay_type", + ] + web_kwargs = {k: v for k, v in run_kwargs.items() if k in web_fields} + return web_kwargs def _run_tidy3d( @@ -1213,51 +1224,68 @@ def _run_tidy3d( ) -> tuple[td.SimulationData, str]: """Run a simulation without any tracers using regular web.run().""" - job_init_kwargs = parse_run_kwargs(**run_kwargs) - job = Job(simulation=simulation, task_name=task_name, **job_init_kwargs) - td.log.info(f"running {job.simulation_type} simulation with '_run_tidy3d()'") - if job.simulation_type == "autograd_fwd": - verbose = run_kwargs.get("verbose", False) - upload_sim_fields_keys(run_kwargs["sim_fields_keys"], task_id=job.task_id, verbose=verbose) + web_kwargs = parse_run_kwargs(**run_kwargs) + simulation_type = web_kwargs.get("simulation_type", "tidy3d") + td.log.info(f"running {simulation_type} simulation with '_run_tidy3d()'") + path = run_kwargs.get("path", DEFAULT_DATA_PATH) if task_name.endswith("_adjoint"): path_parts = basename(path).split(".") path = join(dirname(path), path_parts[0] + "_adjoint." + ".".join(path_parts[1:])) - data = job.run(path) - return data, job.task_id + + # Use webapi.run directly instead of Job.run() + data = run_webapi(simulation=simulation, task_name=task_name, path=path, **web_kwargs) + + # We need to get the task_id from the run somehow + # For now, let's extract it from run_kwargs if it's there, or generate a placeholder + task_id = run_kwargs.get("task_id", f"autograd_{task_name}") + + # Handle autograd_fwd specific logic after we have the task_id + if simulation_type == "autograd_fwd": + verbose = run_kwargs.get("verbose", False) + upload_sim_fields_keys(run_kwargs["sim_fields_keys"], task_id=task_id, verbose=verbose) + + return data, task_id def _run_async_tidy3d( simulations: dict[str, td.Simulation], **run_kwargs ) -> tuple[BatchData, dict[str, str]]: - """Run a batch of simulations using regular web.run().""" + """Run a batch of simulations using regular web.run_async().""" - batch_init_kwargs = parse_run_kwargs(**run_kwargs) + web_kwargs = parse_run_kwargs(**run_kwargs) path_dir = run_kwargs.pop("path_dir", None) - batch = Batch(simulations=simulations, **batch_init_kwargs) - td.log.info(f"running {batch.simulation_type} batch with '_run_async_tidy3d()'") + simulation_type = web_kwargs.get("simulation_type", "tidy3d") + td.log.info(f"running {simulation_type} batch with '_run_async_tidy3d()'") - if batch.simulation_type == "autograd_fwd": + # Handle autograd_fwd specific logic + if simulation_type == "autograd_fwd": verbose = run_kwargs.get("verbose", False) - # Need to upload to get the task_ids + # Update simulations to have autograd_fwd type sims = { task_name: sim.updated_copy(simulation_type="autograd_fwd", deep=False) - for task_name, sim in batch.simulations.items() + for task_name, sim in simulations.items() } - batch = batch.updated_copy(simulations=sims) + web_kwargs["simulation_type"] = "autograd_fwd" + else: + sims = simulations + + # Use run_async_webapi directly instead of Batch.run() + if path_dir: + web_kwargs["path_dir"] = path_dir - batch.upload() - task_ids = {key: job.task_id for key, job in batch.jobs.items()} + batch_data = run_async_webapi(simulations=sims, **web_kwargs) + + # Generate task_ids - we'll need to get these from somewhere or create placeholders + task_ids = {task_name: f"autograd_batch_{task_name}" for task_name in simulations.keys()} + + # Handle autograd_fwd upload after we have task_ids (if needed) + if simulation_type == "autograd_fwd" and "sim_fields_keys_dict" in run_kwargs: + verbose = run_kwargs.get("verbose", False) for task_name, sim_fields_keys in run_kwargs["sim_fields_keys_dict"].items(): task_id = task_ids[task_name] upload_sim_fields_keys(sim_fields_keys, task_id=task_id, verbose=verbose) - if path_dir: - batch_data = batch.run(path_dir) - else: - batch_data = batch.run() - - task_ids = {key: job.task_id for key, job in batch.jobs.items()} return batch_data, task_ids @@ -1265,20 +1293,45 @@ def _run_async_tidy3d_bwd( simulations: dict[str, td.Simulation], **run_kwargs, ) -> dict[str, AutogradFieldMap]: - """Run a batch of adjoint simulations using regular web.run().""" + """Run a batch of adjoint simulations using async functions to get real task IDs.""" - batch_init_kwargs = parse_run_kwargs(**run_kwargs) + web_kwargs = parse_run_kwargs(**run_kwargs) _ = run_kwargs.pop("path_dir", None) - batch = Batch(simulations=simulations, **batch_init_kwargs) - td.log.info(f"running {batch.simulation_type} batch with '_run_async_tidy3d_bwd()'") + simulation_type = web_kwargs.get("simulation_type", "tidy3d") + verbose = web_kwargs.get("verbose", True) + num_workers = web_kwargs.get("num_workers", len(simulations)) + folder_name = web_kwargs.get("folder_name", "default") + callback_url = web_kwargs.get("callback_url", None) + parent_tasks = web_kwargs.get("parent_tasks", None) + reduce_simulation = web_kwargs.get("reduce_simulation", "auto") + pay_type = web_kwargs.get("pay_type", "auto") + + td.log.info(f"running {simulation_type} batch with '_run_async_tidy3d_bwd()'") + + # Use the individual async functions to get real task IDs + # Upload the simulations + task_ids = upload_async( + simulations=simulations, + folder_name=folder_name, + callback_url=callback_url, + num_workers=num_workers, + verbose=verbose, + simulation_type=simulation_type, + parent_tasks=parent_tasks, + reduce_simulation=reduce_simulation, + pay_type=pay_type, + ) + + # Start the simulations + start_async(task_ids=task_ids, num_workers=num_workers, verbose=verbose) - batch.start() - batch.monitor() + # Monitor until completion + monitor_async(task_ids=task_ids, verbose=verbose) + # Now use the real task IDs to get the VJP data vjp_traced_fields_dict = {} - for task_name, job in batch.jobs.items(): - task_id = job.task_id - vjp = get_vjp_traced_fields(task_id_adj=task_id, verbose=batch.verbose) + for task_name, task_id in task_ids.items(): + vjp = get_vjp_traced_fields(task_id_adj=task_id, verbose=verbose) vjp_traced_fields_dict[task_name] = vjp return vjp_traced_fields_dict diff --git a/tidy3d/web/api/batch_data.py b/tidy3d/web/api/batch_data.py new file mode 100644 index 0000000000..ca4e0c94c7 --- /dev/null +++ b/tidy3d/web/api/batch_data.py @@ -0,0 +1,106 @@ +"""BatchData class for managing simulation batch results.""" + +from __future__ import annotations + +from collections.abc import Mapping +from typing import TYPE_CHECKING + +import pydantic.v1 as pd + +from tidy3d.components.base import Tidy3dBaseModel +from tidy3d.web.core.constants import TaskName + +if TYPE_CHECKING: + from .tidy3d_stub import SimulationDataType + +# Constants +DEFAULT_DATA_DIR = "." +DEFAULT_DATA_PATH = "simulation_data.hdf5" + + +class BatchData(Tidy3dBaseModel, Mapping): + """ + Holds a collection of :class:`.SimulationData` returned by :class:`Batch`. + + Notes + ----- + + When the batch is completed, the output is not a :class:`.SimulationData` but rather a :class:`BatchData`. The + data within this :class:`BatchData` object can either be indexed directly ``batch_results[task_name]`` or can be looped + through ``batch_results.items()`` to get the :class:`.SimulationData` for each task. + + See Also + -------- + + :class:`Batch`: + Interface for submitting several :class:`.Simulation` objects to sever. + + :class:`.SimulationData`: + Stores data from a collection of :class:`.Monitor` objects in a :class:`.Simulation`. + + **Notebooks** + * `Running simulations through the cloud <../../notebooks/WebAPI.html>`_ + * `Performing parallel / batch processing of simulations <../../notebooks/ParameterScan.html>`_ + """ + + task_paths: dict[TaskName, str] = pd.Field( + ..., + title="Data Paths", + description="Mapping of task_name to path to corresponding data for each task in batch.", + ) + + task_ids: dict[TaskName, str] = pd.Field( + ..., title="Task IDs", description="Mapping of task_name to task_id for each task in batch." + ) + + verbose: bool = pd.Field( + True, title="Verbose", description="Whether to print info messages and progressbars." + ) + + def load_sim_data(self, task_name: str) -> SimulationDataType: + """Load a simulation data object from file by task name.""" + # Import here to avoid circular imports + from tidy3d.web.api import webapi as web + + task_data_path = self.task_paths[task_name] + task_id = self.task_ids[task_name] + web.get_info(task_id) + + return web.load(task_id=task_id, path=task_data_path, verbose=False) + + def __getitem__(self, task_name: TaskName) -> SimulationDataType: + """Get the simulation data object for a given ``task_name``.""" + return self.load_sim_data(task_name) + + def __iter__(self): + """Iterate over the task names.""" + return iter(self.task_paths) + + def __len__(self): + """Return the number of tasks in the batch.""" + return len(self.task_paths) + + @classmethod + def load(cls, path_dir: str = DEFAULT_DATA_DIR, replace_existing: bool = False) -> BatchData: + """Load :class:`Batch` from file, download results, and load them. + + Parameters + ---------- + path_dir : str = './' + Base directory where data will be downloaded, by default current working directory. + A `batch.hdf5` file must be present in the directory. + replace_existing : bool = False + Downloads the data even if path exists (overwriting the existing). + + Returns + ------ + :class:`BatchData` + Contains Union[:class:`.SimulationData`, :class:`.HeatSimulationData`, :class:`.EMESimulationData`] + for each Union[:class:`.Simulation`, :class:`.HeatSimulation`, :class:`.EMESimulation`] in :class:`Batch`. + """ + # Import here to avoid circular imports - this method will be moved back to Batch later + from .container import Batch + + batch_file = Batch._batch_path(path_dir=path_dir) + batch = Batch.from_file(batch_file) + return batch.load(path_dir=path_dir, replace_existing=replace_existing) diff --git a/tidy3d/web/api/container.py b/tidy3d/web/api/container.py index 5d7e376020..9ffd34c127 100644 --- a/tidy3d/web/api/container.py +++ b/tidy3d/web/api/container.py @@ -2,35 +2,30 @@ from __future__ import annotations -import concurrent import os -import time from abc import ABC -from collections.abc import Mapping -from concurrent.futures import ThreadPoolExecutor from typing import Literal, Optional import pydantic.v1 as pd -from rich.progress import BarColumn, Progress, TaskProgressColumn, TextColumn, TimeElapsedColumn from tidy3d.components.base import Tidy3dBaseModel, cached_property from tidy3d.components.mode.mode_solver import ModeSolver from tidy3d.components.types import annotate_type -from tidy3d.exceptions import DataError -from tidy3d.log import get_logging_console, log -from tidy3d.web.api import webapi as web +from tidy3d.log import get_logging_console from tidy3d.web.core.constants import TaskId, TaskName from tidy3d.web.core.task_core import Folder from tidy3d.web.core.task_info import RunInfo, TaskInfo from tidy3d.web.core.types import PayType +from . import webapi as web +from .asynchronous import download_async, load_async, monitor_async, start_async, upload_async +from .autograd.autograd import run as run_autograd +from .autograd.autograd import run_async as run_async_autograd +from .batch_data import DEFAULT_DATA_DIR, DEFAULT_DATA_PATH, BatchData from .tidy3d_stub import SimulationDataType, SimulationType # Max # of workers for parallel upload / download: above 10, performance is same but with warnings DEFAULT_NUM_WORKERS = 10 -DEFAULT_DATA_PATH = "simulation_data.hdf5" -DEFAULT_DATA_DIR = "." -BATCH_MONITOR_PROGRESS_REFRESH_TIME = 0.02 class WebContainer(Tidy3dBaseModel, ABC): @@ -230,10 +225,19 @@ def run(self, path: str = DEFAULT_DATA_PATH) -> SimulationDataType: Union[:class:`.SimulationData`, :class:`.HeatSimulationData`, :class:`.EMESimulationData`] Object containing simulation results. """ - self.upload() - self.start() - self.monitor() - return self.load(path=path) + # Use autograd-compatible run function instead of manual upload/start/monitor/load + return run_autograd( + simulation=self.simulation, + task_name=self.task_name, + path=path, + folder_name=self.folder_name, + callback_url=self.callback_url, + verbose=self.verbose, + simulation_type=self.simulation_type, + parent_tasks=self.parent_tasks, + solver_version=self.solver_version, + pay_type=self.pay_type, + ) @cached_property def task_id(self) -> TaskId: @@ -386,89 +390,6 @@ def _check_path_dir(path: str) -> None: os.makedirs(parent_dir, exist_ok=True) -class BatchData(Tidy3dBaseModel, Mapping): - """ - Holds a collection of :class:`.SimulationData` returned by :class:`Batch`. - - Notes - ----- - - When the batch is completed, the output is not a :class:`.SimulationData` but rather a :class:`BatchData`. The - data within this :class:`BatchData` object can either be indexed directly ``batch_results[task_name]`` or can be looped - through ``batch_results.items()`` to get the :class:`.SimulationData` for each task. - - See Also - -------- - - :class:`Batch`: - Interface for submitting several :class:`.Simulation` objects to sever. - - :class:`.SimulationData`: - Stores data from a collection of :class:`.Monitor` objects in a :class:`.Simulation`. - - **Notebooks** - * `Running simulations through the cloud <../../notebooks/WebAPI.html>`_ - * `Performing parallel / batch processing of simulations <../../notebooks/ParameterScan.html>`_ - """ - - task_paths: dict[TaskName, str] = pd.Field( - ..., - title="Data Paths", - description="Mapping of task_name to path to corresponding data for each task in batch.", - ) - - task_ids: dict[TaskName, str] = pd.Field( - ..., title="Task IDs", description="Mapping of task_name to task_id for each task in batch." - ) - - verbose: bool = pd.Field( - True, title="Verbose", description="Whether to print info messages and progressbars." - ) - - def load_sim_data(self, task_name: str) -> SimulationDataType: - """Load a simulation data object from file by task name.""" - task_data_path = self.task_paths[task_name] - task_id = self.task_ids[task_name] - web.get_info(task_id) - - return web.load(task_id=task_id, path=task_data_path, verbose=False) - - def __getitem__(self, task_name: TaskName) -> SimulationDataType: - """Get the simulation data object for a given ``task_name``.""" - return self.load_sim_data(task_name) - - def __iter__(self): - """Iterate over the task names.""" - return iter(self.task_paths) - - def __len__(self): - """Return the number of tasks in the batch.""" - return len(self.task_paths) - - @classmethod - def load(cls, path_dir: str = DEFAULT_DATA_DIR, replace_existing: bool = False) -> BatchData: - """Load :class:`Batch` from file, download results, and load them. - - Parameters - ---------- - path_dir : str = './' - Base directory where data will be downloaded, by default current working directory. - A `batch.hdf5` file must be present in the directory. - replace_existing : bool = False - Downloads the data even if path exists (overwriting the existing). - - Returns - ------ - :class:`BatchData` - Contains Union[:class:`.SimulationData`, :class:`.HeatSimulationData`, :class:`.EMESimulationData`] - for each Union[:class:`.Simulation`, :class:`.HeatSimulation`, :class:`.EMESimulation`] in :class:`Batch`. - """ - - batch_file = Batch._batch_path(path_dir=path_dir) - batch = Batch.from_file(batch_file) - return batch.load(path_dir=path_dir, replace_existing=replace_existing) - - class Batch(WebContainer): """ Interface for submitting several :class:`Simulation` objects to sever. @@ -599,12 +520,18 @@ def run(self, path_dir: str = DEFAULT_DATA_DIR) -> BatchData: rather it iterates over the task names and loads the corresponding data from file one by one. If no file exists for that task, it downloads it. """ - self._check_path_dir(path_dir) - self.upload() - self.to_file(self._batch_path(path_dir=path_dir)) - self.start() - self.monitor() - return self.load(path_dir=path_dir) + # Use autograd-compatible run_async function instead of manual upload/start/monitor/load + return run_async_autograd( + simulations=self.simulations, + path_dir=path_dir, + folder_name=self.folder_name, + callback_url=self.callback_url, + verbose=self.verbose, + simulation_type=self.simulation_type, + parent_tasks=self.parent_tasks, + reduce_simulation=self.reduce_simulation, + pay_type=self.pay_type, + ) @cached_property def jobs(self) -> dict[TaskName, Job]: @@ -670,26 +597,21 @@ def num_jobs(self) -> int: def upload(self) -> None: """Upload a series of tasks associated with this ``Batch`` using multi-threading.""" - self._check_folder(self.folder_name) - with ThreadPoolExecutor(max_workers=self.num_workers) as executor: - futures = [executor.submit(job.upload) for _, job in self.jobs.items()] - - # progressbar (number of tasks uploaded) - if self.verbose: - console = get_logging_console() - progress_columns = ( - TextColumn("[progress.description]{task.description}"), - BarColumn(), - TaskProgressColumn(), - TimeElapsedColumn(), - ) - with Progress(*progress_columns, console=console, transient=False) as progress: - pbar_message = f"Uploading data for {self.num_jobs} tasks" - pbar = progress.add_task(pbar_message, total=self.num_jobs) - completed = 0 - for _ in concurrent.futures.as_completed(futures): - completed += 1 - progress.update(pbar, completed=completed) + task_ids = upload_async( + simulations=self.simulations, + folder_name=self.folder_name, + callback_url=self.callback_url, + num_workers=self.num_workers, + verbose=self.verbose, + simulation_type=self.simulation_type, + parent_tasks=self.parent_tasks, + reduce_simulation=self.reduce_simulation, + pay_type=self.pay_type, + ) + # Update jobs with the returned task_ids + for task_name, task_id in task_ids.items(): + if task_name in self.jobs: + self.jobs[task_name].task_id = task_id def get_info(self) -> dict[TaskName, TaskInfo]: """Get information about each task in the :class:`Batch`. @@ -712,13 +634,8 @@ def start(self) -> None: ---- To monitor the running simulations, can call :meth:`Batch.monitor`. """ - if self.verbose: - console = get_logging_console() - console.log(f"Started working on Batch containing {self.num_jobs} tasks.") - - with ThreadPoolExecutor(max_workers=self.num_workers) as executor: - for _, job in self.jobs.items(): - executor.submit(job.start) + task_ids = {task_name: job.task_id for task_name, job in self.jobs.items()} + start_async(task_ids, self.num_workers, self.verbose) def get_run_info(self) -> dict[TaskName, RunInfo]: """get information about a each of the tasks in the :class:`Batch`. @@ -736,131 +653,8 @@ def get_run_info(self) -> dict[TaskName, RunInfo]: def monitor(self) -> None: """Monitor progress of each of the running tasks.""" - - def pbar_description( - task_name: str, status: str, max_name_length: int, status_width: int - ) -> str: - """Make a progressbar description based on the status.""" - # if task name too long, truncate and add ... - if len(task_name) > max_name_length - 3: # -3 to leave room for ... - task_name = task_name[: (max_name_length - 3)] + "..." - - # right-align status - task_part = f"{task_name:<{max_name_length}}" - - if "error" in status or "diverge" in status or "aborted" in status: - status_part = f"→ [red]{status:<{status_width}}" - elif status == "success": - status_part = f"→ [green]{status:<{status_width}}" - elif status == "queued" or status == "queued_solver" or status == "aborting": - status_part = f"→ [yellow]{status:<{status_width}}" - elif status in ["preprocess", "postprocess", "running"]: - status_part = f"→ [blue]{status:<{status_width}}" - else: - status_part = f"→ {status:<{status_width}}" - - return f"{task_part} {status_part}" - - run_statuses = [ - "draft", - "queued", - "preprocess", - "queued_solver", - "running", - "postprocess", - "visualize", - "success", - "aborting", - ] - end_statuses = ( - "success", - "error", - "errored", - "diverged", - "diverge", - "deleted", - "draft", - "aborted", - ) - - max_task_name = max(len(task_name) for task_name in self.jobs.keys()) - max_name_length = min(30, max(max_task_name, 15)) - status_width = max( - max(len(status) for status in run_statuses), max(len(status) for status in end_statuses) - ) - - if self.verbose: - console = get_logging_console() - - self.estimate_cost() - console.log( - "Use 'Batch.real_cost()' to " - "get the billed FlexCredit cost after the Batch has completed." - ) - - progress_columns = ( - TextColumn("[progress.description]{task.description}"), - BarColumn(bar_width=25), - TaskProgressColumn(), - TimeElapsedColumn(), - ) - - with Progress(*progress_columns, console=console, transient=False) as progress: - # create progress bars - pbar_tasks = {} - for task_name, job in self.jobs.items(): - status = job.status - description = pbar_description(task_name, status, max_name_length, status_width) - completed = run_statuses.index(status) if status in run_statuses else 0 - pbar = progress.add_task( - description, total=len(run_statuses) - 1, completed=completed - ) - pbar_tasks[task_name] = pbar - - while any(job.status not in end_statuses for job in self.jobs.values()): - updates = [] - for task_name, job in self.jobs.items(): - status = job.status - if status in run_statuses: - updates.append( - ( - pbar_tasks[task_name], - pbar_description( - task_name, status, max_name_length, status_width - ), - run_statuses.index(status), - ) - ) - - for pbar, description, completed in updates: - progress.update( - pbar, description=description, completed=completed, refresh=False - ) - - progress.refresh() - time.sleep(BATCH_MONITOR_PROGRESS_REFRESH_TIME) - - updates = [] - for task_name, job in self.jobs.items(): - updates.append( - ( - pbar_tasks[task_name], - pbar_description(task_name, job.status, max_name_length, status_width), - len(run_statuses) - 1, - ) - ) - - for pbar, description, completed in updates: - progress.update( - pbar, description=description, completed=completed, refresh=False - ) - - progress.refresh() - console.log("Batch complete.") - - else: - while any(job.status not in end_statuses for job in self.jobs.values()): - time.sleep(web.REFRESH_TIME) + task_ids = {task_name: job.task_id for task_name, job in self.jobs.items()} + monitor_async(task_ids, self.verbose) @staticmethod def _job_data_path(task_id: TaskId, path_dir: str = DEFAULT_DATA_DIR): @@ -915,59 +709,8 @@ def download(self, path_dir: str = DEFAULT_DATA_DIR, replace_existing: bool = Fa The :class:`Batch` hdf5 file will be automatically saved as ``{path_dir}/batch.hdf5``, allowing one to load this :class:`Batch` later using ``batch = Batch.from_file()``. """ - self._check_path_dir(path_dir=path_dir) - self.to_file(self._batch_path(path_dir=path_dir)) - - num_existing = 0 - for _, job in self.jobs.items(): - job_path_str = self._job_data_path(task_id=job.task_id, path_dir=path_dir) - if os.path.exists(job_path_str): - num_existing += 1 - if num_existing > 0: - files_plural = "files have" if num_existing > 1 else "file has" - log.warning( - f"{num_existing} {files_plural} already been downloaded " - f"and will be skipped. To forcibly overwrite existing files, invoke " - "the load or download function with `replace_existing=True`.", - log_once=True, - ) - - with ThreadPoolExecutor(max_workers=self.num_workers) as executor: - fns = [] - for task_name, job in self.jobs.items(): - job_path_str = self._job_data_path(task_id=job.task_id, path_dir=path_dir) - if os.path.exists(job_path_str): - if replace_existing: - log.info(f"File '{job_path_str}' already exists. Overwriting.") - else: - log.info(f"File '{job_path_str}' already exists. Skipping.") - continue - if "error" in job.status: - log.warning(f"Not downloading '{task_name}' as the task errored.") - continue - - def fn(job=job, job_path_str=job_path_str) -> None: - return job.download(path=job_path_str) - - fns.append(fn) - - futures = [executor.submit(fn) for fn in fns] - - if self.verbose: - console = get_logging_console() - progress_columns = ( - TextColumn("[progress.description]{task.description}"), - BarColumn(), - TaskProgressColumn(), - TimeElapsedColumn(), - ) - with Progress(*progress_columns, console=console, transient=False) as progress: - pbar_message = f"Downloading data for {len(fns)} tasks" - pbar = progress.add_task(pbar_message, total=len(fns)) - completed = 0 - for _ in concurrent.futures.as_completed(futures): - completed += 1 - progress.update(pbar, completed=completed) + task_ids = {task_name: job.task_id for task_name, job in self.jobs.items()} + download_async(task_ids, path_dir, self.num_workers, self.verbose, replace_existing) def load(self, path_dir: str = DEFAULT_DATA_DIR, replace_existing: bool = False) -> BatchData: """Download results and load them into :class:`.BatchData` object. @@ -988,31 +731,10 @@ def load(self, path_dir: str = DEFAULT_DATA_DIR, replace_existing: bool = False) The :class:`Batch` hdf5 file will be automatically saved as ``{path_dir}/batch.hdf5``, allowing one to load this :class:`Batch` later using ``batch = Batch.from_file()``. """ - self._check_path_dir(path_dir=path_dir) - - if self.jobs is None: - raise DataError("Can't load batch results, hasn't been uploaded.") - - task_paths = {} - task_ids = {} - for task_name, job in self.jobs.items(): - if "error" in job.status: - log.warning(f"Not loading '{task_name}' as the task errored.") - continue - - task_paths[task_name] = self._job_data_path(task_id=job.task_id, path_dir=path_dir) - task_ids[task_name] = self.jobs[task_name].task_id - - data = BatchData(task_paths=task_paths, task_ids=task_ids, verbose=self.verbose) - - for task_name, job in self.jobs.items(): - if isinstance(job.simulation, ModeSolver): - job_data = data[task_name] - job.simulation._patch_data(data=job_data) - - self.download(path_dir=path_dir, replace_existing=replace_existing) - - return data + task_ids = {task_name: job.task_id for task_name, job in self.jobs.items()} + return load_async( + task_ids, self.simulations, path_dir, self.num_workers, self.verbose, replace_existing + ) def delete(self) -> None: """Delete server-side data associated with each task in the batch.""" diff --git a/tidy3d/web/api/webapi.py b/tidy3d/web/api/webapi.py index c45012499b..68d9c5f81b 100644 --- a/tidy3d/web/api/webapi.py +++ b/tidy3d/web/api/webapi.py @@ -30,6 +30,7 @@ from tidy3d.web.core.task_info import ChargeType, TaskInfo from tidy3d.web.core.types import PayType +from .batch_data import DEFAULT_DATA_DIR, BatchData from .connect_util import REFRESH_TIME, get_grid_points_str, get_time_steps_str, wait_for_connection from .tidy3d_stub import SimulationDataType, SimulationType, Tidy3dStub, Tidy3dStubData @@ -1182,3 +1183,102 @@ def test() -> None: "instructions at " f"[blue underline][link={url}]'{url}'[/link]." ) from e + + +@wait_for_connection +def run_async( + simulations: dict[str, SimulationType], + folder_name: str = "default", + path_dir: str = DEFAULT_DATA_DIR, + callback_url: Optional[str] = None, + num_workers: Optional[int] = None, + verbose: bool = True, + simulation_type: str = "tidy3d", + parent_tasks: Optional[dict[str, list[str]]] = None, + reduce_simulation: Literal["auto", True, False] = "auto", + pay_type: Union[PayType, str] = PayType.AUTO, +) -> BatchData: + """ + Submits a batch of simulations to server, starts running, monitors progress, downloads, + and loads results as a :class:`.BatchData` object. + + This is the autograd-compatible version of batch processing that uses the async functions + for proper batch handling while maintaining autograd compatibility. + + Parameters + ---------- + simulations : dict[str, Union[:class:`.Simulation`, :class:`.HeatSimulation`, :class:`.EMESimulation`]] + Dictionary mapping task names to simulations to upload to server. + folder_name : str = "default" + Name of folder to store tasks on web UI. + path_dir : str = "." + Directory to store the simulation data files. + callback_url : str = None + Http PUT url to receive simulation finish event. The body content is a json file with + fields ``{'task_id', 'status', 'task_name', 'task_type'}``. + num_workers : int = None + Number of workers for parallel processing. + verbose : bool = True + If `True`, will print progressbars and status, otherwise, will run silently. + simulation_type : str = "tidy3d" + Type of simulation, one of {'tidy3d', 'heat', 'eme'}. + parent_tasks : Optional[dict[str, list[str]]] = None + Dictionary mapping task names to lists of parent task IDs for each simulation. + reduce_simulation : Literal["auto", True, False] = "auto" + Whether to reduce the simulation size by merging structures. + pay_type : Union[PayType, str] = PayType.AUTO + Payment type for the simulation. + + Returns + ------- + :class:`.BatchData` + Object containing simulation data for each simulation in the batch. + + Note + ---- + This function is autograd-compatible and uses the async batch processing functions + (upload_async, start_async, monitor_async, load_async) for efficient batch handling. + """ + + from .asynchronous import load_async, monitor_async, start_async, upload_async + + if not simulations: + raise ValueError("No simulations provided") + + # Input validation - expect a dict + if not isinstance(simulations, dict): + raise AssertionError("simulations must be a dictionary mapping task names to simulations") + + # Set default num_workers + if num_workers is None: + num_workers = len(simulations) + + # Use the async functions for proper batch processing + # Upload all simulations + task_ids = upload_async( + simulations=simulations, + folder_name=folder_name, + callback_url=callback_url, + num_workers=num_workers, + verbose=verbose, + simulation_type=simulation_type, + parent_tasks=parent_tasks, + reduce_simulation=reduce_simulation, + pay_type=pay_type, + ) + + # Start all simulations + start_async(task_ids=task_ids, num_workers=num_workers, verbose=verbose) + + # Monitor all simulations until completion + monitor_async(task_ids=task_ids, verbose=verbose) + + # Load the results + return load_async( + task_ids=task_ids, + simulations=simulations, + path_dir=path_dir, + num_workers=num_workers, + verbose=verbose, + replace_existing=False, + )