Skip to content
20 changes: 20 additions & 0 deletions servers/fai-lambda-deploy/scripts/fai-scribe-stack.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import * as apigateway from "aws-cdk-lib/aws-apigateway";
import { Certificate } from "aws-cdk-lib/aws-certificatemanager";
import * as ec2 from "aws-cdk-lib/aws-ec2";
import * as efs from "aws-cdk-lib/aws-efs";
import * as iam from "aws-cdk-lib/aws-iam";
import * as lambda from "aws-cdk-lib/aws-lambda";
import { LogGroup, RetentionDays } from "aws-cdk-lib/aws-logs";
import { ARecord, HostedZone, RecordTarget } from "aws-cdk-lib/aws-route53";
Expand Down Expand Up @@ -125,6 +126,25 @@ export class FaiScribeStack extends Stack {
filesystem: lambda.FileSystem.fromEfsAccessPoint(accessPoint, "/mnt/efs")
});

// Grant SQS permissions for editing session queues
// These queues are created dynamically by the FAI server per editing session
lambdaFunction.addToRolePolicy(
new iam.PolicyStatement({
effect: iam.Effect.ALLOW,
actions: [
"sqs:ReceiveMessage",
"sqs:DeleteMessage",
"sqs:GetQueueAttributes",
"sqs:GetQueueUrl",
"sqs:ChangeMessageVisibility"
],
resources: [
// Allow access to all editing-session queues in this region
`arn:aws:sqs:${this.region}:${this.account}:editing-session-*.fifo`
]
})
);

const apiName = `${lambdaName}-${environmentType.toLowerCase()}`;

const api = new apigateway.RestApi(this, `${lambdaName}-api`, {
Expand Down
17 changes: 17 additions & 0 deletions servers/fai-lambda/fai-scribe/src/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,19 @@ async def handle_editing_request(
)
LOGGER.info(f"Repository ready at: {repo_path}")

# Ensure queue_url is present
queue_url = session.get("queue_url")
if not queue_url:
raise RuntimeError(f"No queue_url found for editing session: {editing_id}")

try:
session_id, pr_url = await run_editing_session(
repo_path=repo_path,
user_prompt=prompt,
base_branch=session["base_branch"],
working_branch=session["working_branch"],
editing_id=editing_id,
queue_url=queue_url,
resume_session_id=session.get("session_id"),
existing_pr_url=session.get("pr_url"),
)
Expand All @@ -101,6 +107,17 @@ async def handle_editing_request(
else:
LOGGER.info(f"Successfully updated editing session: {editing_id}")

# Clean up the SQS queue after successful completion
LOGGER.info(f"Cleaning up SQS queue for session: {editing_id}")
cleanup_response = await client.delete(f"{SETTINGS.FAI_API_URL}/editing-sessions/{editing_id}/queue")
if cleanup_response.status_code == 200:
LOGGER.info(f"Successfully cleaned up queue for session: {editing_id}")
else:
LOGGER.warning(
f"Failed to cleanup queue for session {editing_id}: "
f"{cleanup_response.status_code} - {cleanup_response.text}"
)

return {
"editing_id": editing_id,
"session_id": session_id,
Expand Down
4 changes: 2 additions & 2 deletions servers/fai-lambda/fai-scribe/src/settings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
import os
from typing import Any

from dotenv import load_dotenv
Expand All @@ -10,7 +9,8 @@


class Settings:
FAI_API_URL: str = os.environ.get("FAI_API_URL") or "https://fai.buildwithfern.com"
# FAI_API_URL: str = os.environ.get("FAI_API_URL") or "https://fai.buildwithfern.com"
FAI_API_URL: str = "https://wen-metagnathous-deridingly.ngrok-free.app"


class SingletonFactory:
Expand Down
33 changes: 17 additions & 16 deletions servers/fai-lambda/fai-scribe/src/utils/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,13 @@
LOGGER,
SETTINGS,
)
from .sqs_client import SQSClient
from .system_prompts import EDITING_SYSTEM_PROMPT

GITHUB_PR_URL_PATTERN = re.compile(r"(?:https?://)?(?:www\.)?github\.com/([^/]+)/([^/]+)/pull/(\d+)", re.IGNORECASE)


class SessionInterruptedError(Exception):
"""Raised when an editing session is interrupted."""

pass


Expand All @@ -45,7 +44,6 @@ async def update_session_status(editing_id: str, status: str) -> bool:


async def update_session_metadata(editing_id: str, session_id: str | None = None, pr_url: str | None = None) -> bool:
"""Update session metadata (session_id and/or pr_url) immediately when available."""
try:
payload = {}
if session_id is not None:
Expand All @@ -69,19 +67,15 @@ async def update_session_metadata(editing_id: str, session_id: str | None = None
return False


async def check_if_interrupted(editing_id: str) -> bool:
def check_if_interrupted_via_sqs(sqs_client: SQSClient) -> tuple[bool, str | None]:
try:
async with httpx.AsyncClient(timeout=5.0) as client:
response = await client.get(f"{SETTINGS.FAI_API_URL}/editing-sessions/{editing_id}")
if response.status_code == 200:
session_data = response.json()
return session_data["editing_session"]["status"] == "interrupted"
else:
LOGGER.warning(f"Failed to check interrupted status: {response.status_code}")
return False
is_interrupted, receipt_handle = sqs_client.has_interrupt_message()
if is_interrupted:
LOGGER.info("Received INTERRUPT message from SQS queue")
return is_interrupted, receipt_handle
except Exception as e:
LOGGER.warning(f"Error checking interrupted status: {e}")
return False
LOGGER.warning(f"Error checking SQS for interruption: {e}")
return False, None


async def run_editing_session(
Expand All @@ -90,6 +84,7 @@ async def run_editing_session(
base_branch: str,
working_branch: str,
editing_id: str,
queue_url: str,
resume_session_id: str | None = None,
existing_pr_url: str | None = None,
) -> tuple[str, str | None]:
Expand Down Expand Up @@ -134,6 +129,9 @@ async def run_editing_session(
session_id = resume_session_id
pr_url = existing_pr_url

sqs_client = SQSClient(queue_url)
LOGGER.info(f"Initialized SQS client for queue: {queue_url}")

async with ClaudeSDKClient(options=options) as client:
await client.query(full_prompt)

Expand All @@ -152,8 +150,11 @@ async def run_editing_session(
await update_session_status(editing_id, "active")
LOGGER.info(f"Transitioned new session {editing_id} from STARTUP to ACTIVE")

if await check_if_interrupted(editing_id):
LOGGER.warning(f"Session interrupted: {editing_id}")
is_interrupted, receipt_handle = check_if_interrupted_via_sqs(sqs_client)
if is_interrupted:
LOGGER.warning(f"Session interrupted via SQS: {editing_id}")
if receipt_handle:
sqs_client.delete_message(receipt_handle)
if session_id is not None:
await update_session_status(editing_id, "waiting")
raise SessionInterruptedError(f"Editing session {editing_id} was interrupted")
Expand Down
92 changes: 92 additions & 0 deletions servers/fai-lambda/fai-scribe/src/utils/sqs_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import json
import logging
from typing import Any

import boto3
from botocore.exceptions import ClientError

from ..settings import LOGGER


class SQSClient:
def __init__(self, queue_url: str):
self.queue_url = queue_url
self.sqs_client = boto3.client("sqs")
self._last_checked_count = 0

def receive_messages(self, max_messages: int = 10) -> list[dict[str, Any]]:
try:
response = self.sqs_client.receive_message(
QueueUrl=self.queue_url,
MaxNumberOfMessages=min(max_messages, 10),
WaitTimeSeconds=0, # Short polling for Lambda
AttributeNames=["All"],
)

messages = response.get("Messages", [])
parsed_messages = []

for msg in messages:
try:
body = json.loads(msg["Body"])
parsed_messages.append(
{
"body": body,
"receipt_handle": msg["ReceiptHandle"],
"message_id": msg["MessageId"],
}
)
except json.JSONDecodeError:
LOGGER.warning(f"Failed to parse message body as JSON: {msg['Body']}")
continue

if parsed_messages:
self._last_checked_count += len(parsed_messages)
LOGGER.info(f"Received {len(parsed_messages)} messages from queue")

return parsed_messages

except ClientError as e:
LOGGER.error(f"Failed to receive messages from queue: {e.response['Error']['Message']}")
return []
except Exception as e:
LOGGER.error(f"Unexpected error receiving messages: {str(e)}", exc_info=True)
return []

def delete_message(self, receipt_handle: str) -> bool:
try:
self.sqs_client.delete_message(QueueUrl=self.queue_url, ReceiptHandle=receipt_handle)
LOGGER.debug("Deleted message from queue")
return True

except ClientError as e:
LOGGER.error(f"Failed to delete message: {e.response['Error']['Message']}")
return False
except Exception as e:
LOGGER.error(f"Unexpected error deleting message: {str(e)}", exc_info=True)
return False

def has_interrupt_message(self) -> tuple[bool, str | None]:
messages = self.receive_messages(max_messages=10)

for msg in messages:
body = msg["body"]
if body.get("type") == "INTERRUPT":
LOGGER.info(f"Found INTERRUPT message in queue")
return True, msg["receipt_handle"]

return False, None
Comment on lines 69 to 80
Copy link
Contributor

@vercel vercel bot Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The has_interrupt_message() method fetches messages but doesn't acknowledge non-INTERRUPT messages, causing queue message accumulation and duplicates on retry.

View Details
📝 Patch Details
diff --git a/servers/fai-lambda/fai-scribe/src/utils/sqs_client.py b/servers/fai-lambda/fai-scribe/src/utils/sqs_client.py
index 7df5e881d..f5703da45 100644
--- a/servers/fai-lambda/fai-scribe/src/utils/sqs_client.py
+++ b/servers/fai-lambda/fai-scribe/src/utils/sqs_client.py
@@ -68,14 +68,21 @@ class SQSClient:
 
     def has_interrupt_message(self) -> tuple[bool, str | None]:
         messages = self.receive_messages(max_messages=10)
+        interrupt_receipt = None
 
         for msg in messages:
             body = msg["body"]
+            receipt_handle = msg["receipt_handle"]
+            
             if body.get("type") == "INTERRUPT":
                 LOGGER.info(f"Found INTERRUPT message in queue")
-                return True, msg["receipt_handle"]
+                interrupt_receipt = receipt_handle
+            else:
+                # Delete non-INTERRUPT messages to prevent queue accumulation
+                self.delete_message(receipt_handle)
+                LOGGER.debug(f"Deleted processed {body.get('type', 'UNKNOWN')} message")
 
-        return False, None
+        return interrupt_receipt is not None, interrupt_receipt
 
     def get_resume_messages(self) -> list[dict[str, Any]]:
         messages = self.receive_messages(max_messages=10)
@@ -83,8 +90,14 @@ class SQSClient:
 
         for msg in messages:
             body = msg["body"]
+            receipt_handle = msg["receipt_handle"]
+            
             if body.get("type") == "RESUME":
-                resume_messages.append({"body": body, "receipt_handle": msg["receipt_handle"]})
+                resume_messages.append({"body": body, "receipt_handle": receipt_handle})
+            else:
+                # Delete non-RESUME messages to prevent queue accumulation
+                self.delete_message(receipt_handle)
+                LOGGER.debug(f"Deleted processed {body.get('type', 'UNKNOWN')} message")
 
         if resume_messages:
             LOGGER.info(f"Found {len(resume_messages)} RESUME messages in queue")

Analysis

SQS message accumulation in has_interrupt_message() and get_resume_messages()

What fails: SQSClient.has_interrupt_message() and SQSClient.get_resume_messages() fetch messages but only delete specific message types (INTERRUPT/RESUME), leaving other messages in the queue to accumulate

How to reproduce:

# Simulate mixed message types in SQS queue
# Call has_interrupt_message() repeatedly
client = SQSClient(queue_url)
client.has_interrupt_message()  # Fetches all messages, only deletes INTERRUPT
client.has_interrupt_message()  # Re-fetches same non-INTERRUPT messages

Result: Non-INTERRUPT messages (RESUME, OTHER types) remain in queue after each poll and get re-processed on every subsequent call, causing message accumulation and duplicate processing

Expected: All processed messages should be deleted from queue to prevent redelivery after SQS visibility timeout expires

Root cause: Methods call receive_messages() but only delete messages of specific types, violating SQS best practices for message cleanup


def get_resume_messages(self) -> list[dict[str, Any]]:
messages = self.receive_messages(max_messages=10)
resume_messages = []

for msg in messages:
body = msg["body"]
if body.get("type") == "RESUME":
resume_messages.append({"body": body, "receipt_handle": msg["receipt_handle"]})

if resume_messages:
LOGGER.info(f"Found {len(resume_messages)} RESUME messages in queue")

return resume_messages
Loading