Skip to content
18 changes: 17 additions & 1 deletion servers/fai-lambda-deploy/scripts/fai-scribe-stack.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import { type EnvironmentInfo, EnvironmentType } from "@fern-fern/fern-cloud-sdk/api";
import { CfnOutput, Duration, RemovalPolicy, Stack, type StackProps } from "aws-cdk-lib";
import { CfnOutput, Duration, RemovalPolicy, Size, Stack, type StackProps } from "aws-cdk-lib";
import * as apigateway from "aws-cdk-lib/aws-apigateway";
import { Certificate } from "aws-cdk-lib/aws-certificatemanager";
import * as ec2 from "aws-cdk-lib/aws-ec2";
import * as efs from "aws-cdk-lib/aws-efs";
import * as iam from "aws-cdk-lib/aws-iam";
import * as lambda from "aws-cdk-lib/aws-lambda";
import { LogGroup, RetentionDays } from "aws-cdk-lib/aws-logs";
import { ARecord, HostedZone, RecordTarget } from "aws-cdk-lib/aws-route53";
Expand Down Expand Up @@ -109,6 +110,7 @@ export class FaiScribeStack extends Stack {
}),
timeout: Duration.minutes(15),
memorySize: 512,
ephemeralStorageSize: Size.mebibytes(2048),
logGroup,
vpc: this.vpc,
vpcSubnets: {
Expand All @@ -125,6 +127,20 @@ export class FaiScribeStack extends Stack {
filesystem: lambda.FileSystem.fromEfsAccessPoint(accessPoint, "/mnt/efs")
});

lambdaFunction.addToRolePolicy(
new iam.PolicyStatement({
effect: iam.Effect.ALLOW,
actions: [
"sqs:ReceiveMessage",
"sqs:DeleteMessage",
"sqs:GetQueueAttributes",
"sqs:GetQueueUrl",
"sqs:ChangeMessageVisibility"
],
resources: [`arn:aws:sqs:${this.region}:${this.account}:editing-session-*.fifo`]
})
);

const apiName = `${lambdaName}-${environmentType.toLowerCase()}`;

const api = new apigateway.RestApi(this, `${lambdaName}-api`, {
Expand Down
28 changes: 28 additions & 0 deletions servers/fai-lambda/fai-scribe/src/handler.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import asyncio
import json
import shutil
from datetime import (
UTC,
datetime,
)
from pathlib import Path
from typing import Any

import httpx
Expand Down Expand Up @@ -75,13 +77,21 @@ async def handle_editing_request(
)
LOGGER.info(f"Repository ready at: {repo_path}")

queue_url = session.get("queue_url")
if not queue_url:
raise RuntimeError(
f"No queue_url found for editing session {editing_id}. "
f"Session is in invalid state - queue should have been created during initialization."
)

try:
session_id, pr_url = await run_editing_session(
repo_path=repo_path,
user_prompt=prompt,
base_branch=session["base_branch"],
working_branch=session["working_branch"],
editing_id=editing_id,
queue_url=queue_url,
resume_session_id=session.get("session_id"),
existing_pr_url=session.get("pr_url"),
)
Expand All @@ -101,6 +111,24 @@ async def handle_editing_request(
else:
LOGGER.info(f"Successfully updated editing session: {editing_id}")

LOGGER.info(f"Cleaning up SQS queue for session: {editing_id}")
cleanup_response = await client.delete(f"{SETTINGS.FAI_API_URL}/editing-sessions/{editing_id}/queue")
if cleanup_response.status_code == 200:
LOGGER.info(f"Successfully cleaned up queue for session: {editing_id}")
else:
LOGGER.warning(
f"Failed to cleanup queue for session {editing_id}: "
f"{cleanup_response.status_code} - {cleanup_response.text}"
)

try:
session_dir = Path(repo_path).parent
if session_dir.exists() and session_dir.name.startswith("editing-"):
shutil.rmtree(session_dir)
LOGGER.info(f"Cleaned up session directory: {session_dir}")
except Exception as cleanup_error:
LOGGER.warning(f"Failed to cleanup session directory: {cleanup_error}")

return {
"editing_id": editing_id,
"session_id": session_id,
Expand Down
80 changes: 62 additions & 18 deletions servers/fai-lambda/fai-scribe/src/utils/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,13 @@
LOGGER,
SETTINGS,
)
from .sqs_client import SQSClient
from .system_prompts import EDITING_SYSTEM_PROMPT

GITHUB_PR_URL_PATTERN = re.compile(r"(?:https?://)?(?:www\.)?github\.com/([^/]+)/([^/]+)/pull/(\d+)", re.IGNORECASE)


class SessionInterruptedError(Exception):
"""Raised when an editing session is interrupted."""

pass


Expand All @@ -45,7 +44,6 @@ async def update_session_status(editing_id: str, status: str) -> bool:


async def update_session_metadata(editing_id: str, session_id: str | None = None, pr_url: str | None = None) -> bool:
"""Update session metadata (session_id and/or pr_url) immediately when available."""
try:
payload = {}
if session_id is not None:
Expand All @@ -69,19 +67,15 @@ async def update_session_metadata(editing_id: str, session_id: str | None = None
return False


async def check_if_interrupted(editing_id: str) -> bool:
def check_if_interrupted_via_sqs(sqs_client: SQSClient) -> tuple[bool, str | None]:
try:
async with httpx.AsyncClient(timeout=5.0) as client:
response = await client.get(f"{SETTINGS.FAI_API_URL}/editing-sessions/{editing_id}")
if response.status_code == 200:
session_data = response.json()
return session_data["editing_session"]["status"] == "interrupted"
else:
LOGGER.warning(f"Failed to check interrupted status: {response.status_code}")
return False
is_interrupted, receipt_handle = sqs_client.has_interrupt_message()
if is_interrupted:
LOGGER.info("Received INTERRUPT message from SQS queue")
return is_interrupted, receipt_handle
except Exception as e:
LOGGER.warning(f"Error checking interrupted status: {e}")
return False
LOGGER.warning(f"Error checking SQS for interruption: {e}")
return False, None


async def run_editing_session(
Expand All @@ -90,6 +84,7 @@ async def run_editing_session(
base_branch: str,
working_branch: str,
editing_id: str,
queue_url: str,
resume_session_id: str | None = None,
existing_pr_url: str | None = None,
) -> tuple[str, str | None]:
Expand Down Expand Up @@ -134,6 +129,9 @@ async def run_editing_session(
session_id = resume_session_id
pr_url = existing_pr_url

sqs_client = SQSClient(queue_url)
LOGGER.info(f"Initialized SQS client for queue: {queue_url}")

async with ClaudeSDKClient(options=options) as client:
await client.query(full_prompt)

Expand All @@ -152,11 +150,14 @@ async def run_editing_session(
await update_session_status(editing_id, "active")
LOGGER.info(f"Transitioned new session {editing_id} from STARTUP to ACTIVE")

if await check_if_interrupted(editing_id):
LOGGER.warning(f"Session interrupted: {editing_id}")
if session_id is not None:
if session_id is not None:
is_interrupted, receipt_handle = check_if_interrupted_via_sqs(sqs_client)
if is_interrupted:
LOGGER.warning(f"Session interrupted via SQS: {editing_id}")
if receipt_handle:
sqs_client.delete_message(receipt_handle)
await update_session_status(editing_id, "waiting")
raise SessionInterruptedError(f"Editing session {editing_id} was interrupted")
raise SessionInterruptedError(f"Editing session {editing_id} was interrupted")

if isinstance(message, AssistantMessage):
for block in message.content:
Expand Down Expand Up @@ -185,6 +186,49 @@ async def run_editing_session(
if session_id is None:
raise RuntimeError("Failed to obtain session ID from Claude")

LOGGER.info(f"Checking for RESUME messages in queue for session {editing_id}")
resume_messages = sqs_client.get_resume_messages()

if resume_messages:
LOGGER.info(f"Found {len(resume_messages)} RESUME messages, batching prompts")

all_prompts = []
for resume_msg in resume_messages:
body = resume_msg["body"]
prompts = body.get("prompts", [])
if prompts:
all_prompts.extend(prompts)
sqs_client.delete_message(resume_msg["receipt_handle"])

if all_prompts:
batched_prompt = "\n\n".join([f"Request {i+1}: {p}" for i, p in enumerate(all_prompts)])
LOGGER.info(f"Processing {len(all_prompts)} batched prompts: {batched_prompt[:200]}...")

async with ClaudeSDKClient(options=options) as client:
await client.query(batched_prompt)

async for msg in client.receive_response():
is_interrupted, receipt_handle = check_if_interrupted_via_sqs(sqs_client)
if is_interrupted:
LOGGER.warning(f"Session interrupted while processing batched prompts: {editing_id}")
if receipt_handle:
sqs_client.delete_message(receipt_handle)
await update_session_status(editing_id, "waiting")
raise SessionInterruptedError(f"Editing session {editing_id} was interrupted")

if isinstance(msg, AssistantMessage):
for block in msg.content:
if isinstance(block, TextBlock):
for line in block.text.split("\n"):
if "PR_URL:" in line:
extracted_text = line.split("PR_URL:", 1)[1].strip()
if extracted_text:
match = GITHUB_PR_URL_PATTERN.search(extracted_text)
if match:
pr_url = match.group(0)
LOGGER.info(f"Updated PR URL: {pr_url}")
await update_session_metadata(editing_id, pr_url=pr_url)

await update_session_status(editing_id, "waiting")

return session_id, pr_url
104 changes: 104 additions & 0 deletions servers/fai-lambda/fai-scribe/src/utils/sqs_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import json
from typing import Any

import boto3
from botocore.exceptions import ClientError

from ..settings import LOGGER


class SQSClient:
def __init__(self, queue_url: str):
self.queue_url = queue_url
self.sqs_client = boto3.client("sqs")
self._last_checked_count = 0
self._cached_messages: list[dict[str, Any]] = []

def receive_messages(self, max_messages: int = 10) -> list[dict[str, Any]]:
try:
response = self.sqs_client.receive_message(
QueueUrl=self.queue_url,
MaxNumberOfMessages=min(max_messages, 10),
WaitTimeSeconds=0, # Short polling for Lambda
AttributeNames=["All"],
)

messages = response.get("Messages", [])
parsed_messages = []

for msg in messages:
try:
body = json.loads(msg["Body"])
parsed_messages.append(
{
"body": body,
"receipt_handle": msg["ReceiptHandle"],
"message_id": msg["MessageId"],
}
)
except json.JSONDecodeError:
LOGGER.warning(f"Failed to parse message body as JSON: {msg['Body']}")
continue

if parsed_messages:
self._last_checked_count += len(parsed_messages)
LOGGER.info(f"Received {len(parsed_messages)} messages from queue")

return parsed_messages

except ClientError as e:
LOGGER.error(f"Failed to receive messages from queue: {e.response['Error']['Message']}")
return []
except Exception as e:
LOGGER.error(f"Unexpected error receiving messages: {str(e)}", exc_info=True)
return []

def delete_message(self, receipt_handle: str) -> bool:
try:
self.sqs_client.delete_message(QueueUrl=self.queue_url, ReceiptHandle=receipt_handle)
LOGGER.debug("Deleted message from queue")
return True

except ClientError as e:
LOGGER.error(f"Failed to delete message: {e.response['Error']['Message']}")
return False
except Exception as e:
LOGGER.error(f"Unexpected error deleting message: {str(e)}", exc_info=True)
return False

def has_interrupt_message(self) -> tuple[bool, str | None]:
messages = self.receive_messages(max_messages=10)

for msg in messages:
body = msg["body"]
if body.get("type") == "INTERRUPT":
LOGGER.info("Found INTERRUPT message in queue")
return True, msg["receipt_handle"]
else:
self._cached_messages.append(msg)

return False, None

def get_resume_messages(self) -> list[dict[str, Any]]:
resume_messages = []

for msg in self._cached_messages:
body = msg["body"]
if body.get("type") == "RESUME":
resume_messages.append({"body": body, "receipt_handle": msg["receipt_handle"]})

messages = self.receive_messages(max_messages=10)
for msg in messages:
body = msg["body"]
if body.get("type") == "RESUME":
resume_messages.append({"body": body, "receipt_handle": msg["receipt_handle"]})

if resume_messages:
num_resume = len(resume_messages)
num_cached = len([m for m in self._cached_messages if m["body"].get("type") == "RESUME"])
num_new = len([m for m in messages if m["body"].get("type") == "RESUME"])
LOGGER.info(f"Found {num_resume} RESUME messages ({num_cached} cached, {num_new} new)")

self._cached_messages.clear()

return resume_messages
Loading
Loading