Skip to content

Commit 3c458fd

Browse files
Modernize typing syntax to Python 3.10+ standards (#791)
## Overview This PR modernizes the entire SmartSim codebase to use Python 3.10+ typing syntax, improving code readability and type safety while maintaining full backward compatibility. [ committed by @al-rigazzi ] [ reviewed by @MattToast ] --------- Co-authored-by: Matt Drozt <matthew.drozt@gmail.com>
1 parent 3f94fdd commit 3c458fd

File tree

109 files changed

+958
-1029
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

109 files changed

+958
-1029
lines changed

.github/workflows/run_tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ jobs:
5555
fail-fast: false
5656
matrix:
5757
subset: [backends, slow_tests, group_a, group_b]
58-
os: [macos-13, macos-14, ubuntu-22.04] # Operating systems
58+
os: [macos-14, ubuntu-22.04] # Operating systems
5959
compiler: [8] # GNU compiler version
6060
rai: [1.2.7] # Redis AI versions
6161
py_v: ["3.10", "3.11", "3.12"] # Python versions

conftest.py

Lines changed: 43 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
RunSettings,
6565
SrunSettings,
6666
)
67+
from collections.abc import Callable, Collection
6768

6869
logger = get_logger(__name__)
6970

@@ -79,7 +80,7 @@
7980
test_alloc_specs_path = os.getenv("SMARTSIM_TEST_ALLOC_SPEC_SHEET_PATH", None)
8081
test_ports = CONFIG.test_ports
8182
test_account = CONFIG.test_account or ""
82-
test_batch_resources: t.Dict[t.Any, t.Any] = CONFIG.test_batch_resources
83+
test_batch_resources: dict[t.Any, t.Any] = CONFIG.test_batch_resources
8384
test_output_dirs = 0
8485
mpi_app_exe = None
8586
built_mpi_app = False
@@ -169,7 +170,7 @@ def pytest_sessionfinish(
169170
kill_all_test_spawned_processes()
170171

171172

172-
def build_mpi_app() -> t.Optional[pathlib.Path]:
173+
def build_mpi_app() -> pathlib.Path | None:
173174
global built_mpi_app
174175
built_mpi_app = True
175176
cc = shutil.which("cc")
@@ -190,7 +191,7 @@ def build_mpi_app() -> t.Optional[pathlib.Path]:
190191
return None
191192

192193
@pytest.fixture(scope="session")
193-
def mpi_app_path() -> t.Optional[pathlib.Path]:
194+
def mpi_app_path() -> pathlib.Path | None:
194195
"""Return path to MPI app if it was built
195196
196197
return None if it could not or will not be built
@@ -223,7 +224,7 @@ def kill_all_test_spawned_processes() -> None:
223224

224225

225226

226-
def get_hostlist() -> t.Optional[t.List[str]]:
227+
def get_hostlist() -> list[str] | None:
227228
global test_hostlist
228229
if not test_hostlist:
229230
if "PBS_NODEFILE" in os.environ and test_launcher == "pals":
@@ -251,14 +252,14 @@ def get_hostlist() -> t.Optional[t.List[str]]:
251252
return test_hostlist
252253

253254

254-
def _parse_hostlist_file(path: str) -> t.List[str]:
255+
def _parse_hostlist_file(path: str) -> list[str]:
255256
with open(path, "r", encoding="utf-8") as nodefile:
256257
return list({line.strip() for line in nodefile.readlines()})
257258

258259

259260
@pytest.fixture(scope="session")
260-
def alloc_specs() -> t.Dict[str, t.Any]:
261-
specs: t.Dict[str, t.Any] = {}
261+
def alloc_specs() -> dict[str, t.Any]:
262+
specs: dict[str, t.Any] = {}
262263
if test_alloc_specs_path:
263264
try:
264265
with open(test_alloc_specs_path, encoding="utf-8") as spec_file:
@@ -293,7 +294,7 @@ def _reset():
293294
)
294295

295296

296-
def _find_free_port(ports: t.Collection[int]) -> int:
297+
def _find_free_port(ports: Collection[int]) -> int:
297298
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
298299
for port in ports:
299300
try:
@@ -310,7 +311,7 @@ def _find_free_port(ports: t.Collection[int]) -> int:
310311

311312

312313
@pytest.fixture(scope="session")
313-
def wlmutils() -> t.Type[WLMUtils]:
314+
def wlmutils() -> type[WLMUtils]:
314315
return WLMUtils
315316

316317

@@ -335,22 +336,22 @@ def get_test_account() -> str:
335336
return get_account()
336337

337338
@staticmethod
338-
def get_test_interface() -> t.List[str]:
339+
def get_test_interface() -> list[str]:
339340
return test_nic
340341

341342
@staticmethod
342-
def get_test_hostlist() -> t.Optional[t.List[str]]:
343+
def get_test_hostlist() -> list[str] | None:
343344
return get_hostlist()
344345

345346
@staticmethod
346-
def get_batch_resources() -> t.Dict:
347+
def get_batch_resources() -> dict:
347348
return test_batch_resources
348349

349350
@staticmethod
350351
def get_base_run_settings(
351-
exe: str, args: t.List[str], nodes: int = 1, ntasks: int = 1, **kwargs: t.Any
352+
exe: str, args: list[str], nodes: int = 1, ntasks: int = 1, **kwargs: t.Any
352353
) -> RunSettings:
353-
run_args: t.Dict[str, t.Union[int, str, float, None]] = {}
354+
run_args: dict[str, int | str | float | None] = {}
354355

355356
if test_launcher == "slurm":
356357
run_args = {"--nodes": nodes, "--ntasks": ntasks, "--time": "00:10:00"}
@@ -391,9 +392,9 @@ def get_base_run_settings(
391392

392393
@staticmethod
393394
def get_run_settings(
394-
exe: str, args: t.List[str], nodes: int = 1, ntasks: int = 1, **kwargs: t.Any
395+
exe: str, args: list[str], nodes: int = 1, ntasks: int = 1, **kwargs: t.Any
395396
) -> RunSettings:
396-
run_args: t.Dict[str, t.Union[int, str, float, None]] = {}
397+
run_args: dict[str, int | str | float | None] = {}
397398

398399
if test_launcher == "slurm":
399400
run_args = {"nodes": nodes, "ntasks": ntasks, "time": "00:10:00"}
@@ -423,7 +424,7 @@ def get_run_settings(
423424
return RunSettings(exe, args)
424425

425426
@staticmethod
426-
def choose_host(rs: RunSettings) -> t.Optional[str]:
427+
def choose_host(rs: RunSettings) -> str | None:
427428
if isinstance(rs, (MpirunSettings, MpiexecSettings)):
428429
hl = get_hostlist()
429430
if hl is not None:
@@ -450,13 +451,13 @@ def check_output_dir() -> None:
450451

451452

452453
@pytest.fixture
453-
def dbutils() -> t.Type[DBUtils]:
454+
def dbutils() -> type[DBUtils]:
454455
return DBUtils
455456

456457

457458
class DBUtils:
458459
@staticmethod
459-
def get_db_configs() -> t.Dict[str, t.Any]:
460+
def get_db_configs() -> dict[str, t.Any]:
460461
config_settings = {
461462
"enable_checkpoints": 1,
462463
"set_max_memory": "3gb",
@@ -470,7 +471,7 @@ def get_db_configs() -> t.Dict[str, t.Any]:
470471
return config_settings
471472

472473
@staticmethod
473-
def get_smartsim_error_db_configs() -> t.Dict[str, t.Any]:
474+
def get_smartsim_error_db_configs() -> dict[str, t.Any]:
474475
bad_configs = {
475476
"save": [
476477
"-1", # frequency must be positive
@@ -497,8 +498,8 @@ def get_smartsim_error_db_configs() -> t.Dict[str, t.Any]:
497498
return bad_configs
498499

499500
@staticmethod
500-
def get_type_error_db_configs() -> t.Dict[t.Union[int, str], t.Any]:
501-
bad_configs: t.Dict[t.Union[int, str], t.Any] = {
501+
def get_type_error_db_configs() -> dict[int | str, t.Any]:
502+
bad_configs: dict[int | str, t.Any] = {
502503
"save": [2, True, ["2"]], # frequency must be specified as a string
503504
"maxmemory": [99, True, ["99"]], # memory form must be a string
504505
"maxclients": [3, True, ["3"]], # number of clients must be a string
@@ -519,9 +520,9 @@ def get_type_error_db_configs() -> t.Dict[t.Union[int, str], t.Any]:
519520
@staticmethod
520521
def get_config_edit_method(
521522
db: Orchestrator, config_setting: str
522-
) -> t.Optional[t.Callable[..., None]]:
523+
) -> Callable[..., None] | None:
523524
"""Get a db configuration file edit method from a str"""
524-
config_edit_methods: t.Dict[str, t.Callable[..., None]] = {
525+
config_edit_methods: dict[str, Callable[..., None]] = {
525526
"enable_checkpoints": db.enable_checkpoints,
526527
"set_max_memory": db.set_max_memory,
527528
"set_eviction_strategy": db.set_eviction_strategy,
@@ -564,7 +565,7 @@ def test_dir(request: pytest.FixtureRequest) -> str:
564565

565566

566567
@pytest.fixture
567-
def fileutils() -> t.Type[FileUtils]:
568+
def fileutils() -> type[FileUtils]:
568569
return FileUtils
569570

570571

@@ -589,7 +590,7 @@ def get_test_dir_path(dirname: str) -> str:
589590

590591
@staticmethod
591592
def make_test_file(
592-
file_name: str, file_dir: str, file_content: t.Optional[str] = None
593+
file_name: str, file_dir: str, file_content: str | None = None
593594
) -> str:
594595
"""Create a dummy file in the test output directory.
595596
@@ -609,7 +610,7 @@ def make_test_file(
609610

610611

611612
@pytest.fixture
612-
def mlutils() -> t.Type[MLUtils]:
613+
def mlutils() -> type[MLUtils]:
613614
return MLUtils
614615

615616

@@ -624,21 +625,21 @@ def get_test_num_gpus() -> int:
624625

625626

626627
@pytest.fixture
627-
def coloutils() -> t.Type[ColoUtils]:
628+
def coloutils() -> type[ColoUtils]:
628629
return ColoUtils
629630

630631

631632
class ColoUtils:
632633
@staticmethod
633634
def setup_test_colo(
634-
fileutils: t.Type[FileUtils],
635+
fileutils: type[FileUtils],
635636
db_type: str,
636637
exp: Experiment,
637638
application_file: str,
638-
db_args: t.Dict[str, t.Any],
639-
colo_settings: t.Optional[RunSettings] = None,
639+
db_args: dict[str, t.Any],
640+
colo_settings: RunSettings | None = None,
640641
colo_model_name: str = "colocated_model",
641-
port: t.Optional[int] = None,
642+
port: int | None = None,
642643
on_wlm: bool = False,
643644
) -> Model:
644645
"""Setup database needed for the colo pinning tests"""
@@ -666,7 +667,7 @@ def setup_test_colo(
666667
socket_name = f"{colo_model_name}_{socket_suffix}.socket"
667668
db_args["unix_socket"] = os.path.join(tmp_dir, socket_name)
668669

669-
colocate_fun: t.Dict[str, t.Callable[..., None]] = {
670+
colocate_fun: dict[str, Callable[..., None]] = {
670671
"tcp": colo_model.colocate_db_tcp,
671672
"deprecated": colo_model.colocate_db,
672673
"uds": colo_model.colocate_db_uds,
@@ -708,7 +709,7 @@ def config() -> Config:
708709
class CountingCallable:
709710
def __init__(self) -> None:
710711
self._num: int = 0
711-
self._details: t.List[t.Tuple[t.Tuple[t.Any, ...], t.Dict[str, t.Any]]] = []
712+
self._details: list[tuple[tuple[t.Any, ...], dict[str, t.Any]]] = []
712713

713714
def __call__(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
714715
self._num += 1
@@ -719,12 +720,12 @@ def num_calls(self) -> int:
719720
return self._num
720721

721722
@property
722-
def details(self) -> t.List[t.Tuple[t.Tuple[t.Any, ...], t.Dict[str, t.Any]]]:
723+
def details(self) -> list[tuple[tuple[t.Any, ...], dict[str, t.Any]]]:
723724
return self._details
724725

725726
## Reuse database across tests
726727

727-
database_registry: t.DefaultDict[str, t.Optional[Orchestrator]] = defaultdict(lambda: None)
728+
database_registry: defaultdict[str, Orchestrator | None] = defaultdict(lambda: None)
728729

729730
@pytest.fixture(scope="function")
730731
def local_experiment(test_dir: str) -> smartsim.Experiment:
@@ -758,13 +759,13 @@ class DBConfiguration:
758759
name: str
759760
launcher: str
760761
num_nodes: int
761-
interface: t.Union[str,t.List[str]]
762-
hostlist: t.Optional[t.List[str]]
762+
interface: str | list[str]
763+
hostlist: list[str] | None
763764
port: int
764765

765766
@dataclass
766767
class PrepareDatabaseOutput:
767-
orchestrator: t.Optional[Orchestrator] # The actual orchestrator object
768+
orchestrator: Orchestrator | None # The actual orchestrator object
768769
new_db: bool # True if a new database was created when calling prepare_db
769770

770771
# Reuse databases
@@ -817,7 +818,7 @@ def clustered_db(wlmutils: WLMUtils) -> t.Generator[DBConfiguration, None, None]
817818

818819

819820
@pytest.fixture
820-
def register_new_db() -> t.Callable[[DBConfiguration], Orchestrator]:
821+
def register_new_db() -> Callable[[DBConfiguration], Orchestrator]:
821822
def _register_new_db(
822823
config: DBConfiguration
823824
) -> Orchestrator:
@@ -845,11 +846,11 @@ def _register_new_db(
845846

846847
@pytest.fixture(scope="function")
847848
def prepare_db(
848-
register_new_db: t.Callable[
849+
register_new_db: Callable[
849850
[DBConfiguration],
850851
Orchestrator
851852
]
852-
) -> t.Callable[
853+
) -> Callable[
853854
[DBConfiguration],
854855
PrepareDatabaseOutput
855856
]:

doc/changelog.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ To be released at some point in the future
1111

1212
Description
1313

14+
- Modernize typing syntax to Python 3.10+ standards
1415
- Removed telemetry functionality, LaunchedManifest tracking
1516
classes, and SmartDashboard integration
1617
- Update copyright headers from 2021-2024 to 2021-2025 across the entire codebase
@@ -24,6 +25,10 @@ Description
2425

2526
Detailed Notes
2627

28+
- Modernized typing syntax to use Python 3.10+ standards, replacing
29+
`Union[X, Y]` with `X | Y`, `Optional[X]` with `X | None`, and generic
30+
collections (`List[X]``list[X]`, `Dict[X, Y]``dict[X, Y]`, etc.).
31+
([SmartSim-PR791](https://github.com/CrayLabs/SmartSim/pull/791))
2732
- Removed telemetry functionality, LaunchedManifest tracking
2833
system, and SmartDashboard integration.
2934
This includes complete removal of the telemetry monitor and collection system,

smartsim/_core/_cli/build.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
import re
3232
import shutil
3333
import textwrap
34-
import typing as t
34+
from collections.abc import Callable, Collection
3535
from pathlib import Path
3636

3737
from tabulate import tabulate
@@ -139,7 +139,7 @@ def build_redis_ai(
139139

140140
def parse_requirement(
141141
requirement: str,
142-
) -> t.Tuple[str, t.Optional[str], t.Callable[[Version_], bool]]:
142+
) -> tuple[str, str | None, Callable[[Version_], bool]]:
143143
operators = {
144144
"==": operator.eq,
145145
"<=": operator.le,
@@ -199,10 +199,10 @@ def check_ml_python_packages(packages: MLPackageCollection) -> None:
199199

200200

201201
def _format_incompatible_python_env_message(
202-
missing: t.Collection[str], conflicting: t.Collection[str]
202+
missing: Collection[str], conflicting: Collection[str]
203203
) -> str:
204204
indent = "\n\t"
205-
fmt_list: t.Callable[[str, t.Collection[str]], str] = lambda n, l: (
205+
fmt_list: Callable[[str, Collection[str]], str] = lambda n, l: (
206206
f"{n}:{indent}{indent.join(l)}" if l else ""
207207
)
208208
missing_str = fmt_list("Missing", missing)
@@ -237,7 +237,7 @@ def _configure_keydb_build(versions: Versioner) -> None:
237237

238238
# pylint: disable-next=too-many-statements
239239
def execute(
240-
args: argparse.Namespace, _unparsed_args: t.Optional[t.List[str]] = None, /
240+
args: argparse.Namespace, _unparsed_args: list[str] | None = None, /
241241
) -> int:
242242

243243
# Unpack various arguments

smartsim/_core/_cli/clean.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626

2727
import argparse
28-
import typing as t
2928

3029
from smartsim._core._cli.utils import clean, get_install_path
3130

@@ -41,13 +40,13 @@ def configure_parser(parser: argparse.ArgumentParser) -> None:
4140

4241

4342
def execute(
44-
args: argparse.Namespace, _unparsed_args: t.Optional[t.List[str]] = None, /
43+
args: argparse.Namespace, _unparsed_args: list[str] | None = None, /
4544
) -> int:
4645
return clean(get_install_path() / "_core", _all=args.clobber)
4746

4847

4948
def execute_all(
50-
args: argparse.Namespace, _unparsed_args: t.Optional[t.List[str]] = None, /
49+
args: argparse.Namespace, _unparsed_args: list[str] | None = None, /
5150
) -> int:
5251
args.clobber = True
5352
return execute(args)

0 commit comments

Comments
 (0)