Skip to content

Commit a292806

Browse files
authored
[Fix] Add custom mock for sys.stdin that supports buffer attribute in LCBV6 (#2393)
* fix * fix * fix lint
1 parent bb15146 commit a292806

File tree

1 file changed

+39
-1
lines changed

1 file changed

+39
-1
lines changed

opencompass/datasets/livecodebench/testing_util.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# flake8: noqa
12
# Copyright LiveCodeBench @ 2024,
23

34
import ast
@@ -657,18 +658,55 @@ def stripped_string_compare(s1, s2):
657658
return s1 == s2
658659

659660

661+
class MockStdinWithBuffer:
662+
663+
def __init__(self, inputs: str):
664+
self.inputs = inputs
665+
self._stringio = StringIO(inputs)
666+
self.buffer = MockBuffer(inputs)
667+
668+
def read(self, *args):
669+
return self.inputs
670+
671+
def readline(self, *args):
672+
return self._stringio.readline(*args)
673+
674+
def readlines(self, *args):
675+
return self.inputs.split('\n')
676+
677+
def __getattr__(self, name):
678+
# Delegate other attributes to StringIO
679+
return getattr(self._stringio, name)
680+
681+
682+
class MockBuffer:
683+
684+
def __init__(self, inputs: str):
685+
self.inputs = inputs.encode('utf-8') # Convert to bytes
686+
687+
def read(self, *args):
688+
# Return as byte strings that can be split
689+
return self.inputs
690+
691+
def readline(self, *args):
692+
return self.inputs.split(b'\n')[0] + b'\n'
693+
694+
660695
def call_method(method, inputs):
661696

662697
if isinstance(inputs, list):
663698
inputs = '\n'.join(inputs)
664699

665700
inputs_line_iterator = iter(inputs.split('\n'))
666701

702+
# Create custom stdin mock with buffer support
703+
mock_stdin = MockStdinWithBuffer(inputs)
704+
667705
# sys.setrecursionlimit(10000)
668706

669707
# @patch('builtins.input', side_effect=inputs.split("\n"))
670708
@patch('builtins.open', mock_open(read_data=inputs))
671-
@patch('sys.stdin', StringIO(inputs))
709+
@patch('sys.stdin', mock_stdin) # Use our custom mock instead of StringIO
672710
@patch('sys.stdin.readline', lambda *args: next(inputs_line_iterator))
673711
@patch('sys.stdin.readlines', lambda *args: inputs.split('\n'))
674712
@patch('sys.stdin.read', lambda *args: inputs)

0 commit comments

Comments
 (0)