3
3
from __future__ import annotations
4
4
5
5
import os
6
- import typing
7
6
from abc import ABC , abstractmethod
8
7
from typing import Optional , Union , get_args
9
8
23
22
from tidy3d .plugins .smatrix .ports .modal import Port
24
23
from tidy3d .plugins .smatrix .ports .rectangular_lumped import LumpedPort
25
24
from tidy3d .plugins .smatrix .ports .wave import WavePort
25
+ from tidy3d .web import Batch , BatchData
26
26
27
27
# fwidth of gaussian pulse in units of central frequency
28
28
FWIDTH_FRAC = 1.0 / 10
@@ -105,7 +105,7 @@ class AbstractComponentModeler(ABC, Tidy3dBaseModel):
105
105
)
106
106
107
107
# TODO properly refactor, plugins should not have web methods.
108
- batch_cached : typing . Any = pd .Field (
108
+ batch_cached : BatchData = pd .Field (
109
109
None ,
110
110
title = "Batch (Cached)" ,
111
111
description = "DEPRECATED: Optional field to specify ``batch``. Only used as a workaround internally "
@@ -182,7 +182,7 @@ def to_file(self, fname: str) -> None:
182
182
super (AbstractComponentModeler , self ).to_file (fname = fname ) # noqa: UP008
183
183
184
184
@cached_property
185
- def batch (self ):
185
+ def batch (self ) -> Batch :
186
186
""":class:`.Batch` associated with this component modeler."""
187
187
# TODO properly refactor, plugins data types should not have web methods.
188
188
from tidy3d .web .api .container import Batch
@@ -210,7 +210,7 @@ def batch_path(self) -> str:
210
210
return self .batch ._batch_path (path_dir = self .path_dir )
211
211
212
212
@cached_property
213
- def batch_data (self ):
213
+ def batch_data (self ) -> BatchData :
214
214
"""The :class:`.BatchData` associated with the simulations run for this component modeler."""
215
215
return self .batch .run (path_dir = self .path_dir )
216
216
@@ -232,7 +232,7 @@ def _batch_path(self) -> str:
232
232
"""Where we store the batch for this :class:`AbstractComponentModeler` instance after the run."""
233
233
return os .path .join (self .path_dir , "batch" + str (hash (self )) + ".hdf5" )
234
234
235
- def _run_sims (self , path_dir : str = DEFAULT_DATA_DIR ):
235
+ def _run_sims (self , path_dir : str = DEFAULT_DATA_DIR ) -> BatchData :
236
236
"""Run :class:`.Simulation` for each port and return the batch after saving."""
237
237
_ = self .get_path_dir (path_dir )
238
238
self .batch .to_file (self ._batch_path )
@@ -247,11 +247,11 @@ def get_port_by_name(self, port_name: str) -> Port:
247
247
return ports [0 ]
248
248
249
249
@abstractmethod
250
- def _construct_smatrix (self , batch_data ) -> DataArray :
250
+ def _construct_smatrix (self , batch_data : BatchData ) -> DataArray :
251
251
"""Post process :class:`.BatchData` to generate scattering matrix."""
252
252
253
253
@abstractmethod
254
- def _internal_construct_smatrix (self , batch_data ) -> DataArray :
254
+ def _internal_construct_smatrix (self , batch_data : BatchData ) -> DataArray :
255
255
"""Post process :class:`.BatchData` to generate scattering matrix, for internal use only."""
256
256
257
257
def run (self , path_dir : str = DEFAULT_DATA_DIR ) -> DataArray :
@@ -324,3 +324,6 @@ def sim_data_by_task_name(self, task_name: str) -> SimulationData:
324
324
sim_data = self .batch_data [task_name ]
325
325
config .logging_level = log_level_cache
326
326
return sim_data
327
+
328
+
329
+ AbstractComponentModeler .update_forward_refs ()
0 commit comments