|
6 | 6 | import os
|
7 | 7 | import time
|
8 | 8 | from abc import ABC
|
| 9 | +from collections.abc import Mapping |
9 | 10 | from concurrent.futures import ThreadPoolExecutor
|
10 | 11 | from typing import Dict, Optional, Tuple
|
11 | 12 |
|
@@ -339,7 +340,7 @@ def estimate_cost(self, verbose: bool = True) -> float:
|
339 | 340 | return web.estimate_cost(self.task_id, verbose=verbose, solver_version=self.solver_version)
|
340 | 341 |
|
341 | 342 |
|
342 |
| -class BatchData(Tidy3dBaseModel): |
| 343 | +class BatchData(Tidy3dBaseModel, Mapping): |
343 | 344 | """
|
344 | 345 | Holds a collection of :class:`.SimulationData` returned by :class:`Batch`.
|
345 | 346 |
|
@@ -391,16 +392,18 @@ def load_sim_data(self, task_name: str) -> SimulationDataType:
|
391 | 392 | verbose=False,
|
392 | 393 | )
|
393 | 394 |
|
394 |
| - def items(self) -> Tuple[TaskName, SimulationDataType]: |
395 |
| - """Iterate through the simulations for each task_name.""" |
396 |
| - |
397 |
| - for task_name in self.task_paths.keys(): |
398 |
| - yield task_name, self.load_sim_data(task_name) |
399 |
| - |
400 | 395 | def __getitem__(self, task_name: TaskName) -> SimulationDataType:
|
401 | 396 | """Get the simulation data object for a given ``task_name``."""
|
402 | 397 | return self.load_sim_data(task_name)
|
403 | 398 |
|
| 399 | + def __iter__(self): |
| 400 | + """Iterate over the task names.""" |
| 401 | + return iter(self.task_paths) |
| 402 | + |
| 403 | + def __len__(self): |
| 404 | + """Return the number of tasks in the batch.""" |
| 405 | + return len(self.task_paths) |
| 406 | + |
404 | 407 | @classmethod
|
405 | 408 | def load(cls, path_dir: str = DEFAULT_DATA_DIR) -> BatchData:
|
406 | 409 | """Load :class:`Batch` from file, download results, and load them.
|
|
0 commit comments