1616 LOGGER ,
1717 SETTINGS ,
1818)
19+ from .sqs_client import SQSClient
1920from .system_prompts import EDITING_SYSTEM_PROMPT
2021
2122GITHUB_PR_URL_PATTERN = re .compile (r"(?:https?://)?(?:www\.)?github\.com/([^/]+)/([^/]+)/pull/(\d+)" , re .IGNORECASE )
2223
2324
2425class SessionInterruptedError (Exception ):
25- """Raised when an editing session is interrupted."""
26-
2726 pass
2827
2928
@@ -45,7 +44,6 @@ async def update_session_status(editing_id: str, status: str) -> bool:
4544
4645
4746async def update_session_metadata (editing_id : str , session_id : str | None = None , pr_url : str | None = None ) -> bool :
48- """Update session metadata (session_id and/or pr_url) immediately when available."""
4947 try :
5048 payload = {}
5149 if session_id is not None :
@@ -69,19 +67,15 @@ async def update_session_metadata(editing_id: str, session_id: str | None = None
6967 return False
7068
7169
72- async def check_if_interrupted ( editing_id : str ) -> bool :
70+ def check_if_interrupted_via_sqs ( sqs_client : SQSClient ) -> tuple [ bool , str | None ] :
7371 try :
74- async with httpx .AsyncClient (timeout = 5.0 ) as client :
75- response = await client .get (f"{ SETTINGS .FAI_API_URL } /editing-sessions/{ editing_id } " )
76- if response .status_code == 200 :
77- session_data = response .json ()
78- return session_data ["editing_session" ]["status" ] == "interrupted"
79- else :
80- LOGGER .warning (f"Failed to check interrupted status: { response .status_code } " )
81- return False
72+ is_interrupted , receipt_handle = sqs_client .has_interrupt_message ()
73+ if is_interrupted :
74+ LOGGER .info ("Received INTERRUPT message from SQS queue" )
75+ return is_interrupted , receipt_handle
8276 except Exception as e :
83- LOGGER .warning (f"Error checking interrupted status : { e } " )
84- return False
77+ LOGGER .warning (f"Error checking SQS for interruption : { e } " )
78+ return False , None
8579
8680
8781async def run_editing_session (
@@ -90,6 +84,7 @@ async def run_editing_session(
9084 base_branch : str ,
9185 working_branch : str ,
9286 editing_id : str ,
87+ queue_url : str ,
9388 resume_session_id : str | None = None ,
9489 existing_pr_url : str | None = None ,
9590) -> tuple [str , str | None ]:
@@ -134,6 +129,9 @@ async def run_editing_session(
134129 session_id = resume_session_id
135130 pr_url = existing_pr_url
136131
132+ sqs_client = SQSClient (queue_url )
133+ LOGGER .info (f"Initialized SQS client for queue: { queue_url } " )
134+
137135 async with ClaudeSDKClient (options = options ) as client :
138136 await client .query (full_prompt )
139137
@@ -152,11 +150,14 @@ async def run_editing_session(
152150 await update_session_status (editing_id , "active" )
153151 LOGGER .info (f"Transitioned new session { editing_id } from STARTUP to ACTIVE" )
154152
155- if await check_if_interrupted (editing_id ):
156- LOGGER .warning (f"Session interrupted: { editing_id } " )
157- if session_id is not None :
153+ if session_id is not None :
154+ is_interrupted , receipt_handle = check_if_interrupted_via_sqs (sqs_client )
155+ if is_interrupted :
156+ LOGGER .warning (f"Session interrupted via SQS: { editing_id } " )
157+ if receipt_handle :
158+ sqs_client .delete_message (receipt_handle )
158159 await update_session_status (editing_id , "waiting" )
159- raise SessionInterruptedError (f"Editing session { editing_id } was interrupted" )
160+ raise SessionInterruptedError (f"Editing session { editing_id } was interrupted" )
160161
161162 if isinstance (message , AssistantMessage ):
162163 for block in message .content :
@@ -185,6 +186,49 @@ async def run_editing_session(
185186 if session_id is None :
186187 raise RuntimeError ("Failed to obtain session ID from Claude" )
187188
189+ LOGGER .info (f"Checking for RESUME messages in queue for session { editing_id } " )
190+ resume_messages = sqs_client .get_resume_messages ()
191+
192+ if resume_messages :
193+ LOGGER .info (f"Found { len (resume_messages )} RESUME messages, batching prompts" )
194+
195+ all_prompts = []
196+ for resume_msg in resume_messages :
197+ body = resume_msg ["body" ]
198+ prompts = body .get ("prompts" , [])
199+ if prompts :
200+ all_prompts .extend (prompts )
201+ sqs_client .delete_message (resume_msg ["receipt_handle" ])
202+
203+ if all_prompts :
204+ batched_prompt = "\n \n " .join ([f"Request { i + 1 } : { p } " for i , p in enumerate (all_prompts )])
205+ LOGGER .info (f"Processing { len (all_prompts )} batched prompts: { batched_prompt [:200 ]} ..." )
206+
207+ async with ClaudeSDKClient (options = options ) as client :
208+ await client .query (batched_prompt )
209+
210+ async for msg in client .receive_response ():
211+ is_interrupted , receipt_handle = check_if_interrupted_via_sqs (sqs_client )
212+ if is_interrupted :
213+ LOGGER .warning (f"Session interrupted while processing batched prompts: { editing_id } " )
214+ if receipt_handle :
215+ sqs_client .delete_message (receipt_handle )
216+ await update_session_status (editing_id , "waiting" )
217+ raise SessionInterruptedError (f"Editing session { editing_id } was interrupted" )
218+
219+ if isinstance (msg , AssistantMessage ):
220+ for block in msg .content :
221+ if isinstance (block , TextBlock ):
222+ for line in block .text .split ("\n " ):
223+ if "PR_URL:" in line :
224+ extracted_text = line .split ("PR_URL:" , 1 )[1 ].strip ()
225+ if extracted_text :
226+ match = GITHUB_PR_URL_PATTERN .search (extracted_text )
227+ if match :
228+ pr_url = match .group (0 )
229+ LOGGER .info (f"Updated PR URL: { pr_url } " )
230+ await update_session_metadata (editing_id , pr_url = pr_url )
231+
188232 await update_session_status (editing_id , "waiting" )
189233
190234 return session_id , pr_url
0 commit comments