Skip to content

Commit 608a4d4

Browse files
authored
feat(scribe): Implement SQS to eliminate state machine polling. (#4639)
1 parent e583e9d commit 608a4d4

File tree

14 files changed

+733
-202
lines changed

14 files changed

+733
-202
lines changed

servers/fai-lambda-deploy/scripts/fai-scribe-stack.ts

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import { type EnvironmentInfo, EnvironmentType } from "@fern-fern/fern-cloud-sdk/api";
2-
import { CfnOutput, Duration, RemovalPolicy, Stack, type StackProps } from "aws-cdk-lib";
2+
import { CfnOutput, Duration, RemovalPolicy, Size, Stack, type StackProps } from "aws-cdk-lib";
33
import * as apigateway from "aws-cdk-lib/aws-apigateway";
44
import { Certificate } from "aws-cdk-lib/aws-certificatemanager";
55
import * as ec2 from "aws-cdk-lib/aws-ec2";
66
import * as efs from "aws-cdk-lib/aws-efs";
7+
import * as iam from "aws-cdk-lib/aws-iam";
78
import * as lambda from "aws-cdk-lib/aws-lambda";
89
import { LogGroup, RetentionDays } from "aws-cdk-lib/aws-logs";
910
import { ARecord, HostedZone, RecordTarget } from "aws-cdk-lib/aws-route53";
@@ -109,6 +110,7 @@ export class FaiScribeStack extends Stack {
109110
}),
110111
timeout: Duration.minutes(15),
111112
memorySize: 512,
113+
ephemeralStorageSize: Size.mebibytes(2048),
112114
logGroup,
113115
vpc: this.vpc,
114116
vpcSubnets: {
@@ -125,6 +127,20 @@ export class FaiScribeStack extends Stack {
125127
filesystem: lambda.FileSystem.fromEfsAccessPoint(accessPoint, "/mnt/efs")
126128
});
127129

130+
lambdaFunction.addToRolePolicy(
131+
new iam.PolicyStatement({
132+
effect: iam.Effect.ALLOW,
133+
actions: [
134+
"sqs:ReceiveMessage",
135+
"sqs:DeleteMessage",
136+
"sqs:GetQueueAttributes",
137+
"sqs:GetQueueUrl",
138+
"sqs:ChangeMessageVisibility"
139+
],
140+
resources: [`arn:aws:sqs:${this.region}:${this.account}:editing-session-*.fifo`]
141+
})
142+
);
143+
128144
const apiName = `${lambdaName}-${environmentType.toLowerCase()}`;
129145

130146
const api = new apigateway.RestApi(this, `${lambdaName}-api`, {

servers/fai-lambda/fai-scribe/src/handler.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import asyncio
22
import json
3+
import shutil
34
from datetime import (
45
UTC,
56
datetime,
67
)
8+
from pathlib import Path
79
from typing import Any
810

911
import httpx
@@ -75,13 +77,21 @@ async def handle_editing_request(
7577
)
7678
LOGGER.info(f"Repository ready at: {repo_path}")
7779

80+
queue_url = session.get("queue_url")
81+
if not queue_url:
82+
raise RuntimeError(
83+
f"No queue_url found for editing session {editing_id}. "
84+
f"Session is in invalid state - queue should have been created during initialization."
85+
)
86+
7887
try:
7988
session_id, pr_url = await run_editing_session(
8089
repo_path=repo_path,
8190
user_prompt=prompt,
8291
base_branch=session["base_branch"],
8392
working_branch=session["working_branch"],
8493
editing_id=editing_id,
94+
queue_url=queue_url,
8595
resume_session_id=session.get("session_id"),
8696
existing_pr_url=session.get("pr_url"),
8797
)
@@ -101,6 +111,24 @@ async def handle_editing_request(
101111
else:
102112
LOGGER.info(f"Successfully updated editing session: {editing_id}")
103113

114+
LOGGER.info(f"Cleaning up SQS queue for session: {editing_id}")
115+
cleanup_response = await client.delete(f"{SETTINGS.FAI_API_URL}/editing-sessions/{editing_id}/queue")
116+
if cleanup_response.status_code == 200:
117+
LOGGER.info(f"Successfully cleaned up queue for session: {editing_id}")
118+
else:
119+
LOGGER.warning(
120+
f"Failed to cleanup queue for session {editing_id}: "
121+
f"{cleanup_response.status_code} - {cleanup_response.text}"
122+
)
123+
124+
try:
125+
session_dir = Path(repo_path).parent
126+
if session_dir.exists() and session_dir.name.startswith("editing-"):
127+
shutil.rmtree(session_dir)
128+
LOGGER.info(f"Cleaned up session directory: {session_dir}")
129+
except Exception as cleanup_error:
130+
LOGGER.warning(f"Failed to cleanup session directory: {cleanup_error}")
131+
104132
return {
105133
"editing_id": editing_id,
106134
"session_id": session_id,

servers/fai-lambda/fai-scribe/src/utils/agent.py

Lines changed: 62 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,13 @@
1616
LOGGER,
1717
SETTINGS,
1818
)
19+
from .sqs_client import SQSClient
1920
from .system_prompts import EDITING_SYSTEM_PROMPT
2021

2122
GITHUB_PR_URL_PATTERN = re.compile(r"(?:https?://)?(?:www\.)?github\.com/([^/]+)/([^/]+)/pull/(\d+)", re.IGNORECASE)
2223

2324

2425
class SessionInterruptedError(Exception):
25-
"""Raised when an editing session is interrupted."""
26-
2726
pass
2827

2928

@@ -45,7 +44,6 @@ async def update_session_status(editing_id: str, status: str) -> bool:
4544

4645

4746
async def update_session_metadata(editing_id: str, session_id: str | None = None, pr_url: str | None = None) -> bool:
48-
"""Update session metadata (session_id and/or pr_url) immediately when available."""
4947
try:
5048
payload = {}
5149
if session_id is not None:
@@ -69,19 +67,15 @@ async def update_session_metadata(editing_id: str, session_id: str | None = None
6967
return False
7068

7169

72-
async def check_if_interrupted(editing_id: str) -> bool:
70+
def check_if_interrupted_via_sqs(sqs_client: SQSClient) -> tuple[bool, str | None]:
7371
try:
74-
async with httpx.AsyncClient(timeout=5.0) as client:
75-
response = await client.get(f"{SETTINGS.FAI_API_URL}/editing-sessions/{editing_id}")
76-
if response.status_code == 200:
77-
session_data = response.json()
78-
return session_data["editing_session"]["status"] == "interrupted"
79-
else:
80-
LOGGER.warning(f"Failed to check interrupted status: {response.status_code}")
81-
return False
72+
is_interrupted, receipt_handle = sqs_client.has_interrupt_message()
73+
if is_interrupted:
74+
LOGGER.info("Received INTERRUPT message from SQS queue")
75+
return is_interrupted, receipt_handle
8276
except Exception as e:
83-
LOGGER.warning(f"Error checking interrupted status: {e}")
84-
return False
77+
LOGGER.warning(f"Error checking SQS for interruption: {e}")
78+
return False, None
8579

8680

8781
async def run_editing_session(
@@ -90,6 +84,7 @@ async def run_editing_session(
9084
base_branch: str,
9185
working_branch: str,
9286
editing_id: str,
87+
queue_url: str,
9388
resume_session_id: str | None = None,
9489
existing_pr_url: str | None = None,
9590
) -> tuple[str, str | None]:
@@ -134,6 +129,9 @@ async def run_editing_session(
134129
session_id = resume_session_id
135130
pr_url = existing_pr_url
136131

132+
sqs_client = SQSClient(queue_url)
133+
LOGGER.info(f"Initialized SQS client for queue: {queue_url}")
134+
137135
async with ClaudeSDKClient(options=options) as client:
138136
await client.query(full_prompt)
139137

@@ -152,11 +150,14 @@ async def run_editing_session(
152150
await update_session_status(editing_id, "active")
153151
LOGGER.info(f"Transitioned new session {editing_id} from STARTUP to ACTIVE")
154152

155-
if await check_if_interrupted(editing_id):
156-
LOGGER.warning(f"Session interrupted: {editing_id}")
157-
if session_id is not None:
153+
if session_id is not None:
154+
is_interrupted, receipt_handle = check_if_interrupted_via_sqs(sqs_client)
155+
if is_interrupted:
156+
LOGGER.warning(f"Session interrupted via SQS: {editing_id}")
157+
if receipt_handle:
158+
sqs_client.delete_message(receipt_handle)
158159
await update_session_status(editing_id, "waiting")
159-
raise SessionInterruptedError(f"Editing session {editing_id} was interrupted")
160+
raise SessionInterruptedError(f"Editing session {editing_id} was interrupted")
160161

161162
if isinstance(message, AssistantMessage):
162163
for block in message.content:
@@ -185,6 +186,49 @@ async def run_editing_session(
185186
if session_id is None:
186187
raise RuntimeError("Failed to obtain session ID from Claude")
187188

189+
LOGGER.info(f"Checking for RESUME messages in queue for session {editing_id}")
190+
resume_messages = sqs_client.get_resume_messages()
191+
192+
if resume_messages:
193+
LOGGER.info(f"Found {len(resume_messages)} RESUME messages, batching prompts")
194+
195+
all_prompts = []
196+
for resume_msg in resume_messages:
197+
body = resume_msg["body"]
198+
prompts = body.get("prompts", [])
199+
if prompts:
200+
all_prompts.extend(prompts)
201+
sqs_client.delete_message(resume_msg["receipt_handle"])
202+
203+
if all_prompts:
204+
batched_prompt = "\n\n".join([f"Request {i+1}: {p}" for i, p in enumerate(all_prompts)])
205+
LOGGER.info(f"Processing {len(all_prompts)} batched prompts: {batched_prompt[:200]}...")
206+
207+
async with ClaudeSDKClient(options=options) as client:
208+
await client.query(batched_prompt)
209+
210+
async for msg in client.receive_response():
211+
is_interrupted, receipt_handle = check_if_interrupted_via_sqs(sqs_client)
212+
if is_interrupted:
213+
LOGGER.warning(f"Session interrupted while processing batched prompts: {editing_id}")
214+
if receipt_handle:
215+
sqs_client.delete_message(receipt_handle)
216+
await update_session_status(editing_id, "waiting")
217+
raise SessionInterruptedError(f"Editing session {editing_id} was interrupted")
218+
219+
if isinstance(msg, AssistantMessage):
220+
for block in msg.content:
221+
if isinstance(block, TextBlock):
222+
for line in block.text.split("\n"):
223+
if "PR_URL:" in line:
224+
extracted_text = line.split("PR_URL:", 1)[1].strip()
225+
if extracted_text:
226+
match = GITHUB_PR_URL_PATTERN.search(extracted_text)
227+
if match:
228+
pr_url = match.group(0)
229+
LOGGER.info(f"Updated PR URL: {pr_url}")
230+
await update_session_metadata(editing_id, pr_url=pr_url)
231+
188232
await update_session_status(editing_id, "waiting")
189233

190234
return session_id, pr_url
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import json
2+
from typing import Any
3+
4+
import boto3
5+
from botocore.exceptions import ClientError
6+
7+
from ..settings import LOGGER
8+
9+
10+
class SQSClient:
11+
def __init__(self, queue_url: str):
12+
self.queue_url = queue_url
13+
self.sqs_client = boto3.client("sqs")
14+
self._last_checked_count = 0
15+
self._cached_messages: list[dict[str, Any]] = []
16+
17+
def receive_messages(self, max_messages: int = 10) -> list[dict[str, Any]]:
18+
try:
19+
response = self.sqs_client.receive_message(
20+
QueueUrl=self.queue_url,
21+
MaxNumberOfMessages=min(max_messages, 10),
22+
WaitTimeSeconds=0, # Short polling for Lambda
23+
AttributeNames=["All"],
24+
)
25+
26+
messages = response.get("Messages", [])
27+
parsed_messages = []
28+
29+
for msg in messages:
30+
try:
31+
body = json.loads(msg["Body"])
32+
parsed_messages.append(
33+
{
34+
"body": body,
35+
"receipt_handle": msg["ReceiptHandle"],
36+
"message_id": msg["MessageId"],
37+
}
38+
)
39+
except json.JSONDecodeError:
40+
LOGGER.warning(f"Failed to parse message body as JSON: {msg['Body']}")
41+
continue
42+
43+
if parsed_messages:
44+
self._last_checked_count += len(parsed_messages)
45+
LOGGER.info(f"Received {len(parsed_messages)} messages from queue")
46+
47+
return parsed_messages
48+
49+
except ClientError as e:
50+
LOGGER.error(f"Failed to receive messages from queue: {e.response['Error']['Message']}")
51+
return []
52+
except Exception as e:
53+
LOGGER.error(f"Unexpected error receiving messages: {str(e)}", exc_info=True)
54+
return []
55+
56+
def delete_message(self, receipt_handle: str) -> bool:
57+
try:
58+
self.sqs_client.delete_message(QueueUrl=self.queue_url, ReceiptHandle=receipt_handle)
59+
LOGGER.debug("Deleted message from queue")
60+
return True
61+
62+
except ClientError as e:
63+
LOGGER.error(f"Failed to delete message: {e.response['Error']['Message']}")
64+
return False
65+
except Exception as e:
66+
LOGGER.error(f"Unexpected error deleting message: {str(e)}", exc_info=True)
67+
return False
68+
69+
def has_interrupt_message(self) -> tuple[bool, str | None]:
70+
messages = self.receive_messages(max_messages=10)
71+
72+
for msg in messages:
73+
body = msg["body"]
74+
if body.get("type") == "INTERRUPT":
75+
LOGGER.info("Found INTERRUPT message in queue")
76+
return True, msg["receipt_handle"]
77+
else:
78+
self._cached_messages.append(msg)
79+
80+
return False, None
81+
82+
def get_resume_messages(self) -> list[dict[str, Any]]:
83+
resume_messages = []
84+
85+
for msg in self._cached_messages:
86+
body = msg["body"]
87+
if body.get("type") == "RESUME":
88+
resume_messages.append({"body": body, "receipt_handle": msg["receipt_handle"]})
89+
90+
messages = self.receive_messages(max_messages=10)
91+
for msg in messages:
92+
body = msg["body"]
93+
if body.get("type") == "RESUME":
94+
resume_messages.append({"body": body, "receipt_handle": msg["receipt_handle"]})
95+
96+
if resume_messages:
97+
num_resume = len(resume_messages)
98+
num_cached = len([m for m in self._cached_messages if m["body"].get("type") == "RESUME"])
99+
num_new = len([m for m in messages if m["body"].get("type") == "RESUME"])
100+
LOGGER.info(f"Found {num_resume} RESUME messages ({num_cached} cached, {num_new} new)")
101+
102+
self._cached_messages.clear()
103+
104+
return resume_messages

0 commit comments

Comments
 (0)