diff --git a/servers/fai-lambda-deploy/scripts/fai-scribe-stack.ts b/servers/fai-lambda-deploy/scripts/fai-scribe-stack.ts index 16dab9fad3..08f5b143b5 100644 --- a/servers/fai-lambda-deploy/scripts/fai-scribe-stack.ts +++ b/servers/fai-lambda-deploy/scripts/fai-scribe-stack.ts @@ -1,9 +1,10 @@ import { type EnvironmentInfo, EnvironmentType } from "@fern-fern/fern-cloud-sdk/api"; -import { CfnOutput, Duration, RemovalPolicy, Stack, type StackProps } from "aws-cdk-lib"; +import { CfnOutput, Duration, RemovalPolicy, Size, Stack, type StackProps } from "aws-cdk-lib"; import * as apigateway from "aws-cdk-lib/aws-apigateway"; import { Certificate } from "aws-cdk-lib/aws-certificatemanager"; import * as ec2 from "aws-cdk-lib/aws-ec2"; import * as efs from "aws-cdk-lib/aws-efs"; +import * as iam from "aws-cdk-lib/aws-iam"; import * as lambda from "aws-cdk-lib/aws-lambda"; import { LogGroup, RetentionDays } from "aws-cdk-lib/aws-logs"; import { ARecord, HostedZone, RecordTarget } from "aws-cdk-lib/aws-route53"; @@ -109,6 +110,7 @@ export class FaiScribeStack extends Stack { }), timeout: Duration.minutes(15), memorySize: 512, + ephemeralStorageSize: Size.mebibytes(2048), logGroup, vpc: this.vpc, vpcSubnets: { @@ -125,6 +127,20 @@ export class FaiScribeStack extends Stack { filesystem: lambda.FileSystem.fromEfsAccessPoint(accessPoint, "/mnt/efs") }); + lambdaFunction.addToRolePolicy( + new iam.PolicyStatement({ + effect: iam.Effect.ALLOW, + actions: [ + "sqs:ReceiveMessage", + "sqs:DeleteMessage", + "sqs:GetQueueAttributes", + "sqs:GetQueueUrl", + "sqs:ChangeMessageVisibility" + ], + resources: [`arn:aws:sqs:${this.region}:${this.account}:editing-session-*.fifo`] + }) + ); + const apiName = `${lambdaName}-${environmentType.toLowerCase()}`; const api = new apigateway.RestApi(this, `${lambdaName}-api`, { diff --git a/servers/fai-lambda/fai-scribe/src/handler.py b/servers/fai-lambda/fai-scribe/src/handler.py index 3f08b3f195..eda407e061 100644 --- a/servers/fai-lambda/fai-scribe/src/handler.py +++ b/servers/fai-lambda/fai-scribe/src/handler.py @@ -1,9 +1,11 @@ import asyncio import json +import shutil from datetime import ( UTC, datetime, ) +from pathlib import Path from typing import Any import httpx @@ -75,6 +77,13 @@ async def handle_editing_request( ) LOGGER.info(f"Repository ready at: {repo_path}") + queue_url = session.get("queue_url") + if not queue_url: + raise RuntimeError( + f"No queue_url found for editing session {editing_id}. " + f"Session is in invalid state - queue should have been created during initialization." + ) + try: session_id, pr_url = await run_editing_session( repo_path=repo_path, @@ -82,6 +91,7 @@ async def handle_editing_request( base_branch=session["base_branch"], working_branch=session["working_branch"], editing_id=editing_id, + queue_url=queue_url, resume_session_id=session.get("session_id"), existing_pr_url=session.get("pr_url"), ) @@ -101,6 +111,24 @@ async def handle_editing_request( else: LOGGER.info(f"Successfully updated editing session: {editing_id}") + LOGGER.info(f"Cleaning up SQS queue for session: {editing_id}") + cleanup_response = await client.delete(f"{SETTINGS.FAI_API_URL}/editing-sessions/{editing_id}/queue") + if cleanup_response.status_code == 200: + LOGGER.info(f"Successfully cleaned up queue for session: {editing_id}") + else: + LOGGER.warning( + f"Failed to cleanup queue for session {editing_id}: " + f"{cleanup_response.status_code} - {cleanup_response.text}" + ) + + try: + session_dir = Path(repo_path).parent + if session_dir.exists() and session_dir.name.startswith("editing-"): + shutil.rmtree(session_dir) + LOGGER.info(f"Cleaned up session directory: {session_dir}") + except Exception as cleanup_error: + LOGGER.warning(f"Failed to cleanup session directory: {cleanup_error}") + return { "editing_id": editing_id, "session_id": session_id, diff --git a/servers/fai-lambda/fai-scribe/src/utils/agent.py b/servers/fai-lambda/fai-scribe/src/utils/agent.py index bd1966b58b..dee77ef280 100644 --- a/servers/fai-lambda/fai-scribe/src/utils/agent.py +++ b/servers/fai-lambda/fai-scribe/src/utils/agent.py @@ -16,14 +16,13 @@ LOGGER, SETTINGS, ) +from .sqs_client import SQSClient from .system_prompts import EDITING_SYSTEM_PROMPT GITHUB_PR_URL_PATTERN = re.compile(r"(?:https?://)?(?:www\.)?github\.com/([^/]+)/([^/]+)/pull/(\d+)", re.IGNORECASE) class SessionInterruptedError(Exception): - """Raised when an editing session is interrupted.""" - pass @@ -45,7 +44,6 @@ async def update_session_status(editing_id: str, status: str) -> bool: async def update_session_metadata(editing_id: str, session_id: str | None = None, pr_url: str | None = None) -> bool: - """Update session metadata (session_id and/or pr_url) immediately when available.""" try: payload = {} if session_id is not None: @@ -69,19 +67,15 @@ async def update_session_metadata(editing_id: str, session_id: str | None = None return False -async def check_if_interrupted(editing_id: str) -> bool: +def check_if_interrupted_via_sqs(sqs_client: SQSClient) -> tuple[bool, str | None]: try: - async with httpx.AsyncClient(timeout=5.0) as client: - response = await client.get(f"{SETTINGS.FAI_API_URL}/editing-sessions/{editing_id}") - if response.status_code == 200: - session_data = response.json() - return session_data["editing_session"]["status"] == "interrupted" - else: - LOGGER.warning(f"Failed to check interrupted status: {response.status_code}") - return False + is_interrupted, receipt_handle = sqs_client.has_interrupt_message() + if is_interrupted: + LOGGER.info("Received INTERRUPT message from SQS queue") + return is_interrupted, receipt_handle except Exception as e: - LOGGER.warning(f"Error checking interrupted status: {e}") - return False + LOGGER.warning(f"Error checking SQS for interruption: {e}") + return False, None async def run_editing_session( @@ -90,6 +84,7 @@ async def run_editing_session( base_branch: str, working_branch: str, editing_id: str, + queue_url: str, resume_session_id: str | None = None, existing_pr_url: str | None = None, ) -> tuple[str, str | None]: @@ -134,6 +129,9 @@ async def run_editing_session( session_id = resume_session_id pr_url = existing_pr_url + sqs_client = SQSClient(queue_url) + LOGGER.info(f"Initialized SQS client for queue: {queue_url}") + async with ClaudeSDKClient(options=options) as client: await client.query(full_prompt) @@ -152,11 +150,14 @@ async def run_editing_session( await update_session_status(editing_id, "active") LOGGER.info(f"Transitioned new session {editing_id} from STARTUP to ACTIVE") - if await check_if_interrupted(editing_id): - LOGGER.warning(f"Session interrupted: {editing_id}") - if session_id is not None: + if session_id is not None: + is_interrupted, receipt_handle = check_if_interrupted_via_sqs(sqs_client) + if is_interrupted: + LOGGER.warning(f"Session interrupted via SQS: {editing_id}") + if receipt_handle: + sqs_client.delete_message(receipt_handle) await update_session_status(editing_id, "waiting") - raise SessionInterruptedError(f"Editing session {editing_id} was interrupted") + raise SessionInterruptedError(f"Editing session {editing_id} was interrupted") if isinstance(message, AssistantMessage): for block in message.content: @@ -185,6 +186,49 @@ async def run_editing_session( if session_id is None: raise RuntimeError("Failed to obtain session ID from Claude") + LOGGER.info(f"Checking for RESUME messages in queue for session {editing_id}") + resume_messages = sqs_client.get_resume_messages() + + if resume_messages: + LOGGER.info(f"Found {len(resume_messages)} RESUME messages, batching prompts") + + all_prompts = [] + for resume_msg in resume_messages: + body = resume_msg["body"] + prompts = body.get("prompts", []) + if prompts: + all_prompts.extend(prompts) + sqs_client.delete_message(resume_msg["receipt_handle"]) + + if all_prompts: + batched_prompt = "\n\n".join([f"Request {i+1}: {p}" for i, p in enumerate(all_prompts)]) + LOGGER.info(f"Processing {len(all_prompts)} batched prompts: {batched_prompt[:200]}...") + + async with ClaudeSDKClient(options=options) as client: + await client.query(batched_prompt) + + async for msg in client.receive_response(): + is_interrupted, receipt_handle = check_if_interrupted_via_sqs(sqs_client) + if is_interrupted: + LOGGER.warning(f"Session interrupted while processing batched prompts: {editing_id}") + if receipt_handle: + sqs_client.delete_message(receipt_handle) + await update_session_status(editing_id, "waiting") + raise SessionInterruptedError(f"Editing session {editing_id} was interrupted") + + if isinstance(msg, AssistantMessage): + for block in msg.content: + if isinstance(block, TextBlock): + for line in block.text.split("\n"): + if "PR_URL:" in line: + extracted_text = line.split("PR_URL:", 1)[1].strip() + if extracted_text: + match = GITHUB_PR_URL_PATTERN.search(extracted_text) + if match: + pr_url = match.group(0) + LOGGER.info(f"Updated PR URL: {pr_url}") + await update_session_metadata(editing_id, pr_url=pr_url) + await update_session_status(editing_id, "waiting") return session_id, pr_url diff --git a/servers/fai-lambda/fai-scribe/src/utils/sqs_client.py b/servers/fai-lambda/fai-scribe/src/utils/sqs_client.py new file mode 100644 index 0000000000..480086a807 --- /dev/null +++ b/servers/fai-lambda/fai-scribe/src/utils/sqs_client.py @@ -0,0 +1,104 @@ +import json +from typing import Any + +import boto3 +from botocore.exceptions import ClientError + +from ..settings import LOGGER + + +class SQSClient: + def __init__(self, queue_url: str): + self.queue_url = queue_url + self.sqs_client = boto3.client("sqs") + self._last_checked_count = 0 + self._cached_messages: list[dict[str, Any]] = [] + + def receive_messages(self, max_messages: int = 10) -> list[dict[str, Any]]: + try: + response = self.sqs_client.receive_message( + QueueUrl=self.queue_url, + MaxNumberOfMessages=min(max_messages, 10), + WaitTimeSeconds=0, # Short polling for Lambda + AttributeNames=["All"], + ) + + messages = response.get("Messages", []) + parsed_messages = [] + + for msg in messages: + try: + body = json.loads(msg["Body"]) + parsed_messages.append( + { + "body": body, + "receipt_handle": msg["ReceiptHandle"], + "message_id": msg["MessageId"], + } + ) + except json.JSONDecodeError: + LOGGER.warning(f"Failed to parse message body as JSON: {msg['Body']}") + continue + + if parsed_messages: + self._last_checked_count += len(parsed_messages) + LOGGER.info(f"Received {len(parsed_messages)} messages from queue") + + return parsed_messages + + except ClientError as e: + LOGGER.error(f"Failed to receive messages from queue: {e.response['Error']['Message']}") + return [] + except Exception as e: + LOGGER.error(f"Unexpected error receiving messages: {str(e)}", exc_info=True) + return [] + + def delete_message(self, receipt_handle: str) -> bool: + try: + self.sqs_client.delete_message(QueueUrl=self.queue_url, ReceiptHandle=receipt_handle) + LOGGER.debug("Deleted message from queue") + return True + + except ClientError as e: + LOGGER.error(f"Failed to delete message: {e.response['Error']['Message']}") + return False + except Exception as e: + LOGGER.error(f"Unexpected error deleting message: {str(e)}", exc_info=True) + return False + + def has_interrupt_message(self) -> tuple[bool, str | None]: + messages = self.receive_messages(max_messages=10) + + for msg in messages: + body = msg["body"] + if body.get("type") == "INTERRUPT": + LOGGER.info("Found INTERRUPT message in queue") + return True, msg["receipt_handle"] + else: + self._cached_messages.append(msg) + + return False, None + + def get_resume_messages(self) -> list[dict[str, Any]]: + resume_messages = [] + + for msg in self._cached_messages: + body = msg["body"] + if body.get("type") == "RESUME": + resume_messages.append({"body": body, "receipt_handle": msg["receipt_handle"]}) + + messages = self.receive_messages(max_messages=10) + for msg in messages: + body = msg["body"] + if body.get("type") == "RESUME": + resume_messages.append({"body": body, "receipt_handle": msg["receipt_handle"]}) + + if resume_messages: + num_resume = len(resume_messages) + num_cached = len([m for m in self._cached_messages if m["body"].get("type") == "RESUME"]) + num_new = len([m for m in messages if m["body"].get("type") == "RESUME"]) + LOGGER.info(f"Found {num_resume} RESUME messages ({num_cached} cached, {num_new} new)") + + self._cached_messages.clear() + + return resume_messages diff --git a/servers/fai-lambda/shared/utils/git.py b/servers/fai-lambda/shared/utils/git.py index 2dc942ba08..17b9324fac 100644 --- a/servers/fai-lambda/shared/utils/git.py +++ b/servers/fai-lambda/shared/utils/git.py @@ -2,13 +2,34 @@ import os import shutil import subprocess +import time from pathlib import Path logger = logging.getLogger() +def cleanup_old_sessions(base_path: Path, max_age_hours: int = 24) -> None: + """Remove session directories older than max_age_hours.""" + try: + if not base_path.exists(): + return + + cutoff_time = time.time() - (max_age_hours * 3600) + + for item in base_path.iterdir(): + if item.is_dir() and (item.name.startswith("editing-") or item.name.startswith("session-")): + try: + if item.stat().st_mtime < cutoff_time: + logger.info(f"Removing old session directory: {item.name}") + shutil.rmtree(item) + except Exception as e: + logger.warning(f"Failed to remove old session {item.name}: {e}") + except Exception as e: + logger.warning(f"Failed to cleanup old sessions: {e}") + + def clone_repo(repository: str, session_id: str, session_type: str = "session") -> str: - """Clone a GitHub repository into /tmp. + """Clone a GitHub repository into /tmp with aggressive cleanup. Args: repository: GitHub repository in format 'owner/repo' @@ -22,39 +43,45 @@ def clone_repo(repository: str, session_id: str, session_type: str = "session") if not github_token: raise RuntimeError("GITHUB_TOKEN environment variable not set") - repo_path = Path("/tmp") / f"{session_type}-{session_id}" / repository - repo_path.parent.mkdir(parents=True, exist_ok=True) + base_path = Path("/tmp") + session_path = base_path / f"{session_type}-{session_id}" / repository - if repo_path.exists(): - logger.info(f"Directory {repo_path} already exists, removing it (Lambda container reuse)") - try: - shutil.rmtree(repo_path) - logger.info(f"Successfully removed existing directory at {repo_path}") - except Exception as e: - logger.error(f"Failed to remove existing directory at {repo_path}: {e}") - raise RuntimeError(f"Failed to clean up existing directory: {e}") + cleanup_old_sessions(base_path, max_age_hours=1) + + session_path.parent.mkdir(parents=True, exist_ok=True) + + if session_path.exists(): + logger.info(f"Removing existing session path: {session_path}") + shutil.rmtree(session_path) clone_url = f"https://x-access-token:{github_token}@github.com/{repository}.git" - logger.info(f"Cloning {repository} into {repo_path}") + logger.info(f"Cloning {repository} into {session_path}") try: subprocess.run( - ["git", "clone", clone_url, str(repo_path)], + ["git", "clone", clone_url, str(session_path)], check=True, capture_output=True, text=True, ) + logger.info(f"Successfully cloned {repository}") except subprocess.CalledProcessError as e: logger.error(f"Failed to clone repository: {e.stderr}") raise RuntimeError(f"Failed to clone {repository}: {e.stderr}") - configure_git_auth(str(repo_path)) + configure_git_auth(str(session_path)) - return str(repo_path) + return str(session_path) def configure_git_auth(repo_path: str) -> None: """Configure git user for the repository.""" + repo_path_obj = Path(repo_path) + if not repo_path_obj.exists(): + raise RuntimeError(f"Repository path does not exist: {repo_path}") + if not (repo_path_obj / ".git").exists(): + raise RuntimeError(f"Not a git repository (missing .git directory): {repo_path}") + subprocess.run(["git", "config", "user.name", "fern-support"], cwd=repo_path, check=True) subprocess.run(["git", "config", "user.email", "support@buildwithfern.com"], cwd=repo_path, check=True) diff --git a/servers/fai/alembic/versions/5564af7bb91b_add_queue_url_to_editing_sessions.py b/servers/fai/alembic/versions/5564af7bb91b_add_queue_url_to_editing_sessions.py new file mode 100644 index 0000000000..e8039b2130 --- /dev/null +++ b/servers/fai/alembic/versions/5564af7bb91b_add_queue_url_to_editing_sessions.py @@ -0,0 +1,33 @@ +"""add_queue_url_to_editing_sessions + +Revision ID: 5564af7bb91b +Revises: af951c45da91 +Create Date: 2025-10-31 12:21:38.371377 + +""" + +from typing import ( + Sequence, + Union, +) + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "5564af7bb91b" +down_revision: Union[str, Sequence[str], None] = "af951c45da91" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # Add queue_url column to editing_sessions table + op.add_column("editing_sessions", sa.Column("queue_url", sa.String(), nullable=True)) + + +def downgrade() -> None: + """Downgrade schema.""" + # Remove queue_url column from editing_sessions table + op.drop_column("editing_sessions", "queue_url") diff --git a/servers/fai/deploy/src/deploy-stack.ts b/servers/fai/deploy/src/deploy-stack.ts index 113ef7cccb..d5f6a39df4 100644 --- a/servers/fai/deploy/src/deploy-stack.ts +++ b/servers/fai/deploy/src/deploy-stack.ts @@ -159,6 +159,22 @@ export class FernAiDeployStack extends Stack { ); } + fargateService.taskDefinition.taskRole.addToPrincipalPolicy( + new iam.PolicyStatement({ + effect: iam.Effect.ALLOW, + actions: [ + "sqs:CreateQueue", + "sqs:DeleteQueue", + "sqs:SendMessage", + "sqs:GetQueueUrl", + "sqs:GetQueueAttributes", + "sqs:SetQueueAttributes", + "sqs:TagQueue" + ], + resources: [`arn:aws:sqs:us-east-1:985111089818:editing-session-*.fifo`] + }) + ); + const lbResponseTimeAlarm = new cloudwatch.Alarm( this, `fai-${environmentType.toLowerCase()}-lb-target-response-time-alarm`, diff --git a/servers/fai/src/fai/models/db/editing_session_db.py b/servers/fai/src/fai/models/db/editing_session_db.py index 8adaf8ccda..f3aa5751d1 100644 --- a/servers/fai/src/fai/models/db/editing_session_db.py +++ b/servers/fai/src/fai/models/db/editing_session_db.py @@ -39,6 +39,9 @@ class EditingSessionDb(Base): working_branch = Column(String, nullable=False) pr_url = Column(String, nullable=True) + # SQS queue for session communication + queue_url = Column(String, nullable=True) + created_at = Column(DateTime(timezone=True), nullable=False, default=lambda: datetime.now(UTC)) updated_at = Column(DateTime(timezone=True), nullable=False, default=lambda: datetime.now(UTC)) @@ -51,6 +54,7 @@ def to_api(self) -> EditingSession: base_branch=self.base_branch, working_branch=self.working_branch, pr_url=self.pr_url, + queue_url=self.queue_url, status=self.status, created_at=self.created_at, updated_at=self.updated_at, diff --git a/servers/fai/src/fai/models/types/editing_session_types.py b/servers/fai/src/fai/models/types/editing_session_types.py index 0c7dfb61f3..be687dbbd6 100644 --- a/servers/fai/src/fai/models/types/editing_session_types.py +++ b/servers/fai/src/fai/models/types/editing_session_types.py @@ -21,6 +21,7 @@ class EditingSession(BaseModel): base_branch: str working_branch: str pr_url: str | None + queue_url: str | None status: EditingSessionStatus created_at: datetime updated_at: datetime diff --git a/servers/fai/src/fai/routes/editing_sessions.py b/servers/fai/src/fai/routes/editing_sessions.py index c31d675299..f5ef1e021c 100644 --- a/servers/fai/src/fai/routes/editing_sessions.py +++ b/servers/fai/src/fai/routes/editing_sessions.py @@ -23,6 +23,12 @@ from fai.models.db.editing_session_db import EditingSessionDb from fai.models.types.editing_session_types import EditingSessionStatus from fai.settings import LOGGER +from fai.utils.sqs_utils import ( + InterruptMessage, + create_session_queue, + delete_session_queue, + send_message_to_queue, +) def _generate_working_branch(repository: str) -> str: @@ -48,6 +54,14 @@ async def create_editing_session( editing_id = str(uuid.uuid4()) working_branch = _generate_working_branch(request.repository) + queue_url = await create_session_queue(editing_id) + if queue_url is None: + LOGGER.error(f"Failed to create SQS queue for session {editing_id}") + return JSONResponse( + content={"error": "Failed to create SQS queue for editing session"}, + status_code=500, + ) + db_session = EditingSessionDb( id=editing_id, session_id=None, @@ -55,6 +69,7 @@ async def create_editing_session( base_branch=request.base_branch, working_branch=working_branch, pr_url=None, + queue_url=queue_url, status=EditingSessionStatus.STARTUP, created_at=datetime.now(UTC), updated_at=datetime.now(UTC), @@ -64,7 +79,7 @@ async def create_editing_session( await db.commit() await db.refresh(db_session) - LOGGER.info(f"Created editing session: {editing_id} with branch: {working_branch}") + LOGGER.info(f"Created editing session: {editing_id} with branch: {working_branch} and queue: {queue_url}") return JSONResponse( content=jsonable_encoder(CreateEditingSessionResponse(editing_session=db_session.to_api())), @@ -167,7 +182,7 @@ async def interrupt_editing_session( editing_id: str, db: AsyncSession = Depends(get_db), ) -> JSONResponse: - """Interrupt a running editing session.""" + """Interrupt a running editing session by sending INTERRUPT message to SQS.""" LOGGER.info(f"Interrupting editing session: {editing_id}") try: result = await db.execute(select(EditingSessionDb).where(EditingSessionDb.id == editing_id)) @@ -180,21 +195,40 @@ async def interrupt_editing_session( status_code=404, ) - # Only interrupt if session is active - if db_session.status != EditingSessionStatus.ACTIVE: + if db_session.status in [EditingSessionStatus.INTERRUPTED, EditingSessionStatus.COMPLETED]: LOGGER.warning(f"Cannot interrupt session with status {db_session.status}: {editing_id}") return JSONResponse( content={"error": f"Cannot interrupt session with status: {db_session.status}"}, status_code=400, ) + if not db_session.queue_url: + LOGGER.error(f"No queue URL for session {editing_id}, cannot send interrupt message") + return JSONResponse( + content={"error": "No queue URL found for session"}, + status_code=500, + ) + + interrupt_msg = InterruptMessage( + editing_id=editing_id, + timestamp=datetime.now(UTC).isoformat(), + ) + + success = await send_message_to_queue(db_session.queue_url, interrupt_msg) + if not success: + LOGGER.error(f"Failed to send interrupt message to queue for session {editing_id}") + return JSONResponse( + content={"error": "Failed to send interrupt message to queue"}, + status_code=500, + ) + db_session.status = EditingSessionStatus.INTERRUPTED db_session.updated_at = datetime.now(UTC) await db.commit() await db.refresh(db_session) - LOGGER.info(f"Interrupted editing session: {editing_id}") + LOGGER.info(f"Sent interrupt message and updated status for session: {editing_id}") return JSONResponse( content=jsonable_encoder(InterruptEditingSessionResponse(editing_session=db_session.to_api())), @@ -207,3 +241,57 @@ async def interrupt_editing_session( content={"error": "Failed to interrupt editing session", "details": str(e)}, status_code=500, ) + + +@fai_app.delete( + "/editing-sessions/{editing_id}/queue", + openapi_extra={"x-fern-audiences": ["internal"]}, +) +async def cleanup_session_queue( + editing_id: str, + db: AsyncSession = Depends(get_db), +) -> JSONResponse: + """Delete the SQS queue for an editing session.""" + LOGGER.info(f"Cleaning up queue for editing session: {editing_id}") + try: + result = await db.execute(select(EditingSessionDb).where(EditingSessionDb.id == editing_id)) + db_session = result.scalar_one_or_none() + + if db_session is None: + LOGGER.warning(f"Editing session not found: {editing_id}") + return JSONResponse( + content={"error": "Editing session not found"}, + status_code=404, + ) + + if not db_session.queue_url: + LOGGER.info(f"No queue URL for session {editing_id}, nothing to clean up") + return JSONResponse( + content={"message": "No queue to clean up"}, + status_code=200, + ) + + success = await delete_session_queue(db_session.queue_url) + if success: + db_session.queue_url = None + db_session.updated_at = datetime.now(UTC) + await db.commit() + + LOGGER.info(f"Successfully cleaned up queue for session {editing_id}") + return JSONResponse( + content={"message": "Queue cleaned up successfully"}, + status_code=200, + ) + else: + LOGGER.warning(f"Failed to delete queue for session {editing_id}") + return JSONResponse( + content={"error": "Failed to delete queue"}, + status_code=500, + ) + + except Exception as e: + LOGGER.exception(f"Error cleaning up queue for session {editing_id}: {e}") + return JSONResponse( + content={"error": "Failed to clean up queue", "details": str(e)}, + status_code=500, + ) diff --git a/servers/fai/src/fai/routes/slack.py b/servers/fai/src/fai/routes/slack.py index 3e9189851a..bdad353192 100644 --- a/servers/fai/src/fai/routes/slack.py +++ b/servers/fai/src/fai/routes/slack.py @@ -64,7 +64,7 @@ ) from fai.utils.slack.response_qa import log_message_for_qa -MESSAGE_CACHE_TTL = 30 +MESSAGE_CACHE_TTL = 600 async def cleanup_message_cache() -> None: diff --git a/servers/fai/src/fai/utils/slack/edit_handler.py b/servers/fai/src/fai/utils/slack/edit_handler.py index 819954386b..c7a7f0422e 100644 --- a/servers/fai/src/fai/utils/slack/edit_handler.py +++ b/servers/fai/src/fai/utils/slack/edit_handler.py @@ -1,6 +1,8 @@ -import asyncio import json -import logging +from datetime import ( + UTC, + datetime, +) import aioboto3 import httpx @@ -12,11 +14,14 @@ from fai.models.types.editing_session_types import EditingSessionStatus from fai.settings import ( CONFIG, + LOGGER, VARIABLES, ) from fai.utils.github_utils import get_repo_from_docs_domain - -logger = logging.getLogger(__name__) +from fai.utils.sqs_utils import ( + ResumeMessage, + send_message_to_queue, +) async def get_or_create_editing_session_for_thread(team_id: str, channel_id: str, thread_ts: str) -> str | None: @@ -31,7 +36,7 @@ async def get_or_create_editing_session_for_thread(team_id: str, channel_id: str existing = result.scalar_one_or_none() if existing: - logger.info(f"Found existing editing session {existing.editing_id} for thread {thread_ts}") + LOGGER.info(f"Found existing editing session {existing.editing_id} for thread {thread_ts}") return existing.editing_id return None @@ -39,11 +44,6 @@ async def get_or_create_editing_session_for_thread(team_id: str, channel_id: str async def store_editing_session_for_thread(team_id: str, channel_id: str, thread_ts: str, editing_id: str) -> None: async with async_session_maker() as session: - from datetime import ( - UTC, - datetime, - ) - result = await session.execute( select(SlackEditingSessionDb).where( SlackEditingSessionDb.team_id == team_id, @@ -55,7 +55,7 @@ async def store_editing_session_for_thread(team_id: str, channel_id: str, thread if existing: existing.updated_at = datetime.now(UTC) - logger.info(f"Updated editing session {editing_id} for thread {thread_ts}") + LOGGER.info(f"Updated editing session {editing_id} for thread {thread_ts}") else: slack_editing = SlackEditingSessionDb( team_id=team_id, @@ -66,22 +66,21 @@ async def store_editing_session_for_thread(team_id: str, channel_id: str, thread updated_at=datetime.now(UTC), ) session.add(slack_editing) - logger.info(f"Stored new editing session {editing_id} for thread {thread_ts}") + LOGGER.info(f"Stored new editing session {editing_id} for thread {thread_ts}") await session.commit() async def get_editing_session_status(editing_id: str) -> EditingSessionStatus | None: - """Get the current status of an editing session.""" try: async with httpx.AsyncClient(timeout=10.0) as client: response = await client.get(f"{CONFIG.FAI_SERVER_URL}/editing-sessions/{editing_id}") if response.status_code == 404: - logger.warning(f"Editing session not found: {editing_id}") + LOGGER.warning(f"Editing session not found: {editing_id}") return None elif response.status_code != 200: - logger.error(f"Failed to fetch editing session: {response.status_code} - {response.text}") + LOGGER.error(f"Failed to fetch editing session: {response.status_code} - {response.text}") return None session_data = response.json() @@ -89,89 +88,66 @@ async def get_editing_session_status(editing_id: str) -> EditingSessionStatus | return EditingSessionStatus(status_str) except Exception as e: - logger.error(f"Error fetching editing session status: {e}", exc_info=True) + LOGGER.error(f"Error fetching editing session status: {e}", exc_info=True) return None async def interrupt_editing_session(editing_id: str) -> bool: - """Interrupt an active editing session. Returns True if successful.""" try: - status = await get_editing_session_status(editing_id) - - if status == EditingSessionStatus.STARTUP: - logger.info(f"Session {editing_id} is in STARTUP state, waiting for it to become ACTIVE") - max_wait: int = 30 - poll_interval: float = 0.5 - elapsed: float = 0 - - while elapsed < max_wait: - await asyncio.sleep(poll_interval) - elapsed += poll_interval - - status = await get_editing_session_status(editing_id) - if status == EditingSessionStatus.ACTIVE: - logger.info(f"Session {editing_id} transitioned to ACTIVE, proceeding with interruption") - break - elif status not in [EditingSessionStatus.STARTUP, EditingSessionStatus.ACTIVE]: - logger.warning(f"Session {editing_id} in unexpected state {status}, aborting interruption") - return False - - if status != EditingSessionStatus.ACTIVE: - logger.warning(f"Timeout waiting for session {editing_id} to become ACTIVE") - return False - async with httpx.AsyncClient(timeout=10.0) as client: response = await client.post(f"{CONFIG.FAI_SERVER_URL}/editing-sessions/{editing_id}/interrupt") if response.status_code == 200: - logger.info(f"Successfully interrupted editing session: {editing_id}") + LOGGER.info(f"Successfully sent interrupt message for session: {editing_id}") return True elif response.status_code == 400: - logger.warning(f"Cannot interrupt session {editing_id} (not in ACTIVE state)") + LOGGER.warning(f"Cannot interrupt session {editing_id} (not in ACTIVE state)") return False elif response.status_code == 404: - logger.warning(f"Editing session not found for interruption: {editing_id}") + LOGGER.warning(f"Editing session not found for interruption: {editing_id}") return False else: - logger.error(f"Failed to interrupt editing session: {response.status_code} - {response.text}") + LOGGER.error(f"Failed to interrupt editing session: {response.status_code} - {response.text}") return False except Exception as e: - logger.error(f"Error interrupting editing session: {e}", exc_info=True) + LOGGER.error(f"Error interrupting editing session: {e}", exc_info=True) return False -async def wait_for_interruption(editing_id: str, max_wait_seconds: int = 30, poll_interval: float = 0.5) -> bool: - """ - Wait for an editing session to transition out of INTERRUPTED status. - Returns True if session is ready (WAITING), False if timeout or error. - Polls every poll_interval seconds for up to max_wait_seconds. - """ - logger.info(f"Waiting for session {editing_id} to complete interruption (max {max_wait_seconds}s)") - start_time = asyncio.get_event_loop().time() +async def send_resume_message(editing_id: str, prompts: list[str]) -> bool: + try: + async with httpx.AsyncClient(timeout=10.0) as client: + response = await client.get(f"{CONFIG.FAI_SERVER_URL}/editing-sessions/{editing_id}") - while True: - elapsed = asyncio.get_event_loop().time() - start_time - if elapsed >= max_wait_seconds: - logger.warning(f"Timeout waiting for session {editing_id} interruption to complete") - return False + if response.status_code != 200: + LOGGER.error(f"Failed to fetch session for RESUME message: {response.status_code}") + return False - status = await get_editing_session_status(editing_id) + session_data = response.json() + queue_url = session_data["editing_session"].get("queue_url") - if status is None: - logger.error(f"Failed to get status for session {editing_id} during interruption wait") - return False + if not queue_url: + LOGGER.error(f"No queue_url found for session {editing_id}, cannot send RESUME message") + return False - if status == EditingSessionStatus.WAITING: - logger.info(f"Session {editing_id} is now WAITING after interruption") - return True + resume_msg = ResumeMessage( + editing_id=editing_id, + prompts=prompts, + timestamp=datetime.now(UTC).isoformat(), + ) - if status in [EditingSessionStatus.INTERRUPTED, EditingSessionStatus.ACTIVE]: - logger.debug(f"Session {editing_id} still in {status} state, waiting...") - await asyncio.sleep(poll_interval) + success = await send_message_to_queue(queue_url, resume_msg) + if success: + LOGGER.info(f"Sent RESUME message with {len(prompts)} prompts to queue for session {editing_id}") else: - logger.warning(f"Unexpected status {status} for session {editing_id} during interruption wait") - return False + LOGGER.error(f"Failed to send RESUME message to queue for session {editing_id}") + + return success + + except Exception as e: + LOGGER.error(f"Error sending RESUME message for session {editing_id}: {e}", exc_info=True) + return False async def create_editing_session(repository: str, base_branch: str = "main") -> str | None: @@ -187,16 +163,16 @@ async def create_editing_session(repository: str, base_branch: str = "main") -> ) if response.status_code != 201: - logger.error(f"Failed to create editing session: {response.status_code} - {response.text}") + LOGGER.error(f"Failed to create editing session: {response.status_code} - {response.text}") return None session_data = response.json() editing_id = session_data["editing_session"]["id"] - logger.info(f"Created new editing session: {editing_id} for repository: {repository}") + LOGGER.info(f"Created new editing session: {editing_id} for repository: {repository}") return editing_id except Exception as e: - logger.error(f"Error creating editing session: {e}", exc_info=True) + LOGGER.error(f"Error creating editing session: {e}", exc_info=True) return None @@ -210,18 +186,18 @@ async def invoke_editing_lambda( thread_ts: str | None = None, ) -> dict[str, str] | None: if not VARIABLES.FAI_LAMBDA_FUNCTION_NAME: - logger.warning("FAI_LAMBDA_FUNCTION_NAME not configured. Skipping Lambda invocation.") + LOGGER.warning("FAI_LAMBDA_FUNCTION_NAME not configured. Skipping Lambda invocation.") return None repository = await get_repo_from_docs_domain(domain) if not repository: - logger.warning( + LOGGER.warning( f"No GitHub repository found for domain '{domain}'. Skipping Lambda invocation. " f"Please ensure the domain is registered in FDR with a connected GitHub repository." ) return None - logger.info(f"Resolved domain '{domain}' to repository '{repository}'") + LOGGER.info(f"Resolved domain '{domain}' to repository '{repository}'") try: session = aioboto3.Session() @@ -234,14 +210,14 @@ async def invoke_editing_lambda( if editing_id: body_payload["editing_id"] = editing_id - logger.info(f"Resuming editing session: {editing_id}") + LOGGER.info(f"Resuming editing session: {editing_id}") if team_id and channel_id and thread_ts: callback_url = ( f"{CONFIG.FAI_SERVER_URL}/scribe/callback/slack/edit/" f"{team_id}/{channel_id}/{thread_ts}" ) body_payload["callback_url"] = callback_url - logger.info(f"Including callback URL in Lambda invocation: {callback_url}") + LOGGER.info(f"Including callback URL in Lambda invocation: {callback_url}") payload = {"body": json.dumps(body_payload)} response = await lambda_client.invoke( @@ -249,7 +225,7 @@ async def invoke_editing_lambda( InvocationType="Event", Payload=json.dumps(payload), ) - logger.info( + LOGGER.info( f"Successfully invoked FAI Lambda for editing. " f"StatusCode: {response.get('StatusCode')}, " f"Domain: {domain}, " @@ -260,11 +236,11 @@ async def invoke_editing_lambda( return {"status": "invoked", "repository": repository} except ClientError as e: - logger.error( + LOGGER.error( f"Failed to invoke FAI Lambda: {e.response['Error']['Code']} - {e.response['Error']['Message']}", exc_info=True, ) return None except Exception as e: - logger.error(f"Unexpected error invoking FAI Lambda: {str(e)}", exc_info=True) + LOGGER.error(f"Unexpected error invoking FAI Lambda: {str(e)}", exc_info=True) return None diff --git a/servers/fai/src/fai/utils/slack/message_handler.py b/servers/fai/src/fai/utils/slack/message_handler.py index 3c8af9e923..a1a7694452 100644 --- a/servers/fai/src/fai/utils/slack/message_handler.py +++ b/servers/fai/src/fai/utils/slack/message_handler.py @@ -1,4 +1,3 @@ -import asyncio from dataclasses import dataclass from datetime import ( UTC, @@ -44,8 +43,8 @@ get_or_create_editing_session_for_thread, interrupt_editing_session, invoke_editing_lambda, + send_resume_message, store_editing_session_for_thread, - wait_for_interruption, ) from fai.utils.slack.lambda_invoke import invoke_fai_lambda_for_docs_update from fai.utils.slack.postprocessing import SlackifyMarkdown @@ -385,44 +384,58 @@ async def handle_slack_message( LOGGER.info(f"Existing session {existing_editing_id} status: {session_status}") if session_status and session_status.value == "interrupted": - LOGGER.warning( - f"Session {existing_editing_id} is already being interrupted by another request. " - "Not proceeding to avoid race condition." - ) - error_text = ( - "⏳ Previous edit session is currently being interrupted by another request. " - "Please wait a moment and try again." - ) + LOGGER.info(f"Session {existing_editing_id} is INTERRUPTED, queuing new request via RESUME message") + + await send_resume_message(existing_editing_id, [context.text]) + + info_text = "⏳ Previous edit is being interrupted, your request has been queued." return SlackMessageResponse( - response_text=error_text, + response_text=info_text, channel=context.channel, thread_ts=context.thread_ts, bot_token=integration.slack_bot_token, query_id=None, user_id=context.user, ) - elif session_status and session_status.value == "startup": - LOGGER.info(f"Session {existing_editing_id} is in STARTUP state, waiting for it to become ACTIVE") - max_wait: int = 30 - poll_interval: float = 0.5 - elapsed: float = 0 - - while elapsed < max_wait: - await asyncio.sleep(poll_interval) - elapsed += poll_interval - - session_status = await get_editing_session_status(existing_editing_id) - if session_status and session_status.value in ["active", "waiting"]: - LOGGER.info( - f"Session {existing_editing_id} transitioned to {session_status.value}, " - f"proceeding with interruption" - ) - break - elif session_status and session_status.value not in ["startup", "active", "waiting"]: - LOGGER.warning( - f"Session {existing_editing_id} in unexpected state {session_status.value}, aborting" + + if session_status and session_status.value in ["startup", "active", "waiting"]: + if session_status.value == "active": + LOGGER.info(f"Interrupting active session {existing_editing_id} for new edit request") + interrupt_success = await interrupt_editing_session(existing_editing_id) + + if not interrupt_success: + LOGGER.error(f"Failed to interrupt session {existing_editing_id}") + error_text = "❌ Failed to interrupt the previous edit session. Please try again." + return SlackMessageResponse( + response_text=error_text, + channel=context.channel, + thread_ts=context.thread_ts, + bot_token=integration.slack_bot_token, + query_id=None, + user_id=context.user, ) - error_text = "❌ Session is in an unexpected state. Please try again." + + if integration.slack_bot_token: + try: + await send_slack_message( + channel=context.channel, + text="⚠️ Interrupted previous edit session to process new request", + bot_token=integration.slack_bot_token, + thread_ts=context.thread_ts, + ) + LOGGER.info(f"Posted interruption message to thread {context.thread_ts}") + except Exception as e: + LOGGER.error(f"Failed to post interruption message: {e}", exc_info=True) + + await send_resume_message(existing_editing_id, [context.text]) + + elif session_status.value == "startup": + LOGGER.info(f"Interrupting STARTUP session {existing_editing_id} for new edit request") + interrupt_success = await interrupt_editing_session(existing_editing_id) + + if not interrupt_success: + LOGGER.error(f"Failed to interrupt STARTUP session {existing_editing_id}") + error_text = "❌ Failed to interrupt the previous edit session. Please try again." return SlackMessageResponse( response_text=error_text, channel=context.channel, @@ -432,11 +445,22 @@ async def handle_slack_message( user_id=context.user, ) - if not session_status or session_status.value not in ["active", "waiting"]: - LOGGER.warning(f"Timeout waiting for session {existing_editing_id} to become ACTIVE/WAITING") - error_text = "⏳ Previous edit session is still starting up. Please wait a moment and try again." + if integration.slack_bot_token: + try: + await send_slack_message( + channel=context.channel, + text="⚠️ Interrupted previous edit session to process new request", + bot_token=integration.slack_bot_token, + thread_ts=context.thread_ts, + ) + LOGGER.info(f"Posted interruption message to thread {context.thread_ts}") + except Exception as e: + LOGGER.error(f"Failed to post interruption message: {e}", exc_info=True) + + await send_resume_message(existing_editing_id, [context.text]) + return SlackMessageResponse( - response_text=error_text, + response_text="", channel=context.channel, thread_ts=context.thread_ts, bot_token=integration.slack_bot_token, @@ -444,58 +468,32 @@ async def handle_slack_message( user_id=context.user, ) - if session_status and session_status.value in ["active", "waiting"]: - LOGGER.info(f"Interrupting {session_status.value} session {existing_editing_id} for new edit request") - - interrupt_success = await interrupt_editing_session(existing_editing_id) - - if interrupt_success: - LOGGER.info(f"Waiting for session {existing_editing_id} interruption to complete") - wait_success = await wait_for_interruption(existing_editing_id) - - if wait_success: - if integration.slack_bot_token: - interruption_msg = "⚠️ Interrupted previous edit session to process new request" - try: - await send_slack_message( - channel=context.channel, - text=interruption_msg, - bot_token=integration.slack_bot_token, - thread_ts=context.thread_ts, - ) - LOGGER.info(f"Posted interruption message to thread {context.thread_ts}") - except Exception as e: - LOGGER.error(f"Failed to post interruption message: {e}", exc_info=True) - else: - LOGGER.error( - f"Timeout waiting for interruption of session {existing_editing_id}. " - "Not proceeding with new request to avoid duplicate processing." - ) - error_text = ( - "⏳ Previous edit session is still processing and couldn't be interrupted " - "within 30 seconds. Please wait a moment and try again." - ) - return SlackMessageResponse( - response_text=error_text, - channel=context.channel, - thread_ts=context.thread_ts, - bot_token=integration.slack_bot_token, - query_id=None, - user_id=context.user, - ) elif session_status.value == "waiting": - LOGGER.info(f"Session {existing_editing_id} is WAITING and ready to resume") - else: - LOGGER.error(f"Failed to interrupt session {existing_editing_id}. Cannot proceed safely.") - error_text = "❌ Failed to interrupt the previous edit session. Please try again." + LOGGER.info(f"Session {existing_editing_id} is WAITING, " f"queuing request without interrupting") + + await send_resume_message(existing_editing_id, [context.text]) + + if integration.slack_bot_token: + try: + await send_slack_message( + channel=context.channel, + text="⏳ Previous edit is waiting - your request has been queued", + bot_token=integration.slack_bot_token, + thread_ts=context.thread_ts, + ) + LOGGER.info(f"Posted queued message to thread {context.thread_ts}") + except Exception as e: + LOGGER.error(f"Failed to post queued message: {e}", exc_info=True) + return SlackMessageResponse( - response_text=error_text, + response_text="", channel=context.channel, thread_ts=context.thread_ts, bot_token=integration.slack_bot_token, query_id=None, user_id=context.user, ) + else: status_str = session_status.value if session_status else "unknown" LOGGER.error(f"Unexpected session status: {status_str} for session {existing_editing_id}") diff --git a/servers/fai/src/fai/utils/sqs_utils.py b/servers/fai/src/fai/utils/sqs_utils.py new file mode 100644 index 0000000000..a9e0fe3c1b --- /dev/null +++ b/servers/fai/src/fai/utils/sqs_utils.py @@ -0,0 +1,196 @@ +import json +from enum import Enum +from typing import Any + +import aioboto3 +from botocore.exceptions import ClientError + +from ..settings import LOGGER + + +class SQSMessageType(str, Enum): + INTERRUPT = "INTERRUPT" + RESUME = "RESUME" + TERMINATE = "TERMINATE" + + +class SQSMessage: + def __init__(self, message_type: SQSMessageType, editing_id: str, timestamp: str): + self.message_type = message_type + self.editing_id = editing_id + self.timestamp = timestamp + + def to_dict(self) -> dict[str, Any]: + return { + "type": self.message_type.value, + "editing_id": self.editing_id, + "timestamp": self.timestamp, + } + + +class InterruptMessage(SQSMessage): + def __init__(self, editing_id: str, timestamp: str): + super().__init__(SQSMessageType.INTERRUPT, editing_id, timestamp) + + +class ResumeMessage(SQSMessage): + def __init__(self, editing_id: str, prompts: list[str], timestamp: str): + super().__init__(SQSMessageType.RESUME, editing_id, timestamp) + self.prompts = prompts + + def to_dict(self) -> dict[str, Any]: + data = super().to_dict() + data["prompts"] = self.prompts + return data + + +class TerminateMessage(SQSMessage): + def __init__(self, editing_id: str, reason: str, timestamp: str): + super().__init__(SQSMessageType.TERMINATE, editing_id, timestamp) + self.reason = reason + + def to_dict(self) -> dict[str, Any]: + data = super().to_dict() + data["reason"] = self.reason + return data + + +async def create_session_queue(editing_id: str, region: str = "us-east-1") -> str | None: + queue_name = f"editing-session-{editing_id}.fifo" + + try: + session = aioboto3.Session() + async with session.client("sqs", region_name=region) as sqs_client: + response = await sqs_client.create_queue( + QueueName=queue_name, + Attributes={ + "FifoQueue": "true", + "ContentBasedDeduplication": "true", + "MessageRetentionPeriod": "3600", + "VisibilityTimeout": "300", + }, + ) + queue_url = response["QueueUrl"] + LOGGER.info(f"Created SQS FIFO queue for session {editing_id}: {queue_url}") + return queue_url + + except ClientError as e: + LOGGER.error(f"Failed to create SQS queue for session {editing_id}: {e.response['Error']['Message']}") + return None + except Exception as e: + LOGGER.error(f"Unexpected error creating SQS queue for session {editing_id}: {str(e)}", exc_info=True) + return None + + +async def delete_session_queue(queue_url: str) -> bool: + try: + session = aioboto3.Session() + async with session.client("sqs") as sqs_client: + await sqs_client.delete_queue(QueueUrl=queue_url) + LOGGER.info(f"Deleted SQS queue: {queue_url}") + return True + + except ClientError as e: + LOGGER.error(f"Failed to delete SQS queue {queue_url}: {e.response['Error']['Message']}") + return False + except Exception as e: + LOGGER.error(f"Unexpected error deleting SQS queue {queue_url}: {str(e)}", exc_info=True) + return False + + +async def send_message_to_queue(queue_url: str, message: SQSMessage) -> bool: + try: + session = aioboto3.Session() + async with session.client("sqs") as sqs_client: + message_body = json.dumps(message.to_dict()) + + await sqs_client.send_message( + QueueUrl=queue_url, + MessageBody=message_body, + MessageGroupId=message.editing_id, + ) + + LOGGER.info(f"Sent {message.message_type.value} message to queue {queue_url}") + return True + + except ClientError as e: + LOGGER.error(f"Failed to send message to queue {queue_url}: {e.response['Error']['Message']}", exc_info=True) + return False + except Exception as e: + LOGGER.error(f"Unexpected error sending message to queue {queue_url}: {str(e)}", exc_info=True) + return False + + +async def receive_messages_from_queue( + queue_url: str, max_messages: int = 10, wait_time_seconds: int = 0 +) -> list[dict[str, Any]]: + try: + session = aioboto3.Session() + async with session.client("sqs") as sqs_client: + response = await sqs_client.receive_message( + QueueUrl=queue_url, + MaxNumberOfMessages=min(max_messages, 10), + WaitTimeSeconds=wait_time_seconds, + AttributeNames=["All"], + ) + + messages = response.get("Messages", []) + parsed_messages = [] + + for msg in messages: + try: + body = json.loads(msg["Body"]) + parsed_messages.append( + { + "body": body, + "receipt_handle": msg["ReceiptHandle"], + "message_id": msg["MessageId"], + } + ) + except json.JSONDecodeError: + LOGGER.warning(f"Failed to parse message body as JSON: {msg['Body']}") + continue + + if parsed_messages: + LOGGER.info(f"Received {len(parsed_messages)} messages from queue {queue_url}") + + return parsed_messages + + except ClientError as e: + LOGGER.error(f"Failed to receive messages from queue {queue_url}: {e.response['Error']['Message']}") + return [] + except Exception as e: + LOGGER.error(f"Unexpected error receiving messages from queue {queue_url}: {str(e)}", exc_info=True) + return [] + + +async def delete_message_from_queue(queue_url: str, receipt_handle: str) -> bool: + try: + session = aioboto3.Session() + async with session.client("sqs") as sqs_client: + await sqs_client.delete_message(QueueUrl=queue_url, ReceiptHandle=receipt_handle) + LOGGER.debug(f"Deleted message from queue {queue_url}") + return True + + except ClientError as e: + LOGGER.error(f"Failed to delete message from queue {queue_url}: {e.response['Error']['Message']}") + return False + except Exception as e: + LOGGER.error(f"Unexpected error deleting message from queue {queue_url}: {str(e)}", exc_info=True) + return False + + +async def purge_queue(queue_url: str) -> bool: + try: + session = aioboto3.Session() + async with session.client("sqs") as sqs_client: + await sqs_client.purge_queue(QueueUrl=queue_url) + LOGGER.info(f"Purged all messages from queue {queue_url}") + return True + + except ClientError as e: + LOGGER.error(f"Failed to purge queue {queue_url}: {e.response['Error']['Message']}") + return False + except Exception as e: + LOGGER.error(f"Unexpected error purging queue {queue_url}: {str(e)}", exc_info=True) + return False