Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 111 additions & 45 deletions src/swerex/runtime/remote.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import asyncio
import logging
import random
import shutil
import sys
import tempfile
import traceback
import uuid
from pathlib import Path
from typing import Any

import requests
import aiohttp
from pydantic import BaseModel
from typing_extensions import Self

Expand Down Expand Up @@ -55,11 +58,18 @@ def __init__(
if not self._config.host.startswith("http"):
self.logger.warning("Host %s does not start with http, adding http://", self._config.host)
self._config.host = f"http://{self._config.host}"
self._session = None

@classmethod
def from_config(cls, config: RemoteRuntimeConfig) -> Self:
return cls(**config.model_dump())

async def _ensure_session(self) -> aiohttp.ClientSession:
"""Ensure that an aiohttp client session exists and return it."""
if self._session is None or self._session.closed:
self._session = aiohttp.ClientSession()
return self._session

def _get_timeout(self, timeout: float | None = None) -> float:
if timeout is None:
return self._config.timeout
Expand Down Expand Up @@ -111,16 +121,16 @@ def _handle_transfer_exception(self, exc_transfer: _ExceptionTransfer) -> None:
exception.extra_info = exc_transfer.extra_info
raise exception from None

def _handle_response_errors(self, response: requests.Response) -> None:
async def _handle_response_errors(self, response: aiohttp.ClientResponse) -> None:
"""Raise exceptions found in the request response."""
if response.status_code == 511:
exc_transfer = _ExceptionTransfer(**response.json()["swerexception"])
if response.status == 511:
data = await response.json()
exc_transfer = _ExceptionTransfer(**data["swerexception"])
self._handle_transfer_exception(exc_transfer)
try:
if response.status >= 400:
data = await response.json()
self.logger.critical("Received error response: %s", data)
response.raise_for_status()
except Exception:
self.logger.critical("Received error response: %s", response.json())
raise

async def is_alive(self, *, timeout: float | None = None) -> IsAliveResponse:
"""Checks if the runtime is alive.
Expand All @@ -129,20 +139,23 @@ async def is_alive(self, *, timeout: float | None = None) -> IsAliveResponse:
together with the message.
"""
try:
response = requests.get(
f"{self._api_url}/is_alive", headers=self._headers, timeout=self._get_timeout(timeout)
)
if response.status_code == 200:
return IsAliveResponse(**response.json())
elif response.status_code == 511:
exc_transfer = _ExceptionTransfer(**response.json()["swerexception"])
self._handle_transfer_exception(exc_transfer)
msg = (
f"Status code {response.status_code} from {self._api_url}/is_alive. "
f"Message: {response.json().get('detail')}"
)
return IsAliveResponse(is_alive=False, message=msg)
except requests.RequestException:
session = await self._ensure_session()
timeout_value = self._get_timeout(timeout)
async with session.get(
f"{self._api_url}/is_alive", headers=self._headers, timeout=aiohttp.ClientTimeout(total=timeout_value)
) as response:
if response.status == 200:
data = await response.json()
return IsAliveResponse(**data)
elif response.status == 511:
data = await response.json()
exc_transfer = _ExceptionTransfer(**data["swerexception"])
self._handle_transfer_exception(exc_transfer)

data = await response.json()
msg = f"Status code {response.status} from {self._api_url}/is_alive. Message: {data.get('detail')}"
return IsAliveResponse(is_alive=False, message=msg)
except aiohttp.ClientError:
msg = f"Failed to connect to {self._config.host}\n"
msg += traceback.format_exc()
return IsAliveResponse(is_alive=False, message=msg)
Expand All @@ -154,64 +167,117 @@ async def is_alive(self, *, timeout: float | None = None) -> IsAliveResponse:
async def wait_until_alive(self, *, timeout: float = 60.0):
return await _wait_until_alive(self.is_alive, timeout=timeout)

def _request(self, endpoint: str, request: BaseModel | None, output_class: Any):
async def _request(self, endpoint: str, request: BaseModel | None, output_class: Any, num_retries: int = 10):
"""Small helper to make requests to the server and handle errors and output."""
request_url = f"{self._api_url}/{endpoint}"
return await self._request_with_retry(request_url, request, output_class, num_retries)

async def _request_with_retry(
self,
request_url: str,
request: BaseModel | None,
output_class: Any,
num_retries: int,
):
"""Small helper to make requests to the server and handle errors and output."""
response = requests.post(
f"{self._api_url}/{endpoint}", json=request.model_dump() if request else None, headers=self._headers
)
self._handle_response_errors(response)
return output_class(**response.json())
request_id = str(uuid.uuid4())
headers = self._headers.copy()
headers["X-Request-ID"] = request_id # idempotency key for the request

retry_count = 0
last_exception = None
retry_delay = 0.1
backoff_max = 5

session = await self._ensure_session()
while retry_count <= num_retries:
try:
async with session.post(
request_url, json=request.model_dump() if request else None, headers=headers
) as response:
await self._handle_response_errors(response)
data = await response.json()
return output_class(**data)
except Exception as e:
last_exception = e
retry_count += 1
if retry_count <= num_retries:
await asyncio.sleep(retry_delay)
retry_delay *= 2
retry_delay += random.uniform(0, 0.5)
retry_delay = min(retry_delay, backoff_max)
continue
self.logger.error("Error making request %s after %d retries: %s", request_id, num_retries, e)
raise last_exception

async def create_session(self, request: CreateSessionRequest) -> CreateSessionResponse:
"""Creates a new session."""
return self._request("create_session", request, CreateSessionResponse)
return await self._request("create_session", request, CreateSessionResponse)

async def run_in_session(self, action: Action) -> Observation:
"""Runs a command in a session."""
return self._request("run_in_session", action, Observation)
return await self._request("run_in_session", action, Observation)

async def close_session(self, request: CloseSessionRequest) -> CloseSessionResponse:
"""Closes a shell session."""
return self._request("close_session", request, CloseSessionResponse)
return await self._request("close_session", request, CloseSessionResponse)

async def execute(self, command: Command) -> CommandResponse:
"""Executes a command (independent of any shell session)."""
return self._request("execute", command, CommandResponse)
return await self._request("execute", command, CommandResponse)

async def read_file(self, request: ReadFileRequest) -> ReadFileResponse:
"""Reads a file"""
return self._request("read_file", request, ReadFileResponse)
return await self._request("read_file", request, ReadFileResponse)

async def write_file(self, request: WriteFileRequest) -> WriteFileResponse:
"""Writes a file"""
return self._request("write_file", request, WriteFileResponse)
return await self._request("write_file", request, WriteFileResponse)

async def upload(self, request: UploadRequest) -> UploadResponse:
"""Uploads a file"""
source = Path(request.source_path).resolve()
self.logger.debug("Uploading file from %s to %s", request.source_path, request.target_path)

session = await self._ensure_session()

if source.is_dir():
# Ignore cleanup errors: See https://github.com/SWE-agent/SWE-agent/issues/1005
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as temp_dir:
zip_path = Path(temp_dir) / "zipped_transfer.zip"
shutil.make_archive(str(zip_path.with_suffix("")), "zip", source)
self.logger.debug("Created zip file at %s", zip_path)
files = {"file": zip_path.open("rb")}
data = {"target_path": request.target_path, "unzip": "true"}
response = requests.post(f"{self._api_url}/upload", files=files, data=data, headers=self._headers)
self._handle_response_errors(response)
return UploadResponse(**response.json())

data = aiohttp.FormData()
data.add_field("file", open(zip_path, "rb"), filename=zip_path.name, content_type="application/zip")
data.add_field("target_path", request.target_path)
data.add_field("unzip", "true")

async with session.post(f"{self._api_url}/upload", data=data, headers=self._headers) as response:
await self._handle_response_errors(response)
return UploadResponse(**(await response.json()))
elif source.is_file():
self.logger.debug("Uploading file from %s to %s", source, request.target_path)
files = {"file": source.open("rb")}
data = {"target_path": request.target_path, "unzip": "false"}
response = requests.post(f"{self._api_url}/upload", files=files, data=data, headers=self._headers)
self._handle_response_errors(response)
return UploadResponse(**response.json())

data = aiohttp.FormData()
data.add_field("file", open(source, "rb"), filename=source.name)
data.add_field("target_path", request.target_path)
data.add_field("unzip", "false")

async with session.post(f"{self._api_url}/upload", data=data, headers=self._headers) as response:
await self._handle_response_errors(response)
return UploadResponse(**(await response.json()))
else:
msg = f"Source path {source} is not a file or directory"
raise ValueError(msg)

async def close(self) -> CloseResponse:
"""Closes the runtime."""
return self._request("close", None, CloseResponse)
try:
response = await self._request("close", None, CloseResponse)
if self._session and not self._session.closed:
await self._session.close()
return response
finally:
if self._session and not self._session.closed:
await self._session.close()
54 changes: 54 additions & 0 deletions src/swerex/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from fastapi.responses import JSONResponse
from fastapi.security import APIKeyHeader
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.responses import Response

from swerex import __version__
from swerex.runtime.abstract import (
Expand All @@ -38,6 +39,31 @@ def serialize_model(model):
return model.model_dump() if hasattr(model, "model_dump") else model.dict()


class ResponseManager:
"""
This stores the response of the last request, and is used in retries to return
already executed requests.

Note that in the case of multiple concurrent clients, idempotency isn't guaranteed.
"""

def __init__(self):
self.last_processed_request_id = None
self.last_processed_response = None

def get_response(self, request_id):
if request_id == self.last_processed_request_id:
return self.last_processed_response
return None

def set_response(self, request_id, response):
self.last_processed_request_id = request_id
self.last_processed_response = response


response_manager = ResponseManager()


@app.middleware("http")
async def authenticate(request: Request, call_next):
"""Authenticate requests with an API key (if set)."""
Expand All @@ -48,6 +74,34 @@ async def authenticate(request: Request, call_next):
return await call_next(request)


@app.middleware("http")
async def handle_request_id(request: Request, call_next):
"""Handle request ID for idempotency."""
request_id = request.headers.get("X-Request-ID")
if request_id:
response = response_manager.get_response(request_id)
if response:
return response

response = await call_next(request)

async def body_stream():
async for chunk in response.body_iterator:
yield chunk

new_response = Response(
content=body_stream(),
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.media_type,
)

if request_id:
response_manager.set_response(request_id, new_response)

return new_response


@app.exception_handler(Exception)
async def exception_handler(request: Request, exc: Exception):
"""We catch exceptions that are thrown by the runtime, serialize them to JSON and
Expand Down
Loading