Skip to content

Commit b796d3c

Browse files
committed
refactor: [typing] add typing to the Scheduler
Types added to make it clearer what is going on with the job scheduling. This means the pyright/ruff LSP have more information to be able to jump to definitions and show documentation. Signed-off-by: James McCorrie <[email protected]>
1 parent 9c1bf17 commit b796d3c

File tree

5 files changed

+178
-133
lines changed

5 files changed

+178
-133
lines changed

src/dvsim/flow/base.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
import os
88
import pprint
99
import sys
10-
from collections.abc import Mapping
10+
from abc import ABC, abstractmethod
11+
from collections.abc import Mapping, Sequence
1112
from pathlib import Path
13+
from typing import ClassVar
1214

1315
import hjson
1416

@@ -27,7 +29,7 @@
2729

2830

2931
# Interface class for extensions.
30-
class FlowCfg:
32+
class FlowCfg(ABC):
3133
"""Base class for the different flows supported by dvsim.py.
3234
3335
The constructor expects some parsed hjson data. Create these objects with
@@ -41,9 +43,10 @@ class FlowCfg:
4143

4244
# Can be overridden in subclasses to configure which wildcards to ignore
4345
# when expanding hjson.
44-
ignored_wildcards = []
46+
ignored_wildcards: ClassVar = []
4547

4648
def __str__(self) -> str:
49+
"""Get string representation of the flow config."""
4750
return pprint.pformat(self.__dict__)
4851

4952
def __init__(self, flow_cfg_file, hjson_data, args, mk_config) -> None:
@@ -87,7 +90,7 @@ def __init__(self, flow_cfg_file, hjson_data, args, mk_config) -> None:
8790
# For a primary cfg, it is the aggregated list of all deploy objects
8891
# under self.cfgs. For a non-primary cfg, it is the list of items
8992
# slated for dispatch.
90-
self.deploy = []
93+
self.deploy: Sequence[Deploy] = []
9194

9295
# Timestamp
9396
self.timestamp_long = args.timestamp_long
@@ -98,7 +101,7 @@ def __init__(self, flow_cfg_file, hjson_data, args, mk_config) -> None:
98101
self.rel_path = ""
99102
self.results_title = ""
100103
self.revision = ""
101-
self.css_file = os.path.join(Path(os.path.realpath(__file__)).parent, "style.css")
104+
self.css_file = Path(__file__).resolve().parent / "style.css"
102105
# `self.results_*` below will be updated after `self.rel_path` and
103106
# `self.scratch_base_root` variables are updated.
104107
self.results_dir = ""
@@ -151,7 +154,7 @@ def __init__(self, flow_cfg_file, hjson_data, args, mk_config) -> None:
151154
# Run any final checks
152155
self._post_init()
153156

154-
def _merge_hjson(self, hjson_data) -> None:
157+
def _merge_hjson(self, hjson_data: Mapping) -> None:
155158
"""Take hjson data and merge it into self.__dict__.
156159
157160
Subclasses that need to do something just before the merge should
@@ -162,7 +165,7 @@ def _merge_hjson(self, hjson_data) -> None:
162165
set_target_attribute(self.flow_cfg_file, self.__dict__, key, value)
163166

164167
def _expand(self) -> None:
165-
"""Called to expand wildcards after merging hjson.
168+
"""Expand wildcards after merging hjson.
166169
167170
Subclasses can override this to do something just before expansion.
168171
@@ -237,8 +240,9 @@ def _load_child_cfg(self, entry, mk_config) -> None:
237240
)
238241
sys.exit(1)
239242

240-
def _conv_inline_cfg_to_hjson(self, idict):
243+
def _conv_inline_cfg_to_hjson(self, idict: Mapping) -> str | None:
241244
"""Dump a temp hjson file in the scratch space from input dict.
245+
242246
This method is to be called only by a primary cfg.
243247
"""
244248
if not self.is_primary_cfg:
@@ -259,8 +263,10 @@ def _conv_inline_cfg_to_hjson(self, idict):
259263

260264
# Create the file and dump the dict as hjson
261265
log.verbose('Dumping inline cfg "%s" in hjson to:\n%s', name, temp_cfg_file)
266+
262267
try:
263268
Path(temp_cfg_file).write_text(hjson.dumps(idict, for_json=True))
269+
264270
except Exception as e:
265271
log.exception(
266272
'Failed to hjson-dump temp cfg file"%s" for "%s"(will be skipped!) due to:\n%s',
@@ -332,6 +338,7 @@ def _do_override(self, ov_name: str, ov_value: object) -> None:
332338
log.error('Override key "%s" not found in the cfg!', ov_name)
333339
sys.exit(1)
334340

341+
@abstractmethod
335342
def _purge(self) -> None:
336343
"""Purge the existing scratch areas in preparation for the new run."""
337344

@@ -340,6 +347,7 @@ def purge(self) -> None:
340347
for item in self.cfgs:
341348
item._purge()
342349

350+
@abstractmethod
343351
def _print_list(self) -> None:
344352
"""Print the list of available items that can be kicked off."""
345353

@@ -370,12 +378,13 @@ def prune_selected_cfgs(self) -> None:
370378
# Filter configurations
371379
self.cfgs = [c for c in self.cfgs if c.name in self.select_cfgs]
372380

381+
@abstractmethod
373382
def _create_deploy_objects(self) -> None:
374383
"""Create deploy objects from items that were passed on for being run.
384+
375385
The deploy objects for build and run are created from the objects that
376386
were created from the create_objects() method.
377387
"""
378-
return
379388

380389
def create_deploy_objects(self) -> None:
381390
"""Public facing API for _create_deploy_objects()."""
@@ -389,7 +398,7 @@ def create_deploy_objects(self) -> None:
389398
for item in self.cfgs:
390399
item._create_deploy_objects()
391400

392-
def deploy_objects(self):
401+
def deploy_objects(self) -> Mapping[Deploy, str]:
393402
"""Public facing API for deploying all available objects.
394403
395404
Runs each job and returns a map from item to status.
@@ -402,21 +411,26 @@ def deploy_objects(self):
402411
log.error("Nothing to run!")
403412
sys.exit(1)
404413

405-
return Scheduler(deploy, get_launcher_cls(), self.interactive).run()
414+
return Scheduler(
415+
items=deploy,
416+
launcher_cls=get_launcher_cls(),
417+
interactive=self.interactive,
418+
).run()
406419

407-
def _gen_results(self, results: Mapping[Deploy, str]) -> None:
408-
"""Generate results.
420+
@abstractmethod
421+
def _gen_results(self, results: Mapping[Deploy, str]) -> str:
422+
"""Generate flow results.
409423
410-
The function is called after the flow has completed. It collates the
411-
status of all run targets and generates a dict. It parses the log
424+
The function is called after the flow has completed. It collates
425+
the status of all run targets and generates a dict. It parses the log
412426
to identify the errors, warnings and failures as applicable. It also
413427
prints the full list of failures for debug / triage to the final
414428
report, which is in markdown format.
415429
416430
results should be a dictionary mapping deployed item to result.
417431
"""
418432

419-
def gen_results(self, results) -> None:
433+
def gen_results(self, results: Mapping[Deploy, str]) -> None:
420434
"""Public facing API for _gen_results().
421435
422436
results should be a dictionary mapping deployed item to result.
@@ -437,6 +451,7 @@ def gen_results(self, results) -> None:
437451
self.gen_results_summary()
438452
self.write_results(self.results_html_name, self.results_summary_md)
439453

454+
@abstractmethod
440455
def gen_results_summary(self) -> None:
441456
"""Public facing API to generate summary results for each IP/cfg file."""
442457

@@ -468,4 +483,5 @@ def _get_results_page_link(self, relative_to: str, link_text: str = "") -> str:
468483
return f"[{link_text}]({relative_link})"
469484

470485
def has_errors(self) -> bool:
486+
"""Return error state."""
471487
return self.errors_seen

src/dvsim/flow/sim.py

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,21 @@
44

55
"""Class describing simulation configuration object."""
66

7-
import collections
87
import fnmatch
98
import json
109
import os
1110
import re
1211
import sys
13-
from collections import OrderedDict
12+
from collections import OrderedDict, defaultdict
13+
from collections.abc import Mapping
1414
from datetime import datetime, timezone
1515
from pathlib import Path
1616
from typing import ClassVar
1717

1818
from tabulate import tabulate
1919

2020
from dvsim.flow.base import FlowCfg
21-
from dvsim.job.deploy import CompileSim, CovAnalyze, CovMerge, CovReport, CovUnr, RunTest
21+
from dvsim.job.deploy import CompileSim, CovAnalyze, CovMerge, CovReport, CovUnr, Deploy, RunTest
2222
from dvsim.logging import log
2323
from dvsim.modes import BuildMode, Mode, RunMode, find_mode
2424
from dvsim.regression import Regression
@@ -327,7 +327,7 @@ def _print_list(self) -> None:
327327
log.info(mode_name)
328328

329329
def _create_build_and_run_list(self) -> None:
330-
"""Generates a list of deployable objects from the provided items.
330+
"""Generate a list of deployable objects from the provided items.
331331
332332
Tests to be run are provided with --items switch. These can be glob-
333333
style patterns. This method finds regressions and tests that match
@@ -562,20 +562,13 @@ def cov_unr(self) -> None:
562562
for item in self.cfgs:
563563
item._cov_unr()
564564

565-
def _gen_json_results(self, run_results):
566-
"""Returns the run results as json-formatted dictionary."""
567-
568-
def _empty_str_as_none(s: str) -> str | None:
569-
"""Map an empty string to None and retain the value of a non-empty
570-
string.
571-
572-
This is intended to clearly distinguish an empty string, which may
573-
or may not be an valid value, from an invalid value.
574-
"""
575-
return s if s != "" else None
565+
def _gen_json_results(self, run_results: Mapping[Deploy, str]) -> str:
566+
"""Return the run results as json-formatted dictionary."""
576567

577568
def _pct_str_to_float(s: str) -> float | None:
578-
"""Map a percentage value stored in a string with ` %` suffix to a
569+
"""Extract percent or None.
570+
571+
Map a percentage value stored in a string with ` %` suffix to a
579572
float or to None if the conversion to Float fails.
580573
"""
581574
try:
@@ -608,7 +601,7 @@ def _test_result_to_dict(tr) -> dict:
608601
# Describe name of hardware block targeted by this run and optionally
609602
# the variant of the hardware block.
610603
results["block_name"] = self.name.lower()
611-
results["block_variant"] = _empty_str_as_none(self.variant.lower())
604+
results["block_variant"] = self.variant.lower() or None
612605

613606
# The timestamp for this run has been taken with `utcnow()` and is
614607
# stored in a custom format. Store it in standard ISO format with
@@ -620,7 +613,7 @@ def _test_result_to_dict(tr) -> dict:
620613
# Extract Git properties.
621614
m = re.search(r"https://github.com/.+?/tree/([0-9a-fA-F]+)", self.revision)
622615
results["git_revision"] = m.group(1) if m else None
623-
results["git_branch_name"] = _empty_str_as_none(self.branch)
616+
results["git_branch_name"] = self.branch or None
624617

625618
# Describe type of report and tool used.
626619
results["report_type"] = "simulation"
@@ -704,7 +697,7 @@ def _test_result_to_dict(tr) -> dict:
704697
if sim_results.buckets:
705698
by_tests = sorted(sim_results.buckets.items(), key=lambda i: len(i[1]), reverse=True)
706699
for bucket, tests in by_tests:
707-
unique_tests = collections.defaultdict(list)
700+
unique_tests = defaultdict(list)
708701
for test, line, context in tests:
709702
if not isinstance(test, RunTest):
710703
continue
@@ -743,16 +736,18 @@ def _test_result_to_dict(tr) -> dict:
743736
# Return the `results` dictionary as json string.
744737
return json.dumps(self.results_dict)
745738

746-
def _gen_results(self, run_results):
747-
"""The function is called after the regression has completed. It collates the
739+
def _gen_results(self, results: Mapping[Deploy, str]) -> str:
740+
"""Generate simulation results.
741+
742+
The function is called after the regression has completed. It collates the
748743
status of all run targets and generates a dict. It parses the testplan and
749744
maps the generated result to the testplan entries to generate a final table
750745
(list). It also prints the full list of failures for debug / triage. If cov
751746
is enabled, then the summary coverage report is also generated. The final
752747
result is in markdown format.
753748
"""
754749

755-
def indent_by(level):
750+
def indent_by(level: int) -> str:
756751
return " " * (4 * level)
757752

758753
def create_failure_message(test, line, context):
@@ -769,7 +764,7 @@ def create_failure_message(test, line, context):
769764
return message
770765

771766
def create_bucket_report(buckets):
772-
"""Creates a report based on the given buckets.
767+
"""Create a report based on the given buckets.
773768
774769
The buckets are sorted by descending number of failures. Within
775770
buckets this also group tests by unqualified name, and just a few
@@ -787,7 +782,7 @@ def create_bucket_report(buckets):
787782
fail_msgs = ["\n## Failure Buckets", ""]
788783
for bucket, tests in by_tests:
789784
fail_msgs.append(f"* `{bucket}` has {len(tests)} failures:")
790-
unique_tests = collections.defaultdict(list)
785+
unique_tests = defaultdict(list)
791786
for test, line, context in tests:
792787
unique_tests[test.name].append((test, line, context))
793788
for name, test_reseeds in list(unique_tests.items())[:_MAX_UNIQUE_TESTS]:
@@ -812,7 +807,7 @@ def create_bucket_report(buckets):
812807
return fail_msgs
813808

814809
deployed_items = self.deploy
815-
results = SimResults(deployed_items, run_results)
810+
results = SimResults(deployed_items, results)
816811

817812
# Generate results table for runs.
818813
results_str = "## " + self.results_title + "\n"
@@ -881,7 +876,7 @@ def create_bucket_report(buckets):
881876

882877
# Append coverage results if coverage was enabled.
883878
if self.cov_report_deploy is not None:
884-
report_status = run_results[self.cov_report_deploy]
879+
report_status = results[self.cov_report_deploy]
885880
if report_status == "P":
886881
results_str += "\n## Coverage Results\n"
887882
# Link the dashboard page using "cov_report_page" value.

src/dvsim/job/deploy.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
if TYPE_CHECKING:
2828
from dvsim.flow.sim import SimCfg
29+
from dvsim.launcher.base import Launcher
2930

3031

3132
class Deploy:
@@ -92,7 +93,7 @@ def __init__(self, sim_cfg: "SimCfg") -> None:
9293
self.cmd = self._construct_cmd()
9394

9495
# Launcher instance created later using create_launcher() method.
95-
self.launcher = None
96+
self.launcher: Launcher | None = None
9697

9798
# Job's wall clock time (a.k.a CPU time, or runtime).
9899
self.job_runtime = JobTime()
@@ -484,7 +485,7 @@ class RunTest(Deploy):
484485
fixed_seed = None
485486
cmds_list_vars = ["pre_run_cmds", "post_run_cmds"]
486487

487-
def __init__(self, index, test, build_job, sim_cfg) -> None:
488+
def __init__(self, index, test, build_job, sim_cfg: "SimCfg") -> None:
488489
self.test_obj = test
489490
self.index = index
490491
self.build_seed = sim_cfg.build_seed

src/dvsim/launcher/factory.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@
2020
EDACLOUD_LAUNCHER_EXISTS = False
2121

2222
# The chosen launcher class.
23-
_LAUNCHER_CLS = None
23+
_LAUNCHER_CLS: type[Launcher] | None = None
2424

2525

26-
def set_launcher_type(is_local=False) -> None:
27-
"""Sets the launcher type that will be used to launch the jobs.
26+
def set_launcher_type(is_local: bool = False) -> None:
27+
"""Set the launcher type that will be used to launch the jobs.
2828
2929
The env variable `DVSIM_LAUNCHER` is used to identify what launcher system
3030
to use. This variable is specific to the user's work site. It is meant to
@@ -66,7 +66,7 @@ def set_launcher_type(is_local=False) -> None:
6666
_LAUNCHER_CLS = LocalLauncher
6767

6868

69-
def get_launcher_cls():
69+
def get_launcher_cls() -> type[Launcher]:
7070
"""Returns the chosen launcher class."""
7171
assert _LAUNCHER_CLS is not None
7272
return _LAUNCHER_CLS

0 commit comments

Comments
 (0)