Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/swerex/runtime/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def find_range(cmd: bashlex.ast.node) -> tuple[int, int]:

def _strip_control_chars(s: str) -> str:
ansi_escape = re.compile(r"\x1B[@-_][0-?]*[ -/]*[@-~]")
return ansi_escape.sub("", s)
return ansi_escape.sub("", s).replace("\r\n", "\n")


def _check_bash_command(command: str) -> None:
Expand Down Expand Up @@ -253,7 +253,7 @@ async def _run_interactive(self, action: BashAction) -> BashObservation:
except pexpect.TIMEOUT as e:
msg = f"timeout after {action.timeout} seconds while running command {action.command!r}"
raise CommandTimeoutError(msg) from e
output: str = _strip_control_chars(self.shell.before).strip() # type: ignore
output: str = _strip_control_chars(self.shell.before) # type: ignore
if action.is_interactive_quit:
assert not action.is_interactive_command
self.shell.setecho(False)
Expand Down Expand Up @@ -297,7 +297,7 @@ async def _run_normal(self, action: BashAction) -> BashObservation:
# Bashlex is very buggy and can throw a variety of errors, including
# ParsingErrors, NotImplementedErrors, TypeErrors, possibly more. So we catch them all
self.logger.error("Bashlex fail: %s", e)
action.command += f"\n TMPEXITCODE=$? ; sleep 0.1; echo '{self._UNIQUE_STRING}' ; (exit $TMPEXITCODE)"
action.command += f"\n TMPEXITCODE=$? ; sleep 0.1; echo -n '{self._UNIQUE_STRING}' ; (exit $TMPEXITCODE)"
fallback_terminator = True
else:
action.command = " ; ".join(individual_commands)
Expand All @@ -312,7 +312,7 @@ async def _run_normal(self, action: BashAction) -> BashObservation:
except pexpect.TIMEOUT as e:
msg = f"timeout after {action.timeout} seconds while running command {action.command!r}"
raise CommandTimeoutError(msg) from e
output: str = _strip_control_chars(self.shell.before).strip() # type: ignore
output: str = _strip_control_chars(self.shell.before) # type: ignore

# Part 3: Get the exit code
if action.check == "ignore":
Expand All @@ -327,7 +327,7 @@ async def _run_normal(self, action: BashAction) -> BashObservation:
except pexpect.TIMEOUT:
msg = "timeout while getting exit code"
raise NoExitCodeError(msg)
exit_code_raw: str = _strip_control_chars(self.shell.before).strip() # type: ignore
exit_code_raw: str = _strip_control_chars(self.shell.before) # type: ignore
exit_code = re.findall(f"{_exit_code_prefix}([0-9]+)", exit_code_raw)
if len(exit_code) != 1:
msg = f"failed to parse exit code from output {exit_code_raw!r} (command: {action.command!r}, matches: {exit_code})"
Expand Down
72 changes: 55 additions & 17 deletions tests/test_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,46 @@ async def test_execute_command(remote_runtime: RemoteRuntime):
assert (await remote_runtime.execute(C(command="echo 'hello world'", shell=True))).stdout == "hello world\n"


async def test_execute_command_with_empty_string(remote_runtime: RemoteRuntime):
assert (await remote_runtime.execute(C(command="echo ''", shell=True))).stdout == "\n"


async def test_execute_command_with_empty_string_in_session(runtime_with_default_session: RemoteRuntime):
assert (await runtime_with_default_session.run_in_session(A(command="echo ''", check="raise"))).output == "\n"


async def test_execute_command_with_leading_space_output(remote_runtime: RemoteRuntime):
assert (
await remote_runtime.execute(C(command="echo '\n \nhello world'", shell=True))
).stdout == "\n \nhello world\n"


async def test_execute_command_with_echon(remote_runtime: RemoteRuntime):
assert (await remote_runtime.execute(C(command="echo -n 'hello world'", shell=True))).stdout == "hello world"
assert (await remote_runtime.execute(C(command="echo -n 'hello world\n'", shell=True))).stdout == "hello world\n"
assert (await remote_runtime.execute(C(command="echo -n '\nhello world'", shell=True))).stdout == "\nhello world"


async def test_execute_command_with_newline_in_session(runtime_with_default_session: RemoteRuntime):
assert (await runtime_with_default_session.run_in_session(A(command="printf '\nx'", check="raise"))).output == "\nx"


async def test_execute_command_with_many_newlines_in_session(runtime_with_default_session: RemoteRuntime):
assert (
await runtime_with_default_session.run_in_session(A(command="printf '\n\nx\n\n\n'", check="raise"))
).output == "\n\nx\n\n\n"


async def test_execute_command_with_whitespace_in_session(runtime_with_default_session: RemoteRuntime):
assert (await runtime_with_default_session.run_in_session(A(command="printf ' x'", check="raise"))).output == " x"


async def test_execute_command_with_leading_space_in_session(runtime_with_default_session: RemoteRuntime):
assert (
await runtime_with_default_session.run_in_session(A(command="echo '\n \nhello\nworld'", check="raise"))
).output == "\n \nhello\nworld\n"


async def test_execute_command_shell_false(remote_runtime: RemoteRuntime):
assert (await remote_runtime.execute(C(command=["echo", "hello world"], shell=False))).stdout == "hello world\n"

Expand Down Expand Up @@ -114,18 +154,18 @@ async def test_run_in_shell_multiple_interactive_and_normal_commands(runtime_wit
await run.run_in_session(A(command="python", is_interactive_command=True, expect=[">>> "]))

r = await run.run_in_session(A(command="print('hello world')", is_interactive_command=True, expect=[">>> "]))
assert "hello world" in r.output
assert r.output == "hello world\n"

await run.run_in_session(A(command="quit()\n", is_interactive_quit=True, check="raise"))

r = await run.run_in_session(A(command="echo 'hello world'", check="raise"))
assert "hello world" in r.output
assert r.output == "hello world\n"

await run.run_in_session(A(command="python", is_interactive_command=True, expect=[">>> "]))
await run.run_in_session(A(command="print('hello world')", is_interactive_command=True, expect=[">>> "]))
await run.run_in_session(A(command="quit()\n", is_interactive_quit=True, check="raise"))
r = await run.run_in_session(A(command="echo 'hello world'", check="raise"))
assert "hello world" in r.output
assert r.output == "hello world\n"


async def test_run_in_shell_interactive_command_timeout(runtime_with_default_session: RemoteRuntime):
Expand Down Expand Up @@ -161,14 +201,14 @@ async def test_multiple_isolated_shells(remote_runtime: RemoteRuntime):
response1 = await remote_runtime.run_in_session(A(command="echo $x", session="shell1", check="raise"))
response2 = await remote_runtime.run_in_session(A(command="echo $y", session="shell2", check="raise"))

assert response1.output.strip() == "42"
assert response2.output.strip() == "24"
assert response1.output == "42\n"
assert response2.output == "24\n"

response3 = await remote_runtime.run_in_session(A(command="echo $y", session="shell1", check="raise"))
response4 = await remote_runtime.run_in_session(A(command="echo $x", session="shell2", check="raise"))

assert response3.output.strip() == ""
assert response4.output.strip() == ""
assert response3.output == "\n"
assert response4.output == "\n"

await remote_runtime.close_session(CloseBashSessionRequest(session="shell1"))
await remote_runtime.close_session(CloseBashSessionRequest(session="shell2"))
Expand Down Expand Up @@ -198,14 +238,13 @@ async def test_multiple_commands_with_linebreaks_in_shell(runtime_with_default_s
r = await runtime_with_default_session.run_in_session(
A(command="\n\n\n echo 'test1' \n \n \n echo 'test2' \n\n\n", check="raise")
)
assert r.output.splitlines() == ["test1", "test2"]
assert r.output == "test1\ntest2\n"


async def test_bash_multiline_command_eof(runtime_with_default_session: RemoteRuntime):
command = "\n".join(["python <<EOF", "print('hello world')", "print('hello world 2')", "EOF"])
r = await runtime_with_default_session.run_in_session(A(command=command, check="raise"))
assert "hello world" in r.output
assert "hello world 2" in r.output
assert r.output == "hello world\nhello world 2\n"


async def test_run_in_shell_subshell_command(runtime_with_default_session: RemoteRuntime):
Expand All @@ -221,18 +260,18 @@ async def test_run_in_shell_multiple_commands(runtime_with_default_session: Remo
r = await runtime_with_default_session.run_in_session(
A(command="echo 'hello world'; echo 'hello again'", check="raise")
)
assert r.output.splitlines() == ["hello world", "hello again"]
assert r.output == "hello world\nhello again\n"
r = await runtime_with_default_session.run_in_session(
A(command="echo 'hello world' && echo 'hello again'", check="raise")
)
assert r.output.splitlines() == ["hello world", "hello again"]
assert r.output == "hello world\nhello again\n"


async def test_run_in_shell_while_loop(runtime_with_default_session: RemoteRuntime):
r = await runtime_with_default_session.run_in_session(
A(command="for i in {1..3};\n do echo 'hello world';\n done", check="raise")
)
assert r.output.splitlines() == ["hello world"] * 3
assert r.output == "hello world\n" * 3


async def test_run_in_shell_bashlex_errors(runtime_with_default_session: RemoteRuntime):
Expand All @@ -249,8 +288,7 @@ async def test_run_shell_check_exit_code(runtime_with_default_session: RemoteRun

async def test_with_bashlex_errors(runtime_with_default_session: RemoteRuntime):
r = await runtime_with_default_session.run_in_session(A(command="echo 'hw';A=();echo 'asdf'", check="raise"))
assert "hw" in r.output
assert "asdf" in r.output
assert r.output == "hw\nasdf\n"


async def test_upload_file(runtime_with_default_session: RemoteRuntime, tmp_path: Path):
Expand Down Expand Up @@ -294,7 +332,7 @@ async def test_check_bash_command_invalid(runtime_with_default_session: RemoteRu

async def test_echo_new_lines(runtime_with_default_session: RemoteRuntime):
r = await runtime_with_default_session.run_in_session(A(command="echo 'hello\nworld'", check="raise"))
assert r.output.splitlines() == ["hello", "world"]
assert r.output == "hello\nworld\n"


async def test_interrupt_session(runtime_with_default_session: RemoteRuntime):
Expand All @@ -304,7 +342,7 @@ async def test_interrupt_session(runtime_with_default_session: RemoteRuntime):
pass
r = await runtime_with_default_session.run_in_session(BashInterruptAction())
r = await runtime_with_default_session.run_in_session(A(command="echo 'asdf'", check="raise"))
assert r.output == "asdf"
assert r.output == "asdf\n"


async def test_interrupt_pager(runtime_with_default_session: RemoteRuntime):
Expand Down