Skip to content

Commit 30bd56b

Browse files
yaugenst-flextylerflex
authored andcommitted
Make BatchData a mapping
1 parent 30d0f97 commit 30bd56b

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
### Changed
11+
- `BatchData` is now a mapping and can be accessed and iterated over like a Python dictionary (`.keys()`, `.values()`, `.items()`).
12+
1013
### Fixed
1114
- Gradient inaccuracy when a multi-frequency monitor is used but a single frequency is selected.
1215
- Revert single cell center approximation for custom medium gradient.

tidy3d/web/api/container.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import os
77
import time
88
from abc import ABC
9+
from collections.abc import Mapping
910
from concurrent.futures import ThreadPoolExecutor
1011
from typing import Dict, Optional, Tuple
1112

@@ -339,7 +340,7 @@ def estimate_cost(self, verbose: bool = True) -> float:
339340
return web.estimate_cost(self.task_id, verbose=verbose, solver_version=self.solver_version)
340341

341342

342-
class BatchData(Tidy3dBaseModel):
343+
class BatchData(Tidy3dBaseModel, Mapping):
343344
"""
344345
Holds a collection of :class:`.SimulationData` returned by :class:`Batch`.
345346
@@ -391,16 +392,18 @@ def load_sim_data(self, task_name: str) -> SimulationDataType:
391392
verbose=False,
392393
)
393394

394-
def items(self) -> Tuple[TaskName, SimulationDataType]:
395-
"""Iterate through the simulations for each task_name."""
396-
397-
for task_name in self.task_paths.keys():
398-
yield task_name, self.load_sim_data(task_name)
399-
400395
def __getitem__(self, task_name: TaskName) -> SimulationDataType:
401396
"""Get the simulation data object for a given ``task_name``."""
402397
return self.load_sim_data(task_name)
403398

399+
def __iter__(self):
400+
"""Iterate over the task names."""
401+
return iter(self.task_paths)
402+
403+
def __len__(self):
404+
"""Return the number of tasks in the batch."""
405+
return len(self.task_paths)
406+
404407
@classmethod
405408
def load(cls, path_dir: str = DEFAULT_DATA_DIR) -> BatchData:
406409
"""Load :class:`Batch` from file, download results, and load them.

0 commit comments

Comments
 (0)