Skip to content

Commit b5b84b1

Browse files
Support define grader by input function
1 parent 9fa38ff commit b5b84b1

File tree

4 files changed

+88
-19
lines changed

4 files changed

+88
-19
lines changed

cyaron/compare.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from typing import List, Optional, Tuple, Union
1010

1111
from cyaron.consts import *
12-
from cyaron.graders import CYaRonGraders
12+
from cyaron.graders import CYaRonGraders, GraderType
1313
from cyaron.utils import *
1414

1515
from . import log
@@ -27,11 +27,14 @@ def __str__(self):
2727
return "In program: '{}'. {}".format(self.name, self.mismatch)
2828

2929

30+
PrgoramType = Optional[Union[str, Tuple[str, ...], List[str]]]
31+
32+
3033
class Compare:
3134

3235
@staticmethod
3336
def __compare_two(name, content, std, grader):
34-
(result, info) = CYaRonGraders.invoke(grader, content, std)
37+
result, info = CYaRonGraders.invoke(grader, content, std)
3538
status = "Correct" if result else "!!!INCORRECT!!!"
3639
info = info if info is not None else ""
3740
log.debug("{}: {} {}".format(name, status, info))
@@ -85,8 +88,6 @@ def output(cls, *files, **kwargs):
8588
if (max_workers is None or max_workers >= 0) and job_pool is None:
8689
max_workers = cls.__normal_max_workers(max_workers)
8790
try:
88-
from concurrent.futures import ThreadPoolExecutor
89-
9091
with ThreadPoolExecutor(max_workers=max_workers) as job_pool:
9192
return cls.output(*files,
9293
std=std,
@@ -115,12 +116,12 @@ def do(file):
115116

116117
@classmethod
117118
def program(cls,
118-
*programs: Optional[Union[str, Tuple[str, ...], List[str]]],
119+
*programs: Union[PrgoramType, Tuple[PrgoramType, float]],
119120
input: Union[IO, str],
120121
std: Optional[Union[str, IO]] = None,
121122
std_program: Optional[Union[str, Tuple[str, ...],
122123
List[str]]] = None,
123-
grader: Optional[str] = DEFAULT_GRADER,
124+
grader: Union[str, GraderType] = DEFAULT_GRADER,
124125
max_workers: int = -1,
125126
job_pool: Optional[ThreadPoolExecutor] = None,
126127
stop_on_incorrect=None):
@@ -212,13 +213,13 @@ def do(program_name):
212213
)
213214
else:
214215
content = subprocess.check_output(
215-
program_name,
216-
shell=(not list_like(program_name)),
217-
stdin=input_file,
218-
universal_newlines=True,
219-
timeout=timeout,
220-
encoding="utf-8",
221-
)
216+
program_name,
217+
shell=(not list_like(program_name)),
218+
stdin=input_file,
219+
universal_newlines=True,
220+
timeout=timeout,
221+
encoding="utf-8",
222+
)
222223
cls.__compare_two(program_name, content, std, grader)
223224

224225
if job_pool is not None:

cyaron/graders/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .graderregistry import CYaRonGraders
1+
from .graderregistry import CYaRonGraders, GraderType
22

33
from .fulltext import fulltext
44
from .noipstyle import noipstyle

cyaron/graders/graderregistry.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,32 @@
1+
from typing import Callable, Tuple, Dict, Union
2+
3+
__all__ = ['CYaRonGraders', 'GraderType']
4+
5+
GraderType = Callable[[str, str], Tuple[bool, Union[str, None]]]
6+
7+
18
class GraderRegistry:
2-
_registry = dict()
9+
"""A registry for grader functions."""
10+
_registry: Dict[str, GraderType] = {}
311

4-
def grader(self, name):
12+
def grader(self, name: str):
13+
"""A decorator to register a grader function."""
514

6-
def wrapper(func):
15+
def wrapper(func: GraderType):
716
self._registry[name] = func
817
return func
918

1019
return wrapper
1120

12-
def invoke(self, name, content, std):
13-
return self._registry[name](content, std)
21+
def invoke(self, grader: Union[str, GraderType], content: str, std: str):
22+
"""Invoke a grader function by name or function object."""
23+
if isinstance(grader, str):
24+
return self._registry[grader](content, std)
25+
else:
26+
return grader(content, std)
1427

1528
def check(self, name):
29+
"""Check if a grader is registered."""
1630
return name in self._registry
1731

1832

cyaron/tests/compare_test.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from cyaron.output_capture import captured_output
99
from cyaron.graders.mismatch import *
1010
from cyaron.compare import CompareMismatch
11+
from cyaron.graders import CYaRonGraders
1112

1213
log.set_verbose()
1314

@@ -176,3 +177,56 @@ def test_timeout(self):
176177
pass
177178
else:
178179
self.assertTrue(False)
180+
181+
def test_custom_grader_by_name(self):
182+
183+
@CYaRonGraders.grader("CustomTestGrader")
184+
def custom_test_grader(content: str, std: str):
185+
if content == '1\n' and std == '2\n':
186+
return True, None
187+
return False, "CustomTestGrader failed"
188+
189+
io = IO()
190+
io.output_writeln("2")
191+
192+
Compare.program("echo 1",
193+
std=io,
194+
input=IO(),
195+
grader="CustomTestGrader")
196+
197+
try:
198+
Compare.program("echo 2",
199+
std=io,
200+
input=IO(),
201+
grader="CustomTestGrader")
202+
except CompareMismatch as e:
203+
self.assertEqual(e.name, 'echo 2')
204+
self.assertEqual(e.mismatch, "CustomTestGrader failed")
205+
else:
206+
self.fail("Should raise CompareMismatch")
207+
208+
def test_custom_grader_by_function(self):
209+
210+
def custom_test_grader(content: str, std: str):
211+
if content == '1\n' and std == '2\n':
212+
return True, None
213+
return False, "CustomTestGrader failed"
214+
215+
io = IO()
216+
io.output_writeln("2")
217+
218+
Compare.program("echo 1",
219+
std=io,
220+
input=IO(),
221+
grader=custom_test_grader)
222+
223+
try:
224+
Compare.program("echo 2",
225+
std=io,
226+
input=IO(),
227+
grader=custom_test_grader)
228+
except CompareMismatch as e:
229+
self.assertEqual(e.name, 'echo 2')
230+
self.assertEqual(e.mismatch, "CustomTestGrader failed")
231+
else:
232+
self.fail("Should raise CompareMismatch")

0 commit comments

Comments
 (0)