Skip to content

Commit 49e209a

Browse files
authored
Merge pull request #413 from mdboom/fix-checking-python-file
Fix checking python file
2 parents d68f94e + da80a9c commit 49e209a

File tree

2 files changed

+58
-26
lines changed

2 files changed

+58
-26
lines changed

bench_runner/scripts/install.py

Lines changed: 58 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from pathlib import Path
1010
import shutil
1111
import sys
12-
from typing import Any
12+
from typing import Any, Callable, TextIO
1313

1414

1515
import rich
@@ -41,35 +41,47 @@ def fail_check(dst: PathLike):
4141
sys.exit(1)
4242

4343

44-
def write_yaml(dst: PathLike, contents: Any, check: bool):
44+
def write_and_check(dst: PathLike, writer: Callable[[TextIO], None], check: bool):
4545
"""
46-
Write `contents` to `dst` as YAML.
46+
Call `writer` with a file descriptor to write the contents to `dst`.
4747
4848
If `check` is True, raise SystemExit if the file would change. This is used
4949
in CI to confirm that the file was regenerated after changes to the source
5050
file.
5151
"""
5252
dst = Path(dst)
5353

54-
def do_write(contents, fd):
55-
fd.write("# Generated file: !!! DO NOT EDIT !!!\n")
56-
fd.write("---\n")
57-
yaml = YAML()
58-
yaml.dump(contents, fd)
59-
6054
if check:
6155
if not dst.is_file():
6256
fail_check(dst)
6357

64-
orig_contents = dst.read_text()
6558
fd = io.StringIO()
66-
do_write(contents, fd)
59+
orig_contents = dst.read_text()
60+
writer(fd)
6761
new_contents = fd.getvalue()
6862
if orig_contents != new_contents:
6963
fail_check(dst)
7064
else:
7165
with dst.open("w") as fd:
72-
do_write(contents, fd)
66+
writer(fd)
67+
68+
69+
def write_yaml(dst: PathLike, contents: Any, check: bool):
70+
"""
71+
Write `contents` to `dst` as YAML.
72+
73+
If `check` is True, raise SystemExit if the file would change. This is used
74+
in CI to confirm that the file was regenerated after changes to the source
75+
file.
76+
"""
77+
78+
def do_write(fd):
79+
fd.write("# Generated file: !!! DO NOT EDIT !!!\n")
80+
fd.write("---\n")
81+
yaml = YAML()
82+
yaml.dump(contents, fd)
83+
84+
return write_and_check(dst, do_write, check)
7385

7486

7587
def load_yaml(src: PathLike) -> Any:
@@ -81,6 +93,22 @@ def load_yaml(src: PathLike) -> Any:
8193
return yaml.load(fd)
8294

8395

96+
def write_python(dst: PathLike, contents: str, check: bool):
97+
"""
98+
Write a string of Python code to a file, adding a header about it being generated.
99+
100+
If `check` is True, raise SystemExit if the file would change. This is used
101+
in CI to confirm that the file was regenerated after changes to the source
102+
file.
103+
"""
104+
105+
def do_write(fd):
106+
fd.write("# Generated file: !!! DO NOT EDIT !!!\n\n")
107+
fd.write(contents)
108+
109+
return write_and_check(dst, do_write, check)
110+
111+
84112
def add_flag_variables(dst: dict[str, Any]) -> None:
85113
for flag in flags.FLAGS:
86114
dst[flag.gha_variable] = {
@@ -280,22 +308,26 @@ def generate_generic(dst: Any) -> Any:
280308
def _main(check: bool) -> None:
281309
WORKFLOW_PATH.mkdir(parents=True, exist_ok=True)
282310

283-
for path in TEMPLATE_PATH.glob("*"):
284-
if path.name.endswith(".src.yml") or path.name == "env.yml":
311+
for src_path in TEMPLATE_PATH.glob("*"):
312+
if not src_path.is_file():
285313
continue
286314

287-
if not (ROOT_PATH / path.name).is_file() or path.suffix == ".py":
288-
if check:
289-
fail_check(ROOT_PATH / path.name)
290-
else:
291-
shutil.copyfile(path, ROOT_PATH / path.name)
292-
293-
for src_path in TEMPLATE_PATH.glob("*.src.yml"):
294-
dst_path = WORKFLOW_PATH / (src_path.name[:-8] + ".yml")
295-
generator = GENERATORS.get(src_path.name, generate_generic)
296-
src = load_yaml(src_path)
297-
dst = generator(src)
298-
write_yaml(dst_path, dst, check)
315+
if src_path.name.endswith(".src.yml"):
316+
dst_path = WORKFLOW_PATH / (src_path.name[:-8] + ".yml")
317+
generator = GENERATORS.get(src_path.name, generate_generic)
318+
src = load_yaml(src_path)
319+
dst = generator(src)
320+
write_yaml(dst_path, dst, check)
321+
elif src_path.name.endswith(".src.py"):
322+
dst_path = WORKFLOW_PATH / (src_path.name[:-7] + ".py")
323+
write_python(dst_path, src_path.read_text(), check)
324+
else:
325+
dst_path = ROOT_PATH / src_path.name
326+
if not dst_path.is_file():
327+
if check:
328+
fail_check(dst_path)
329+
else:
330+
shutil.copyfile(src_path, dst_path)
299331

300332

301333
def main():

0 commit comments

Comments
 (0)