55
66import httpx
77
8- from cogstack_model_gateway_client .exceptions import retry_if_network_error
8+ from cogstack_model_gateway_client .exceptions import TaskFailedError , retry_if_network_error
99
1010
1111class GatewayClient :
@@ -103,7 +103,11 @@ async def submit_task(
103103 wait_for_completion : bool = False ,
104104 return_result : bool = True ,
105105 ):
106- """Submit a task to the Gateway and return the task info."""
106+ """Submit a task to the Gateway and return the task info.
107+
108+ Raises:
109+ TaskFailedError: If the task fails and wait_for_completion=True, return_result=True.
110+ """
107111 model_name = model_name or self .default_model
108112 if not model_name :
109113 raise ValueError ("Please provide a model name or set a default model for the client." )
@@ -118,7 +122,11 @@ async def submit_task(
118122 task_uuid = task_info ["uuid" ]
119123 task_info = await self .wait_for_task (task_uuid )
120124 if return_result :
121- return await self .get_task_result (task_uuid )
125+ if task_info .get ("status" ) == "succeeded" :
126+ return await self .get_task_result (task_uuid )
127+ else :
128+ error_message = task_info .get ("error_message" , "Unknown error" )
129+ raise TaskFailedError (task_uuid , error_message , task_info )
122130 return task_info
123131
124132 async def process (
@@ -128,7 +136,11 @@ async def process(
128136 wait_for_completion : bool = True ,
129137 return_result : bool = True ,
130138 ):
131- """Generate annotations for the provided text."""
139+ """Generate annotations for the provided text.
140+
141+ Raises:
142+ TaskFailedError: If the task fails and wait_for_completion=True, return_result=True.
143+ """
132144 return await self .submit_task (
133145 model_name = model_name ,
134146 task = "process" ,
@@ -145,7 +157,11 @@ async def process_bulk(
145157 wait_for_completion : bool = True ,
146158 return_result : bool = True ,
147159 ):
148- """Generate annotations for a list of texts."""
160+ """Generate annotations for a list of texts.
161+
162+ Raises:
163+ TaskFailedError: If the task fails and wait_for_completion=True, return_result=True.
164+ """
149165 return await self .submit_task (
150166 model_name = model_name ,
151167 task = "process_bulk" ,
@@ -166,7 +182,11 @@ async def redact(
166182 wait_for_completion : bool = True ,
167183 return_result : bool = True ,
168184 ):
169- """Redact sensitive information from the provided text."""
185+ """Redact sensitive information from the provided text.
186+
187+ Raises:
188+ TaskFailedError: If the task fails and wait_for_completion=True, return_result=True.
189+ """
170190 params = {
171191 k : v
172192 for k , v in {
@@ -238,15 +258,20 @@ async def get_task_result(self, task_uuid: str, parse: bool = True):
238258 async def wait_for_task (
239259 self , task_uuid : str , detail : bool = True , raise_on_error : bool = False
240260 ):
241- """Poll Gateway until the task reaches a final state."""
261+ """Poll Gateway until the task reaches a final state.
262+
263+ Raises:
264+ TaskFailedError: If raise_on_error=True and the task fails.
265+ TimeoutError: If timeout is reached before task completion.
266+ """
242267 start = asyncio .get_event_loop ().time ()
243268 while True :
244269 task = await self .get_task (task_uuid , detail = detail )
245270 status = task .get ("status" )
246271 if status in ("succeeded" , "failed" ):
247272 if status == "failed" and raise_on_error :
248273 error_message = task .get ("error_message" , "Unknown error" )
249- raise RuntimeError ( f"Task ' { task_uuid } ' failed: { error_message } " )
274+ raise TaskFailedError ( task_uuid , error_message , task )
250275 return task
251276 if self .timeout is not None and asyncio .get_event_loop ().time () - start > self .timeout :
252277 raise TimeoutError (f"Timed out waiting for task '{ task_uuid } ' to complete" )
@@ -365,7 +390,11 @@ def submit_task(
365390 wait_for_completion : bool = False ,
366391 return_result : bool = True ,
367392 ):
368- """Submit a task to the Gateway and return the task info."""
393+ """Submit a task to the Gateway and return the task info.
394+
395+ Raises:
396+ TaskFailedError: If the task fails and wait_for_completion=True, return_result=True.
397+ """
369398 return asyncio .run (
370399 self ._client .submit_task (
371400 model_name = model_name ,
@@ -387,7 +416,11 @@ def process(
387416 wait_for_completion : bool = True ,
388417 return_result : bool = True ,
389418 ):
390- """Generate annotations for the provided text."""
419+ """Generate annotations for the provided text.
420+
421+ Raises:
422+ TaskFailedError: If the task fails and wait_for_completion=True, return_result=True.
423+ """
391424 return asyncio .run (
392425 self ._client .process (
393426 text = text ,
@@ -404,7 +437,11 @@ def process_bulk(
404437 wait_for_completion : bool = True ,
405438 return_result : bool = True ,
406439 ):
407- """Generate annotations for a list of texts."""
440+ """Generate annotations for a list of texts.
441+
442+ Raises:
443+ TaskFailedError: If the task fails and wait_for_completion=True, return_result=True.
444+ """
408445 return asyncio .run (
409446 self ._client .process_bulk (
410447 texts = texts ,
@@ -425,7 +462,11 @@ def redact(
425462 wait_for_completion : bool = True ,
426463 return_result : bool = True ,
427464 ):
428- """Redact sensitive information from the provided text."""
465+ """Redact sensitive information from the provided text.
466+
467+ Raises:
468+ TaskFailedError: If the task fails and wait_for_completion=True, return_result=True.
469+ """
429470 return asyncio .run (
430471 self ._client .redact (
431472 text = text ,
@@ -452,7 +493,12 @@ def get_task_result(self, task_uuid: str, parse: bool = True):
452493 return asyncio .run (self ._client .get_task_result (task_uuid = task_uuid , parse = parse ))
453494
454495 def wait_for_task (self , task_uuid : str , detail : bool = True , raise_on_error : bool = False ):
455- """Poll Gateway until the task reaches a final state."""
496+ """Poll Gateway until the task reaches a final state.
497+
498+ Raises:
499+ TaskFailedError: If raise_on_error=True and the task fails.
500+ TimeoutError: If timeout is reached before task completion.
501+ """
456502 return asyncio .run (
457503 self ._client .wait_for_task (
458504 task_uuid = task_uuid , detail = detail , raise_on_error = raise_on_error
0 commit comments