Skip to content

Commit a131374

Browse files
grader支持获取input
1 parent b5b84b1 commit a131374

File tree

4 files changed

+117
-69
lines changed

4 files changed

+117
-69
lines changed

cyaron/compare.py

Lines changed: 44 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
from __future__ import absolute_import, print_function
22

33
import multiprocessing
4-
import os
54
import subprocess
65
import sys
76
from concurrent.futures import ThreadPoolExecutor
87
from io import open
9-
from typing import List, Optional, Tuple, Union
8+
from typing import List, Optional, Tuple, Union, cast
109

1110
from cyaron.consts import *
12-
from cyaron.graders import CYaRonGraders, GraderType
11+
from cyaron.graders import CYaRonGraders, GraderType3
1312
from cyaron.utils import *
1413

1514
from . import log
@@ -27,14 +26,16 @@ def __str__(self):
2726
return "In program: '{}'. {}".format(self.name, self.mismatch)
2827

2928

30-
PrgoramType = Optional[Union[str, Tuple[str, ...], List[str]]]
29+
PrgoramType = Union[str, Tuple[str, ...], List[str]]
3130

3231

3332
class Compare:
3433

3534
@staticmethod
36-
def __compare_two(name, content, std, grader):
37-
result, info = CYaRonGraders.invoke(grader, content, std)
35+
def __compare_two(name: PrgoramType, content: str, std: str,
36+
input_content: str, grader: Union[str, GraderType3]):
37+
result, info = CYaRonGraders.invoke(grader, content, std,
38+
input_content)
3839
status = "Correct" if result else "!!!INCORRECT!!!"
3940
info = info if info is not None else ""
4041
log.debug("{}: {} {}".format(name, status, info))
@@ -77,7 +78,7 @@ def output(cls, *files, **kwargs):
7778
("stop_on_incorrect", None),
7879
),
7980
)
80-
std = kwargs["std"]
81+
std: IO = kwargs["std"]
8182
grader = kwargs["grader"]
8283
max_workers = kwargs["max_workers"]
8384
job_pool = kwargs["job_pool"]
@@ -101,13 +102,18 @@ def get_std():
101102
return cls.__process_output_file(std)[1]
102103

103104
if job_pool is not None:
104-
std = job_pool.submit(get_std).result()
105+
std_answer = job_pool.submit(get_std).result()
105106
else:
106-
std = get_std()
107+
std_answer = get_std()
108+
109+
with open(std.input_filename, "r", newline="\n",
110+
encoding="utf-8") as input_file:
111+
input_text = input_file.read()
107112

108113
def do(file):
109114
(file_name, content) = cls.__process_output_file(file)
110-
cls.__compare_two(file_name, content, std, grader)
115+
cls.__compare_two(file_name, content, std_answer, input_text,
116+
grader)
111117

112118
if job_pool is not None:
113119
job_pool.map(do, files)
@@ -121,8 +127,8 @@ def program(cls,
121127
std: Optional[Union[str, IO]] = None,
122128
std_program: Optional[Union[str, Tuple[str, ...],
123129
List[str]]] = None,
124-
grader: Union[str, GraderType] = DEFAULT_GRADER,
125-
max_workers: int = -1,
130+
grader: Union[str, GraderType3] = DEFAULT_GRADER,
131+
max_workers: Optional[int] = -1,
126132
job_pool: Optional[ThreadPoolExecutor] = None,
127133
stop_on_incorrect=None):
128134
"""
@@ -182,7 +188,7 @@ def get_std_from_std_program():
182188
elif std is not None:
183189

184190
def get_std_from_std_file():
185-
return cls.__process_output_file(std)[1]
191+
return cls.__process_output_file(cast(Union[str, IO], std))[1]
186192

187193
if job_pool is not None:
188194
std = job_pool.submit(get_std_from_std_file).result()
@@ -197,33 +203,29 @@ def get_std_from_std_file():
197203
"r",
198204
newline="\n",
199205
encoding="utf-8") as input_file:
200-
201-
def do(program_name):
202-
timeout = None
203-
if (list_like(program_name) and len(program_name) == 2
204-
and int_like(program_name[-1])):
205-
program_name, timeout = program_name
206-
if timeout is None:
207-
content = subprocess.check_output(
208-
program_name,
209-
shell=(not list_like(program_name)),
210-
stdin=input_file,
211-
universal_newlines=True,
212-
encoding="utf-8",
213-
)
214-
else:
215-
content = subprocess.check_output(
216-
program_name,
217-
shell=(not list_like(program_name)),
218-
stdin=input_file,
219-
universal_newlines=True,
220-
timeout=timeout,
221-
encoding="utf-8",
222-
)
223-
cls.__compare_two(program_name, content, std, grader)
224-
225-
if job_pool is not None:
226-
job_pool.map(do, programs)
206+
input_text = input_file.read()
207+
208+
def do(program_name: Union[PrgoramType, Tuple[PrgoramType, float]]):
209+
timeout = None
210+
if isinstance(program_name, tuple) and len(program_name) == 2 and (
211+
isinstance(program_name[1], float)
212+
or isinstance(program_name[1], int)):
213+
program_name, timeout = cast(Tuple[PrgoramType, float],
214+
program_name)
227215
else:
228-
for program in programs:
229-
do(program)
216+
program_name = cast(PrgoramType, program_name)
217+
content = subprocess.check_output(
218+
list(program_name)
219+
if isinstance(program_name, tuple) else program_name,
220+
shell=(not list_like(program_name)),
221+
input=input_text,
222+
universal_newlines=True,
223+
encoding="utf-8",
224+
timeout=timeout)
225+
cls.__compare_two(program_name, content, std, input_text, grader)
226+
227+
if job_pool is not None:
228+
job_pool.map(do, programs)
229+
else:
230+
for program in programs:
231+
do(program)

cyaron/graders/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .graderregistry import CYaRonGraders, GraderType
1+
from .graderregistry import CYaRonGraders, GraderType2, GraderType3
22

33
from .fulltext import fulltext
44
from .noipstyle import noipstyle

cyaron/graders/graderregistry.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,51 @@
1-
from typing import Callable, Tuple, Dict, Union
1+
from typing import Callable, Tuple, Dict, Union, Any
22

3-
__all__ = ['CYaRonGraders', 'GraderType']
3+
__all__ = ['CYaRonGraders', 'GraderType2', 'GraderType3']
44

5-
GraderType = Callable[[str, str], Tuple[bool, Union[str, None]]]
5+
GraderType2 = Callable[[str, str], Tuple[bool, Any]]
6+
GraderType3 = Callable[[str, str, str], Tuple[bool, Any]]
67

78

89
class GraderRegistry:
910
"""A registry for grader functions."""
10-
_registry: Dict[str, GraderType] = {}
11+
_registry: Dict[str, GraderType3] = {}
12+
13+
def grader2(self, name: str):
14+
"""
15+
This decorator registers a grader function under a specific name in the registry.
16+
17+
The function being decorated should accept exactly two parameters (excluding
18+
the content input).
19+
"""
20+
21+
def wrapper(func: GraderType2):
22+
self._registry[name] = lambda content, std, _: func(content, std)
23+
return func
24+
25+
return wrapper
26+
27+
grader = grader2
1128

12-
def grader(self, name: str):
13-
"""A decorator to register a grader function."""
29+
def grader3(self, name: str):
30+
"""
31+
This decorator registers a grader function under a specific name in the registry.
32+
33+
The function being decorated should accept exactly three parameters.
34+
"""
1435

15-
def wrapper(func: GraderType):
36+
def wrapper(func: GraderType3):
1637
self._registry[name] = func
1738
return func
1839

1940
return wrapper
2041

21-
def invoke(self, grader: Union[str, GraderType], content: str, std: str):
42+
def invoke(self, grader: Union[str, GraderType3], content: str, std: str,
43+
input_content: str):
2244
"""Invoke a grader function by name or function object."""
2345
if isinstance(grader, str):
24-
return self._registry[grader](content, std)
46+
return self._registry[grader](content, std, input_content)
2547
else:
26-
return grader(content, std)
48+
return grader(content, std, input_content)
2749

2850
def check(self, name):
2951
"""Check if a grader is registered."""

cyaron/tests/compare_test.py

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -123,15 +123,15 @@ def test_file_input_success(self):
123123
grader="NOIPStyle")
124124

125125
def test_file_input_fail(self):
126-
with open("correct.py", "w") as f:
127-
f.write("print(input())")
128-
with open("std.py", "w") as f:
126+
with open("incorrect.py", "w") as f:
129127
f.write("print(input()+'154')")
128+
with open("std.py", "w") as f:
129+
f.write("print(input())")
130130
io = IO()
131131
io.input_writeln("233")
132132
try:
133133
with captured_output():
134-
Compare.program((sys.executable, "correct.py"),
134+
Compare.program((sys.executable, "incorrect.py"),
135135
std_program=(sys.executable, "std.py"),
136136
input=io,
137137
grader="NOIPStyle")
@@ -178,10 +178,11 @@ def test_timeout(self):
178178
else:
179179
self.assertTrue(False)
180180

181-
def test_custom_grader_by_name(self):
181+
def test_custom_grader2_by_name(self):
182+
self.assertEqual(CYaRonGraders.grader, CYaRonGraders.grader2)
182183

183-
@CYaRonGraders.grader("CustomTestGrader")
184-
def custom_test_grader(content: str, std: str):
184+
@CYaRonGraders.grader("CustomTestGrader2")
185+
def custom_test_grader2(content: str, std: str):
185186
if content == '1\n' and std == '2\n':
186187
return True, None
187188
return False, "CustomTestGrader failed"
@@ -192,13 +193,38 @@ def custom_test_grader(content: str, std: str):
192193
Compare.program("echo 1",
193194
std=io,
194195
input=IO(),
195-
grader="CustomTestGrader")
196+
grader="CustomTestGrader2")
196197

197198
try:
198199
Compare.program("echo 2",
199200
std=io,
200201
input=IO(),
201-
grader="CustomTestGrader")
202+
grader="CustomTestGrader2")
203+
except CompareMismatch as e:
204+
self.assertEqual(e.name, 'echo 2')
205+
self.assertEqual(e.mismatch, "CustomTestGrader failed")
206+
else:
207+
self.fail("Should raise CompareMismatch")
208+
209+
def test_custom_grader3_by_name(self):
210+
211+
@CYaRonGraders.grader3("CustomTestGrader3")
212+
def custom_test_grader3(content: str, std: str, input_content: str):
213+
if input_content == '0\n' and content == '1\n' and std == '2\n':
214+
return True, None
215+
return False, "CustomTestGrader failed"
216+
217+
io = IO()
218+
io.input_writeln("0")
219+
io.output_writeln("2")
220+
221+
Compare.program("echo 1", std=io, input=io, grader="CustomTestGrader3")
222+
223+
try:
224+
Compare.program("echo 2",
225+
std=io,
226+
input=io,
227+
grader='CustomTestGrader3')
202228
except CompareMismatch as e:
203229
self.assertEqual(e.name, 'echo 2')
204230
self.assertEqual(e.mismatch, "CustomTestGrader failed")
@@ -207,23 +233,21 @@ def custom_test_grader(content: str, std: str):
207233

208234
def test_custom_grader_by_function(self):
209235

210-
def custom_test_grader(content: str, std: str):
211-
if content == '1\n' and std == '2\n':
236+
def custom_test_grader(content: str, std: str, input_content: str):
237+
if input_content == '0\n' and content == '1\n' and std == '2\n':
212238
return True, None
213239
return False, "CustomTestGrader failed"
214240

215241
io = IO()
242+
io.input_writeln("0")
216243
io.output_writeln("2")
217244

218-
Compare.program("echo 1",
219-
std=io,
220-
input=IO(),
221-
grader=custom_test_grader)
245+
Compare.program("echo 1", std=io, input=io, grader=custom_test_grader)
222246

223247
try:
224248
Compare.program("echo 2",
225249
std=io,
226-
input=IO(),
250+
input=io,
227251
grader=custom_test_grader)
228252
except CompareMismatch as e:
229253
self.assertEqual(e.name, 'echo 2')

0 commit comments

Comments
 (0)