Skip to content

Commit 9fa38ff

Browse files
Add type hints for Compare.program
1 parent fc722ac commit 9fa38ff

File tree

3 files changed

+118
-97
lines changed

3 files changed

+118
-97
lines changed

cyaron/compare.py

Lines changed: 93 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
11
from __future__ import absolute_import, print_function
2-
from .io import IO
3-
from . import log
4-
from cyaron.utils import *
5-
from cyaron.consts import *
6-
from cyaron.graders import CYaRonGraders
7-
import subprocess
2+
83
import multiprocessing
4+
import os
5+
import subprocess
96
import sys
7+
from concurrent.futures import ThreadPoolExecutor
108
from io import open
11-
import os
9+
from typing import List, Optional, Tuple, Union
10+
11+
from cyaron.consts import *
12+
from cyaron.graders import CYaRonGraders
13+
from cyaron.utils import *
14+
15+
from . import log
16+
from .io import IO
1217

1318

1419
class CompareMismatch(ValueError):
@@ -34,13 +39,18 @@ def __compare_two(name, content, std, grader):
3439
raise CompareMismatch(name, info)
3540

3641
@staticmethod
37-
def __process_file(file):
42+
def __process_output_file(file: Union[str, IO]):
3843
if isinstance(file, IO):
44+
if file.output_filename is None:
45+
raise ValueError("IO object has no output file.")
3946
file.flush_buffer()
40-
file.output_file.seek(0)
41-
return file.output_filename, file.output_file.read()
47+
with open(file.output_filename,
48+
"r",
49+
newline="\n",
50+
encoding='utf-8') as f:
51+
return file.output_filename, f.read()
4252
else:
43-
with open(file, "r", newline="\n") as f:
53+
with open(file, "r", newline="\n", encoding="utf-8") as f:
4454
return file, f.read()
4555

4656
@staticmethod
@@ -87,15 +97,15 @@ def output(cls, *files, **kwargs):
8797
pass
8898

8999
def get_std():
90-
return cls.__process_file(std)[1]
100+
return cls.__process_output_file(std)[1]
91101

92102
if job_pool is not None:
93103
std = job_pool.submit(get_std).result()
94104
else:
95105
std = get_std()
96106

97107
def do(file):
98-
(file_name, content) = cls.__process_file(file)
108+
(file_name, content) = cls.__process_output_file(file)
99109
cls.__compare_two(file_name, content, std, grader)
100110

101111
if job_pool is not None:
@@ -104,35 +114,36 @@ def do(file):
104114
[x for x in map(do, files)]
105115

106116
@classmethod
107-
def program(cls, *programs, **kwargs):
108-
kwargs = unpack_kwargs(
109-
"program",
110-
kwargs,
111-
(
112-
"input",
113-
("std", None),
114-
("std_program", None),
115-
("grader", DEFAULT_GRADER),
116-
("max_workers", -1),
117-
("job_pool", None),
118-
("stop_on_incorrect", None),
119-
),
120-
)
121-
input = kwargs["input"]
122-
std = kwargs["std"]
123-
std_program = kwargs["std_program"]
124-
grader = kwargs["grader"]
125-
max_workers = kwargs["max_workers"]
126-
job_pool = kwargs["job_pool"]
127-
if kwargs["stop_on_incorrect"] is not None:
117+
def program(cls,
118+
*programs: Optional[Union[str, Tuple[str, ...], List[str]]],
119+
input: Union[IO, str],
120+
std: Optional[Union[str, IO]] = None,
121+
std_program: Optional[Union[str, Tuple[str, ...],
122+
List[str]]] = None,
123+
grader: Optional[str] = DEFAULT_GRADER,
124+
max_workers: int = -1,
125+
job_pool: Optional[ThreadPoolExecutor] = None,
126+
stop_on_incorrect=None):
127+
"""
128+
Compare the output of the programs with the standard output.
129+
130+
Args:
131+
programs: The programs to be compared.
132+
input: The input file.
133+
std: The standard output file.
134+
std_program: The program that generates the standard output.
135+
grader: The grader to be used.
136+
max_workers: The maximum number of workers.
137+
job_pool: The job pool.
138+
stop_on_incorrect: Deprecated and has no effect.
139+
"""
140+
if stop_on_incorrect is not None:
128141
log.warn(
129142
"parameter stop_on_incorrect is deprecated and has no effect.")
130143

131144
if (max_workers is None or max_workers >= 0) and job_pool is None:
132145
max_workers = cls.__normal_max_workers(max_workers)
133146
try:
134-
from concurrent.futures import ThreadPoolExecutor
135-
136147
with ThreadPoolExecutor(max_workers=max_workers) as job_pool:
137148
return cls.program(*programs,
138149
input=input,
@@ -144,74 +155,74 @@ def program(cls, *programs, **kwargs):
144155
except ImportError:
145156
pass
146157

147-
if not isinstance(input, IO):
148-
raise TypeError("expect {}, got {}".format(
149-
type(IO).__name__,
150-
type(input).__name__))
151-
input.flush_buffer()
152-
input.input_file.seek(0)
158+
if isinstance(input, IO):
159+
input.flush_buffer()
153160

154161
if std_program is not None:
155162

156-
def get_std():
157-
with open(os.dup(input.input_file.fileno()), "r",
158-
newline="\n") as input_file:
159-
content = make_unicode(
160-
subprocess.check_output(
161-
std_program,
162-
shell=(not list_like(std_program)),
163-
stdin=input.input_file,
164-
universal_newlines=True,
165-
))
166-
input_file.seek(0)
163+
def get_std_from_std_program():
164+
with open(input.input_filename
165+
if isinstance(input, IO) else input,
166+
"r",
167+
newline="\n",
168+
encoding="utf-8") as input_file:
169+
content = subprocess.check_output(
170+
std_program,
171+
shell=(not list_like(std_program)),
172+
stdin=input_file,
173+
universal_newlines=True,
174+
encoding="utf-8")
167175
return content
168176

169177
if job_pool is not None:
170-
std = job_pool.submit(get_std).result()
178+
std = job_pool.submit(get_std_from_std_program).result()
171179
else:
172-
std = get_std()
180+
std = get_std_from_std_program()
173181
elif std is not None:
174182

175-
def get_std():
176-
return cls.__process_file(std)[1]
183+
def get_std_from_std_file():
184+
return cls.__process_output_file(std)[1]
177185

178186
if job_pool is not None:
179-
std = job_pool.submit(get_std).result()
187+
std = job_pool.submit(get_std_from_std_file).result()
180188
else:
181-
std = get_std()
189+
std = get_std_from_std_file()
182190
else:
183191
raise TypeError(
184192
"program() missing 1 required non-None keyword-only argument: 'std' or 'std_program'"
185193
)
186194

187-
def do(program_name):
188-
timeout = None
189-
if (list_like(program_name) and len(program_name) == 2
190-
and int_like(program_name[-1])):
191-
program_name, timeout = program_name
192-
with open(os.dup(input.input_file.fileno()), "r",
193-
newline="\n") as input_file:
195+
with open(input.input_filename if isinstance(input, IO) else input,
196+
"r",
197+
newline="\n",
198+
encoding="utf-8") as input_file:
199+
200+
def do(program_name):
201+
timeout = None
202+
if (list_like(program_name) and len(program_name) == 2
203+
and int_like(program_name[-1])):
204+
program_name, timeout = program_name
194205
if timeout is None:
195-
content = make_unicode(
196-
subprocess.check_output(
197-
program_name,
198-
shell=(not list_like(program_name)),
199-
stdin=input_file,
200-
universal_newlines=True,
201-
))
206+
content = subprocess.check_output(
207+
program_name,
208+
shell=(not list_like(program_name)),
209+
stdin=input_file,
210+
universal_newlines=True,
211+
encoding="utf-8",
212+
)
202213
else:
203-
content = make_unicode(
204-
subprocess.check_output(
214+
content = subprocess.check_output(
205215
program_name,
206216
shell=(not list_like(program_name)),
207217
stdin=input_file,
208218
universal_newlines=True,
209219
timeout=timeout,
210-
))
211-
input_file.seek(0)
212-
cls.__compare_two(program_name, content, std, grader)
220+
encoding="utf-8",
221+
)
222+
cls.__compare_two(program_name, content, std, grader)
213223

214-
if job_pool is not None:
215-
job_pool.map(do, programs)
216-
else:
217-
[x for x in map(do, programs)]
224+
if job_pool is not None:
225+
job_pool.map(do, programs)
226+
else:
227+
for program in programs:
228+
do(program)

cyaron/io.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def __init__( # type: ignore
100100
output_file = "{}{{}}{}".format(
101101
self.__escape_format(file_prefix),
102102
self.__escape_format(output_suffix))
103-
self.input_filename, self.output_filename = None, None
103+
self.input_filename, self.output_filename = cast(str, None), None
104104
self.__input_temp, self.__output_temp = False, False
105105
self.__init_file(input_file, data_id, "i", make_dirs)
106106
if not disable_output:
@@ -357,3 +357,5 @@ def output_clear_content(self, pos: int = 0):
357357
def flush_buffer(self):
358358
"""Flush the input file"""
359359
self.input_file.flush()
360+
if self.output_file:
361+
self.output_file.flush()

cyaron/tests/compare_test.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -108,28 +108,36 @@ def test_fulltext_program(self):
108108
correct_out = 'python correct.py: Correct \npython incorrect.py: !!!INCORRECT!!! Hash mismatch: read 53c234e5e8472b6ac51c1ae1cab3fe06fad053beb8ebfd8977b010655bfdd3c3, expected 4355a46b19d348dc2f57c046f8ef63d4538ebb936000f3c9ee954a27460dd865'
109109
self.assertEqual(result, correct_out)
110110

111-
def test_file_input(self):
111+
def test_file_input_success(self):
112112
with open("correct.py", "w") as f:
113113
f.write("print(input())")
114-
115114
with open("std.py", "w") as f:
116115
f.write("print(input())")
117-
118-
io = None
119-
with captured_output() as (out, err):
120-
io = IO()
121-
116+
io = IO()
122117
io.input_writeln("233")
123-
124-
with captured_output() as (out, err):
125-
Compare.program("python correct.py",
126-
std_program="python std.py",
118+
with captured_output():
119+
Compare.program((sys.executable, "correct.py"),
120+
std_program=(sys.executable, "std.py"),
127121
input=io,
128122
grader="NOIPStyle")
129123

130-
result = out.getvalue().strip()
131-
correct_out = 'python correct.py: Correct'
132-
self.assertEqual(result, correct_out)
124+
def test_file_input_fail(self):
125+
with open("correct.py", "w") as f:
126+
f.write("print(input())")
127+
with open("std.py", "w") as f:
128+
f.write("print(input()+'154')")
129+
io = IO()
130+
io.input_writeln("233")
131+
try:
132+
with captured_output():
133+
Compare.program((sys.executable, "correct.py"),
134+
std_program=(sys.executable, "std.py"),
135+
input=io,
136+
grader="NOIPStyle")
137+
except CompareMismatch:
138+
pass
139+
else:
140+
self.fail("Should raise CompareMismatch")
133141

134142
def test_concurrent(self):
135143
programs = ['test{}.py'.format(i) for i in range(16)]

0 commit comments

Comments
 (0)