11import os
22from typing import List , Optional
3- import requests
3+
44import aiohttp
5+ import requests
56from runpod import RunPodLogger
7+ from utils .retry import retry_with_backoff , async_retry_with_backoff
8+
69from integrations .directus_client import get_directus_token
710
811logger = RunPodLogger ()
912
1013
14+ def _make_rag_request (url : str , payload : dict , headers : dict ) -> str :
15+ """
16+ Helper function to make the actual RAG API request.
17+
18+ Args:
19+ url: The API endpoint URL
20+ payload: The request payload
21+ headers: The request headers
22+
23+ Returns:
24+ str: RAG prompt string from the server
25+
26+ Raises:
27+ Exception: If the API call fails
28+ """
29+ logger .debug (f"Making RAG API request to { url } " )
30+ response = requests .post (url , json = payload , headers = headers , timeout = 120 )
31+ response .raise_for_status ()
32+
33+ result = response .text
34+ logger .debug ("Successfully retrieved RAG prompt" )
35+ return result
36+
37+
1138def get_rag_prompt (
1239 query : str , segment_ids : Optional [List [str ]] = None , rag_server_url : Optional [str ] = None
1340) -> str :
1441 """
15- Retrieve RAG prompt by calling the external RAG server API.
42+ Retrieve RAG prompt by calling the external RAG server API with retry logic .
1643
1744 Args:
1845 query: The query string to send to the RAG server
@@ -24,7 +51,7 @@ def get_rag_prompt(
2451
2552 Raises:
2653 ValueError: If RAG_SERVER_URL is not set and no URL is provided
27- Exception: If the API call fails
54+ Exception: If the API call fails after all retries
2855 """
2956 if rag_server_url is None :
3057 rag_server_url = os .getenv ("RAG_SERVER_URL" )
@@ -48,28 +75,58 @@ def get_rag_prompt(
4875
4976 headers = {"Content-Type" : "application/json" }
5077 headers ["Authorization" ] = f"Bearer { get_directus_token ()} "
51- try :
52- logger .debug (f"Making RAG API request to { url } " )
53- response = requests .post (url , json = payload , headers = headers , timeout = 120 )
54- response .raise_for_status ()
55-
56- result = response .text
57- logger .debug ("Successfully retrieved RAG prompt" )
58- return result
5978
60- except requests .exceptions .RequestException as e :
61- logger .error (f"Error calling API: { e } " )
79+ try :
80+ return retry_with_backoff (
81+ _make_rag_request ,
82+ max_retries = 3 ,
83+ initial_delay = 2 ,
84+ backoff_factor = 2 ,
85+ jitter = 0.5 ,
86+ logger = logger ,
87+ url = url ,
88+ payload = payload ,
89+ headers = headers ,
90+ )
91+ except Exception as e :
92+ logger .error (f"Error calling API after all retries: { e } " )
6293 if hasattr (e , "response" ) and e .response is not None :
6394 logger .error (f"Response status: { e .response .status_code } " )
6495 logger .error (f"Response text: { e .response .text } " )
6596 raise Exception (f"Failed to get RAG prompt from server: { str (e )} " ) from e
6697
6798
99+ async def _make_rag_request_async (url : str , payload : dict , headers : dict ) -> str :
100+ """
101+ Helper function to make the actual async RAG API request.
102+
103+ Args:
104+ url: The API endpoint URL
105+ payload: The request payload
106+ headers: The request headers
107+
108+ Returns:
109+ str: RAG prompt string from the server
110+
111+ Raises:
112+ Exception: If the API call fails
113+ """
114+ logger .debug (f"Making async RAG API request to { url } " )
115+ async with aiohttp .ClientSession () as session :
116+ async with session .post (
117+ url , json = payload , headers = headers , timeout = aiohttp .ClientTimeout (total = 120 )
118+ ) as response :
119+ response .raise_for_status ()
120+ result = await response .text ()
121+ logger .debug ("Successfully retrieved RAG prompt" )
122+ return result
123+
124+
68125async def get_rag_prompt_async (
69126 query : str , segment_ids : Optional [List [str ]] = None , rag_server_url : Optional [str ] = None
70127) -> str :
71128 """
72- Async version of get_rag_prompt for parallel processing.
129+ Async version of get_rag_prompt for parallel processing with retry logic .
73130 """
74131 if rag_server_url is None :
75132 rag_server_url = os .getenv ("RAG_SERVER_URL" )
@@ -95,16 +152,17 @@ async def get_rag_prompt_async(
95152 headers ["Authorization" ] = f"Bearer { get_directus_token ()} "
96153
97154 try :
98- logger .debug (f"Making async RAG API request to { url } " )
99- async with aiohttp .ClientSession () as session :
100- async with session .post (
101- url , json = payload , headers = headers , timeout = aiohttp .ClientTimeout (total = 120 )
102- ) as response :
103- response .raise_for_status ()
104- result = await response .text ()
105- logger .debug ("Successfully retrieved RAG prompt" )
106- return result
107-
155+ return await async_retry_with_backoff (
156+ _make_rag_request_async ,
157+ max_retries = 3 ,
158+ initial_delay = 2 ,
159+ backoff_factor = 2 ,
160+ jitter = 0.5 ,
161+ logger = logger ,
162+ url = url ,
163+ payload = payload ,
164+ headers = headers ,
165+ )
108166 except Exception as e :
109- logger .error (f"Error calling API: { e } " )
167+ logger .error (f"Error calling API after all retries : { e } " )
110168 raise Exception (f"Failed to get RAG prompt from server: { str (e )} " ) from e
0 commit comments