Skip to content

Commit 6da3bc5

Browse files
committed
more typing
1 parent 1044c30 commit 6da3bc5

File tree

6 files changed

+94
-71
lines changed

6 files changed

+94
-71
lines changed

bin/constraints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def math_eval(text: str) -> Optional[int | float]:
7676
# eval is dangerous, but on the other hand we run submission code so this is fine
7777
text = text.replace("^", "**")
7878
value = eval(text, {"__builtin__": None})
79-
return value if value is isinstance(value, (int, float)) else None
79+
return value if isinstance(value, (int, float)) else None
8080
except (SyntaxError, NameError, TypeError, ZeroDivisionError):
8181
return None
8282

bin/contest.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
import config
22

33
from pathlib import Path
4+
from typing import cast, Any, Optional
45

56
from util import *
67

78
# Read the contest.yaml, if available
8-
_contest_yaml = None
9+
_contest_yaml: Optional[dict[str, Any]] = None
910

1011

11-
def contest_yaml():
12+
def contest_yaml() -> dict[str, Any]:
1213
global _contest_yaml
1314
if _contest_yaml is not None:
1415
return _contest_yaml
@@ -25,22 +26,22 @@ def contest_yaml():
2526
_problems_yaml = None
2627

2728

28-
def problems_yaml():
29+
def problems_yaml() -> Optional[list[dict[str, Any]]]:
2930
global _problems_yaml
30-
if _problems_yaml:
31-
return _problems_yaml
3231
if _problems_yaml is False:
3332
return None
33+
if _problems_yaml:
34+
return _problems_yaml
3435

3536
problemsyaml_path = Path("problems.yaml")
3637
if not problemsyaml_path.is_file():
3738
_problems_yaml = False
3839
return None
3940
_problems_yaml = read_yaml(problemsyaml_path)
40-
return _problems_yaml
41+
return cast(list[dict[str, Any]], _problems_yaml)
4142

4243

43-
def get_api():
44+
def get_api() -> str:
4445
api = config.args.api or contest_yaml().get("api")
4546
if not api:
4647
fatal(
@@ -105,7 +106,7 @@ def call_api(method, endpoint, **kwargs):
105106
return r
106107

107108

108-
def call_api_get_json(url):
109+
def call_api_get_json(url: str):
109110
r = call_api("GET", url)
110111
r.raise_for_status()
111112
try:

bin/fuzz.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import time
88
import threading
99
from pathlib import Path
10+
from typing import Any, Optional
1011

1112
import parallel
1213
from util import *
@@ -24,9 +25,11 @@
2425

2526

2627
class GeneratorTask:
27-
def __init__(self, fuzz: "Fuzz", t, i, tmp_id):
28+
def __init__(self, fuzz: "Fuzz", t: generate.TestcaseRule, i: int, tmp_id: int):
2829
self.fuzz = fuzz
29-
self.generator = t.generator
30+
generator = t.generator
31+
assert generator is not None
32+
self.generator = generator
3033
self.solution = t.config.solution
3134
self.i = i
3235
self.tmp_id = tmp_id
@@ -39,13 +42,13 @@ def __init__(self, fuzz: "Fuzz", t, i, tmp_id):
3942
self.save_mutex = threading.Lock()
4043
self.saved = False
4144

42-
def run(self, bar):
45+
def run(self, bar: ProgressBar) -> None:
4346
if self._run(bar):
4447
self.fuzz.finish_task(self.tmp_id)
4548
else:
4649
self.fuzz.finish_task(self.tmp_id, 1 + len(self.fuzz.submissions))
4750

48-
def _run(self, bar):
51+
def _run(self, bar: ProgressBar) -> bool:
4952
# GENERATE THE TEST DATA
5053
dir = Path("fuzz") / f"tmp_id_{str(self.tmp_id)}"
5154
cwd = self.fuzz.problem.tmpdir / "tool_runs" / dir
@@ -104,7 +107,7 @@ def _run(self, bar):
104107
self.fuzz.queue.put(SubmissionTask(self, submission, testcase, self.tmp_id))
105108
return True
106109

107-
def save_test(self, bar):
110+
def save_test(self, bar: ProgressBar) -> None:
108111
if self.saved:
109112
return
110113
save = False
@@ -122,17 +125,23 @@ def save_test(self, bar):
122125

123126

124127
class SubmissionTask:
125-
def __init__(self, generator_task, submission, testcase, tmp_id):
128+
def __init__(
129+
self,
130+
generator_task: GeneratorTask,
131+
submission: run.Submission,
132+
testcase: Testcase,
133+
tmp_id: int,
134+
):
126135
self.generator_task = generator_task
127136
self.submission = submission
128137
self.testcase = testcase
129138
self.tmp_id = tmp_id
130139

131-
def run(self, bar):
140+
def run(self, bar: ProgressBar) -> None:
132141
self._run(bar)
133142
self.generator_task.fuzz.finish_task(self.tmp_id)
134143

135-
def _run(self, bar):
144+
def _run(self, bar: ProgressBar) -> None:
136145
r = run.Run(self.generator_task.fuzz.problem, self.submission, self.testcase)
137146
localbar = bar.start(f"{self.generator_task.i}: {self.submission.name}")
138147
result = r.run(localbar)
@@ -155,10 +164,11 @@ def __init__(self, problem: problem.Problem):
155164
# Filter to only keep valid rules depending on seed without duplicates from count
156165
added_testcase_rules = set()
157166

158-
def add_testcase(t):
167+
def add_testcase(t: generate.TestcaseRule) -> None:
159168
if (
160169
t.in_is_generated
161170
and t.parse_error is None
171+
and t.generator is not None
162172
and t.generator.uses_seed
163173
and t.generator.command_string.strip() not in added_testcase_rules
164174
):
@@ -177,7 +187,7 @@ def add_testcase(t):
177187
# SUBMISSIONS
178188
self.submissions = self.problem.selected_or_accepted_submissions()
179189

180-
def run(self):
190+
def run(self) -> bool:
181191
if not has_ryaml:
182192
error("Fuzzing needs the ruamel.yaml python3 library. Install python[3]-ruamel.yaml.")
183193
return False
@@ -192,7 +202,7 @@ def run(self):
192202

193203
message("Press CTRL+C to stop\n", "Fuzz", color_type=MessageType.LOG)
194204

195-
def runner(task: GeneratorTask):
205+
def runner(task: GeneratorTask | SubmissionTask) -> None:
196206
task.run(bar)
197207

198208
# config.args.no_bar = True
@@ -203,7 +213,7 @@ def runner(task: GeneratorTask):
203213
self.tasks = 0
204214
self.queue = parallel.new_queue(runner, pin=True)
205215

206-
def soft_exit(sig, frame):
216+
def soft_exit(sig: Any, frame: Any) -> None:
207217
if self.queue.aborted:
208218
fatal("Running interrupted", force=True)
209219
else:
@@ -240,7 +250,7 @@ def soft_exit(sig, frame):
240250

241251
# finish task from generator with tmp_id
242252
# also add new tasks if queue becomes too empty
243-
def finish_task(self, tmp_id=None, count=1):
253+
def finish_task(self, tmp_id: Optional[int] = None, count: int = 1) -> None:
244254
with self.queue:
245255
# return tmp_id (and reuse it if all submissions are finished)
246256
if tmp_id is not None:
@@ -259,18 +269,18 @@ def finish_task(self, tmp_id=None, count=1):
259269
self.iteration += 1
260270
# 1 new generator tasks which will also create one task per submission
261271
new_tasks = 1 + len(self.submissions)
262-
tmp_id = min(self.free_tmp_id)
263-
self.free_tmp_id.remove(tmp_id)
264-
self.tmp_id_count[tmp_id] = new_tasks
272+
new_tmp_id = min(self.free_tmp_id)
273+
self.free_tmp_id.remove(new_tmp_id)
274+
self.tmp_id_count[new_tmp_id] = new_tasks
265275
self.tasks += new_tasks
266276
self.queue.put(
267-
GeneratorTask(self, testcase_rule, self.iteration, tmp_id),
277+
GeneratorTask(self, testcase_rule, self.iteration, new_tmp_id),
268278
priority=1,
269279
)
270280

271281
# Write new rule to yaml
272282
# lock between read and write to ensure that no rule gets lost
273-
def save_test(self, command):
283+
def save_test(self, command: str) -> None:
274284
with self.generators_yaml_mutex:
275285
generators_yaml = self.problem.path / "generators/generators.yaml"
276286
data = None

bin/parallel.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def __init__(self, task: T, priority: int, index: int):
1919
self.index = index
2020

2121
# Note: heapq uses a min heap, so higher priorities are 'smaller'.
22-
def __lt__(self, other):
22+
def __lt__(self, other: "QueueItem[T]") -> bool:
2323
if self.priority != other.priority:
2424
# python priority queue is a min heap but larger priority
2525
# items should come first => reverse compare
@@ -45,24 +45,24 @@ def __init__(self, f: Callable[[T], Any], pin: bool):
4545
# mutex to lock parallel access
4646
self.mutex = threading.RLock()
4747

48-
def __enter__(self):
48+
def __enter__(self) -> None:
4949
self.mutex.__enter__()
5050

51-
def __exit__(self, *args):
51+
def __exit__(self, *args: Any) -> None:
5252
self.mutex.__exit__(*args)
5353

5454
# Add one task. Higher priority => done first
55-
def put(self, task: T, priority=0):
55+
def put(self, task: T, priority: int = 0) -> None:
5656
raise Exception("Abstract method")
5757

5858
# By default, do nothing on .join(). This is overridden in ParallelQueue.
59-
def join(self):
59+
def join(self) -> None:
6060
return
6161

62-
def done(self):
62+
def done(self) -> None:
6363
raise Exception("Abstract method")
6464

65-
def abort(self):
65+
def abort(self) -> None:
6666
self.aborted = True
6767

6868

@@ -71,7 +71,7 @@ def __init__(self, f: Callable[[T], Any], pin: bool):
7171
super().__init__(f, pin)
7272

7373
# Add one task. Higher priority => done first
74-
def put(self, task: T, priority: int = 0):
74+
def put(self, task: T, priority: int = 0) -> None:
7575
# no task will be handled after self.abort() so skip adding
7676
if self.aborted:
7777
return
@@ -80,7 +80,7 @@ def put(self, task: T, priority: int = 0):
8080
heapq.heappush(self.tasks, QueueItem(task, priority, self.total_tasks))
8181

8282
# Execute all tasks.
83-
def done(self):
83+
def done(self) -> None:
8484
if self.pin:
8585
cores = list(os.sched_getaffinity(0))
8686
os.sched_setaffinity(0, {cores[0]})
@@ -127,7 +127,7 @@ def __init__(self, f: Callable[[T], Any], pin: bool, num_threads: int):
127127

128128
signal.signal(signal.SIGINT, self._interrupt_handler)
129129

130-
def _worker(self, cores: Literal[False] | list[int] = False):
130+
def _worker(self, cores: Literal[False] | list[int] = False) -> None:
131131
if cores is not False:
132132
os.sched_setaffinity(0, cores)
133133
while True:
@@ -164,10 +164,10 @@ def _worker(self, cores: Literal[False] | list[int] = False):
164164
if self.missing == 0:
165165
self.all_done.notify_all()
166166

167-
def _interrupt_handler(self, sig, frame):
167+
def _interrupt_handler(self, sig: Any, frame: Any) -> None:
168168
util.fatal("Running interrupted", force=True)
169169

170-
def _handle_first_error(self):
170+
def _handle_first_error(self) -> None:
171171
if self.first_error is not None:
172172
first_error = self.first_error
173173
self.first_error = None
@@ -177,7 +177,7 @@ def _handle_first_error(self):
177177
raise first_error
178178

179179
# Add one task. Higher priority => done first
180-
def put(self, task: T, priority: int = 0):
180+
def put(self, task: T, priority: int = 0) -> None:
181181
with self.mutex:
182182
# no task should be added after .done() was called
183183
assert not self.finish
@@ -189,14 +189,14 @@ def put(self, task: T, priority: int = 0):
189189
heapq.heappush(self.tasks, QueueItem(task, priority, self.total_tasks))
190190
self.todo.notify()
191191

192-
def join(self):
192+
def join(self) -> None:
193193
# wait for all current task to be completed
194194
with self.all_done:
195195
self.all_done.wait_for(lambda: self.missing == 0)
196196
self._handle_first_error()
197197

198198
# Wait for all tasks to be done and stop all threads
199-
def done(self):
199+
def done(self) -> None:
200200
self.finish = True
201201

202202
# notify all workers with permission to leave main loop
@@ -213,7 +213,7 @@ def done(self):
213213

214214
# Discard all remaining work in the queue and stop all workers.
215215
# Call done() to join the threads.
216-
def abort(self):
216+
def abort(self) -> None:
217217
super().abort()
218218

219219
with self.mutex:
@@ -227,7 +227,7 @@ def abort(self):
227227
self.all_done.notify_all()
228228

229229

230-
def new_queue(f: Callable[[T], Any], pin: bool = False):
230+
def new_queue(f: Callable[[T], Any], pin: bool = False) -> AbstractQueue[T]:
231231
"""
232232
f(task): the function to run on each queue item.
233233
@@ -242,7 +242,7 @@ def new_queue(f: Callable[[T], Any], pin: bool = False):
242242
return SequentialQueue(f, pin)
243243

244244

245-
def run_tasks(f: Callable[[T], Any], tasks: Sequence[T], pin: bool = False):
245+
def run_tasks(f: Callable[[T], Any], tasks: Sequence[T], pin: bool = False) -> None:
246246
queue = new_queue(f, pin)
247247
for task in tasks:
248248
queue.put(task)

bin/problem.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -769,7 +769,7 @@ def download_samples(p) -> list[tuple[Path, Path]]:
769769
return [t for t in testcases if isinstance(t, tuple)]
770770

771771
# Returns the list of submissions passed as command-line arguments, or the list of accepted submissions by default.
772-
def selected_or_accepted_submissions(problem) -> list["run.Submission"]:
772+
def selected_or_accepted_submissions(problem) -> list[run.Submission]:
773773
submissions = problem.submissions()
774774
if not submissions:
775775
return []
@@ -778,7 +778,7 @@ def selected_or_accepted_submissions(problem) -> list["run.Submission"]:
778778
else:
779779
return [s for s in submissions if s.expected_verdicts == [verdicts.Verdict.ACCEPTED]]
780780

781-
def submissions(problem) -> list["run.Submission"] | Literal[False]:
781+
def submissions(problem) -> list[run.Submission] | Literal[False]:
782782
if problem._submissions is not None:
783783
if problem._submissions is False:
784784
return False

0 commit comments

Comments
 (0)