6464 RunSettings ,
6565 SrunSettings ,
6666)
67+ from collections .abc import Callable , Collection
6768
6869logger = get_logger (__name__ )
6970
7980test_alloc_specs_path = os .getenv ("SMARTSIM_TEST_ALLOC_SPEC_SHEET_PATH" , None )
8081test_ports = CONFIG .test_ports
8182test_account = CONFIG .test_account or ""
82- test_batch_resources : t . Dict [t .Any , t .Any ] = CONFIG .test_batch_resources
83+ test_batch_resources : dict [t .Any , t .Any ] = CONFIG .test_batch_resources
8384test_output_dirs = 0
8485mpi_app_exe = None
8586built_mpi_app = False
@@ -169,7 +170,7 @@ def pytest_sessionfinish(
169170 kill_all_test_spawned_processes ()
170171
171172
172- def build_mpi_app () -> t . Optional [ pathlib .Path ] :
173+ def build_mpi_app () -> pathlib .Path | None :
173174 global built_mpi_app
174175 built_mpi_app = True
175176 cc = shutil .which ("cc" )
@@ -190,7 +191,7 @@ def build_mpi_app() -> t.Optional[pathlib.Path]:
190191 return None
191192
192193@pytest .fixture (scope = "session" )
193- def mpi_app_path () -> t . Optional [ pathlib .Path ] :
194+ def mpi_app_path () -> pathlib .Path | None :
194195 """Return path to MPI app if it was built
195196
196197 return None if it could not or will not be built
@@ -223,7 +224,7 @@ def kill_all_test_spawned_processes() -> None:
223224
224225
225226
226- def get_hostlist () -> t . Optional [ t . List [ str ]] :
227+ def get_hostlist () -> list [ str ] | None :
227228 global test_hostlist
228229 if not test_hostlist :
229230 if "PBS_NODEFILE" in os .environ and test_launcher == "pals" :
@@ -251,14 +252,14 @@ def get_hostlist() -> t.Optional[t.List[str]]:
251252 return test_hostlist
252253
253254
254- def _parse_hostlist_file (path : str ) -> t . List [str ]:
255+ def _parse_hostlist_file (path : str ) -> list [str ]:
255256 with open (path , "r" , encoding = "utf-8" ) as nodefile :
256257 return list ({line .strip () for line in nodefile .readlines ()})
257258
258259
259260@pytest .fixture (scope = "session" )
260- def alloc_specs () -> t . Dict [str , t .Any ]:
261- specs : t . Dict [str , t .Any ] = {}
261+ def alloc_specs () -> dict [str , t .Any ]:
262+ specs : dict [str , t .Any ] = {}
262263 if test_alloc_specs_path :
263264 try :
264265 with open (test_alloc_specs_path , encoding = "utf-8" ) as spec_file :
@@ -293,7 +294,7 @@ def _reset():
293294)
294295
295296
296- def _find_free_port (ports : t . Collection [int ]) -> int :
297+ def _find_free_port (ports : Collection [int ]) -> int :
297298 with socket .socket (socket .AF_INET , socket .SOCK_STREAM ) as sock :
298299 for port in ports :
299300 try :
@@ -310,7 +311,7 @@ def _find_free_port(ports: t.Collection[int]) -> int:
310311
311312
312313@pytest .fixture (scope = "session" )
313- def wlmutils () -> t . Type [WLMUtils ]:
314+ def wlmutils () -> type [WLMUtils ]:
314315 return WLMUtils
315316
316317
@@ -335,22 +336,22 @@ def get_test_account() -> str:
335336 return get_account ()
336337
337338 @staticmethod
338- def get_test_interface () -> t . List [str ]:
339+ def get_test_interface () -> list [str ]:
339340 return test_nic
340341
341342 @staticmethod
342- def get_test_hostlist () -> t . Optional [ t . List [ str ]] :
343+ def get_test_hostlist () -> list [ str ] | None :
343344 return get_hostlist ()
344345
345346 @staticmethod
346- def get_batch_resources () -> t . Dict :
347+ def get_batch_resources () -> dict :
347348 return test_batch_resources
348349
349350 @staticmethod
350351 def get_base_run_settings (
351- exe : str , args : t . List [str ], nodes : int = 1 , ntasks : int = 1 , ** kwargs : t .Any
352+ exe : str , args : list [str ], nodes : int = 1 , ntasks : int = 1 , ** kwargs : t .Any
352353 ) -> RunSettings :
353- run_args : t . Dict [str , t . Union [ int , str , float , None ] ] = {}
354+ run_args : dict [str , int | str | float | None ] = {}
354355
355356 if test_launcher == "slurm" :
356357 run_args = {"--nodes" : nodes , "--ntasks" : ntasks , "--time" : "00:10:00" }
@@ -391,9 +392,9 @@ def get_base_run_settings(
391392
392393 @staticmethod
393394 def get_run_settings (
394- exe : str , args : t . List [str ], nodes : int = 1 , ntasks : int = 1 , ** kwargs : t .Any
395+ exe : str , args : list [str ], nodes : int = 1 , ntasks : int = 1 , ** kwargs : t .Any
395396 ) -> RunSettings :
396- run_args : t . Dict [str , t . Union [ int , str , float , None ] ] = {}
397+ run_args : dict [str , int | str | float | None ] = {}
397398
398399 if test_launcher == "slurm" :
399400 run_args = {"nodes" : nodes , "ntasks" : ntasks , "time" : "00:10:00" }
@@ -423,7 +424,7 @@ def get_run_settings(
423424 return RunSettings (exe , args )
424425
425426 @staticmethod
426- def choose_host (rs : RunSettings ) -> t . Optional [ str ] :
427+ def choose_host (rs : RunSettings ) -> str | None :
427428 if isinstance (rs , (MpirunSettings , MpiexecSettings )):
428429 hl = get_hostlist ()
429430 if hl is not None :
@@ -450,13 +451,13 @@ def check_output_dir() -> None:
450451
451452
452453@pytest .fixture
453- def dbutils () -> t . Type [DBUtils ]:
454+ def dbutils () -> type [DBUtils ]:
454455 return DBUtils
455456
456457
457458class DBUtils :
458459 @staticmethod
459- def get_db_configs () -> t . Dict [str , t .Any ]:
460+ def get_db_configs () -> dict [str , t .Any ]:
460461 config_settings = {
461462 "enable_checkpoints" : 1 ,
462463 "set_max_memory" : "3gb" ,
@@ -470,7 +471,7 @@ def get_db_configs() -> t.Dict[str, t.Any]:
470471 return config_settings
471472
472473 @staticmethod
473- def get_smartsim_error_db_configs () -> t . Dict [str , t .Any ]:
474+ def get_smartsim_error_db_configs () -> dict [str , t .Any ]:
474475 bad_configs = {
475476 "save" : [
476477 "-1" , # frequency must be positive
@@ -497,8 +498,8 @@ def get_smartsim_error_db_configs() -> t.Dict[str, t.Any]:
497498 return bad_configs
498499
499500 @staticmethod
500- def get_type_error_db_configs () -> t . Dict [ t . Union [ int , str ] , t .Any ]:
501- bad_configs : t . Dict [ t . Union [ int , str ] , t .Any ] = {
501+ def get_type_error_db_configs () -> dict [ int | str , t .Any ]:
502+ bad_configs : dict [ int | str , t .Any ] = {
502503 "save" : [2 , True , ["2" ]], # frequency must be specified as a string
503504 "maxmemory" : [99 , True , ["99" ]], # memory form must be a string
504505 "maxclients" : [3 , True , ["3" ]], # number of clients must be a string
@@ -519,9 +520,9 @@ def get_type_error_db_configs() -> t.Dict[t.Union[int, str], t.Any]:
519520 @staticmethod
520521 def get_config_edit_method (
521522 db : Orchestrator , config_setting : str
522- ) -> t . Optional [ t . Callable [..., None ]] :
523+ ) -> Callable [..., None ] | None :
523524 """Get a db configuration file edit method from a str"""
524- config_edit_methods : t . Dict [str , t . Callable [..., None ]] = {
525+ config_edit_methods : dict [str , Callable [..., None ]] = {
525526 "enable_checkpoints" : db .enable_checkpoints ,
526527 "set_max_memory" : db .set_max_memory ,
527528 "set_eviction_strategy" : db .set_eviction_strategy ,
@@ -564,7 +565,7 @@ def test_dir(request: pytest.FixtureRequest) -> str:
564565
565566
566567@pytest .fixture
567- def fileutils () -> t . Type [FileUtils ]:
568+ def fileutils () -> type [FileUtils ]:
568569 return FileUtils
569570
570571
@@ -589,7 +590,7 @@ def get_test_dir_path(dirname: str) -> str:
589590
590591 @staticmethod
591592 def make_test_file (
592- file_name : str , file_dir : str , file_content : t . Optional [ str ] = None
593+ file_name : str , file_dir : str , file_content : str | None = None
593594 ) -> str :
594595 """Create a dummy file in the test output directory.
595596
@@ -609,7 +610,7 @@ def make_test_file(
609610
610611
611612@pytest .fixture
612- def mlutils () -> t . Type [MLUtils ]:
613+ def mlutils () -> type [MLUtils ]:
613614 return MLUtils
614615
615616
@@ -624,21 +625,21 @@ def get_test_num_gpus() -> int:
624625
625626
626627@pytest .fixture
627- def coloutils () -> t . Type [ColoUtils ]:
628+ def coloutils () -> type [ColoUtils ]:
628629 return ColoUtils
629630
630631
631632class ColoUtils :
632633 @staticmethod
633634 def setup_test_colo (
634- fileutils : t . Type [FileUtils ],
635+ fileutils : type [FileUtils ],
635636 db_type : str ,
636637 exp : Experiment ,
637638 application_file : str ,
638- db_args : t . Dict [str , t .Any ],
639- colo_settings : t . Optional [ RunSettings ] = None ,
639+ db_args : dict [str , t .Any ],
640+ colo_settings : RunSettings | None = None ,
640641 colo_model_name : str = "colocated_model" ,
641- port : t . Optional [ int ] = None ,
642+ port : int | None = None ,
642643 on_wlm : bool = False ,
643644 ) -> Model :
644645 """Setup database needed for the colo pinning tests"""
@@ -666,7 +667,7 @@ def setup_test_colo(
666667 socket_name = f"{ colo_model_name } _{ socket_suffix } .socket"
667668 db_args ["unix_socket" ] = os .path .join (tmp_dir , socket_name )
668669
669- colocate_fun : t . Dict [str , t . Callable [..., None ]] = {
670+ colocate_fun : dict [str , Callable [..., None ]] = {
670671 "tcp" : colo_model .colocate_db_tcp ,
671672 "deprecated" : colo_model .colocate_db ,
672673 "uds" : colo_model .colocate_db_uds ,
@@ -708,7 +709,7 @@ def config() -> Config:
708709class CountingCallable :
709710 def __init__ (self ) -> None :
710711 self ._num : int = 0
711- self ._details : t . List [ t . Tuple [ t . Tuple [t .Any , ...], t . Dict [str , t .Any ]]] = []
712+ self ._details : list [ tuple [ tuple [t .Any , ...], dict [str , t .Any ]]] = []
712713
713714 def __call__ (self , * args : t .Any , ** kwargs : t .Any ) -> t .Any :
714715 self ._num += 1
@@ -719,12 +720,12 @@ def num_calls(self) -> int:
719720 return self ._num
720721
721722 @property
722- def details (self ) -> t . List [ t . Tuple [ t . Tuple [t .Any , ...], t . Dict [str , t .Any ]]]:
723+ def details (self ) -> list [ tuple [ tuple [t .Any , ...], dict [str , t .Any ]]]:
723724 return self ._details
724725
725726## Reuse database across tests
726727
727- database_registry : t . DefaultDict [str , t . Optional [ Orchestrator ] ] = defaultdict (lambda : None )
728+ database_registry : defaultdict [str , Orchestrator | None ] = defaultdict (lambda : None )
728729
729730@pytest .fixture (scope = "function" )
730731def local_experiment (test_dir : str ) -> smartsim .Experiment :
@@ -758,13 +759,13 @@ class DBConfiguration:
758759 name : str
759760 launcher : str
760761 num_nodes : int
761- interface : t . Union [ str , t . List [str ] ]
762- hostlist : t . Optional [ t . List [ str ]]
762+ interface : str | list [str ]
763+ hostlist : list [ str ] | None
763764 port : int
764765
765766@dataclass
766767class PrepareDatabaseOutput :
767- orchestrator : t . Optional [ Orchestrator ] # The actual orchestrator object
768+ orchestrator : Orchestrator | None # The actual orchestrator object
768769 new_db : bool # True if a new database was created when calling prepare_db
769770
770771# Reuse databases
@@ -817,7 +818,7 @@ def clustered_db(wlmutils: WLMUtils) -> t.Generator[DBConfiguration, None, None]
817818
818819
819820@pytest .fixture
820- def register_new_db () -> t . Callable [[DBConfiguration ], Orchestrator ]:
821+ def register_new_db () -> Callable [[DBConfiguration ], Orchestrator ]:
821822 def _register_new_db (
822823 config : DBConfiguration
823824 ) -> Orchestrator :
@@ -845,11 +846,11 @@ def _register_new_db(
845846
846847@pytest .fixture (scope = "function" )
847848def prepare_db (
848- register_new_db : t . Callable [
849+ register_new_db : Callable [
849850 [DBConfiguration ],
850851 Orchestrator
851852 ]
852- ) -> t . Callable [
853+ ) -> Callable [
853854 [DBConfiguration ],
854855 PrepareDatabaseOutput
855856]:
0 commit comments