|
| 1 | +# flake8: noqa |
1 | 2 | # Copyright LiveCodeBench @ 2024, |
2 | 3 |
|
3 | 4 | import ast |
@@ -657,18 +658,55 @@ def stripped_string_compare(s1, s2): |
657 | 658 | return s1 == s2 |
658 | 659 |
|
659 | 660 |
|
| 661 | +class MockStdinWithBuffer: |
| 662 | + |
| 663 | + def __init__(self, inputs: str): |
| 664 | + self.inputs = inputs |
| 665 | + self._stringio = StringIO(inputs) |
| 666 | + self.buffer = MockBuffer(inputs) |
| 667 | + |
| 668 | + def read(self, *args): |
| 669 | + return self.inputs |
| 670 | + |
| 671 | + def readline(self, *args): |
| 672 | + return self._stringio.readline(*args) |
| 673 | + |
| 674 | + def readlines(self, *args): |
| 675 | + return self.inputs.split('\n') |
| 676 | + |
| 677 | + def __getattr__(self, name): |
| 678 | + # Delegate other attributes to StringIO |
| 679 | + return getattr(self._stringio, name) |
| 680 | + |
| 681 | + |
| 682 | +class MockBuffer: |
| 683 | + |
| 684 | + def __init__(self, inputs: str): |
| 685 | + self.inputs = inputs.encode('utf-8') # Convert to bytes |
| 686 | + |
| 687 | + def read(self, *args): |
| 688 | + # Return as byte strings that can be split |
| 689 | + return self.inputs |
| 690 | + |
| 691 | + def readline(self, *args): |
| 692 | + return self.inputs.split(b'\n')[0] + b'\n' |
| 693 | + |
| 694 | + |
660 | 695 | def call_method(method, inputs): |
661 | 696 |
|
662 | 697 | if isinstance(inputs, list): |
663 | 698 | inputs = '\n'.join(inputs) |
664 | 699 |
|
665 | 700 | inputs_line_iterator = iter(inputs.split('\n')) |
666 | 701 |
|
| 702 | + # Create custom stdin mock with buffer support |
| 703 | + mock_stdin = MockStdinWithBuffer(inputs) |
| 704 | + |
667 | 705 | # sys.setrecursionlimit(10000) |
668 | 706 |
|
669 | 707 | # @patch('builtins.input', side_effect=inputs.split("\n")) |
670 | 708 | @patch('builtins.open', mock_open(read_data=inputs)) |
671 | | - @patch('sys.stdin', StringIO(inputs)) |
| 709 | + @patch('sys.stdin', mock_stdin) # Use our custom mock instead of StringIO |
672 | 710 | @patch('sys.stdin.readline', lambda *args: next(inputs_line_iterator)) |
673 | 711 | @patch('sys.stdin.readlines', lambda *args: inputs.split('\n')) |
674 | 712 | @patch('sys.stdin.read', lambda *args: inputs) |
|
0 commit comments