11import json
2+ from typing import Any , Callable
23
34import requests
45from pydantic import BaseModel
6+ from requests import Response
57
68from nrlf .core .constants import Categories , PointerTypes
79from nrlf .core .model import ConnectionMetadata
@@ -34,6 +36,35 @@ def add_pointer_type(self, pointer_type: PointerTypes):
3436 return self
3537
3638
39+ def retry_if (status_codes : list [int ]) -> Callable [..., Any ]:
40+ """
41+ Decorator to retry a function call if it returns certain errors
42+ """
43+
44+ def wrapped_func (func : Callable [..., Response ]) -> Callable [..., Response ]:
45+ def wrapper (* args : Any , ** kwargs : Any ) -> Any :
46+ attempt_responses : list [Response ] = []
47+ for attempt in range (2 ):
48+ response = func (* args , ** kwargs )
49+ if not response .status_code or response .status_code not in status_codes :
50+ return response
51+ attempt_responses .append (response )
52+ print ( # noqa: T201
53+ f"Retrying due to { response .status_code } error in attempt { attempt + 1 } ..."
54+ )
55+
56+ print ( # noqa: T201
57+ f"All attempts failed with responses: { attempt_responses } "
58+ )
59+ raise Exception (
60+ f"Function failed after retries with responses: { attempt_responses } "
61+ )
62+
63+ return wrapper
64+
65+ return wrapped_func
66+
67+
3768class ConsumerTestClient :
3869
3970 def __init__ (self , config : ClientConfig ):
@@ -60,29 +91,32 @@ def __init__(self, config: ClientConfig):
6091
6192 self .request_headers .update (self .config .custom_headers )
6293
63- def read (self , doc_ref_id : str ):
94+ @retry_if ([502 ])
95+ def read (self , doc_ref_id : str ) -> Response :
6496 return requests .get (
6597 f"{ self .api_url } /DocumentReference/{ doc_ref_id } " ,
6698 headers = self .request_headers ,
6799 cert = self .config .client_cert ,
68100 )
69101
70- def count (self , params : dict [str , str ]):
102+ @retry_if ([502 ])
103+ def count (self , params : dict [str , str ]) -> Response :
71104 return requests .get (
72105 f"{ self .api_url } /DocumentReference/_count" ,
73106 params = params ,
74107 headers = self .request_headers ,
75108 cert = self .config .client_cert ,
76109 )
77110
111+ @retry_if ([502 ])
78112 def search (
79113 self ,
80114 nhs_number : str | None = None ,
81115 custodian : str | None = None ,
82116 pointer_type : PointerTypes | None = None ,
83117 category : Categories | None = None ,
84118 extra_params : dict [str , str ] | None = None ,
85- ):
119+ ) -> Response :
86120 params = {** (extra_params or {})}
87121
88122 if nhs_number :
@@ -114,14 +148,15 @@ def search(
114148 cert = self .config .client_cert ,
115149 )
116150
151+ @retry_if ([502 ])
117152 def search_post (
118153 self ,
119154 nhs_number : str | None = None ,
120155 custodian : str | None = None ,
121156 pointer_type : PointerTypes | None = None ,
122157 category : Categories | None = None ,
123158 extra_fields : dict [str , str ] | None = None ,
124- ):
159+ ) -> Response :
125160 body = {** (extra_fields or {})}
126161
127162 if nhs_number :
@@ -156,7 +191,8 @@ def search_post(
156191 cert = self .config .client_cert ,
157192 )
158193
159- def read_capability_statement (self ):
194+ @retry_if ([502 ])
195+ def read_capability_statement (self ) -> Response :
160196 return requests .get (
161197 f"{ self .api_url } /metadata" ,
162198 headers = self .request_headers ,
@@ -189,74 +225,83 @@ def __init__(self, config: ClientConfig):
189225
190226 self .request_headers .update (self .config .custom_headers )
191227
192- def create (self , doc_ref ):
228+ @retry_if ([502 ])
229+ def create (self , doc_ref ) -> Response :
193230 return requests .post (
194231 f"{ self .api_url } /DocumentReference" ,
195232 json = doc_ref ,
196233 headers = self .request_headers ,
197234 cert = self .config .client_cert ,
198235 )
199236
200- def create_text (self , doc_ref ):
237+ @retry_if ([502 ])
238+ def create_text (self , doc_ref ) -> Response :
201239 return requests .post (
202240 f"{ self .api_url } /DocumentReference" ,
203241 data = doc_ref ,
204242 headers = self .request_headers ,
205243 cert = self .config .client_cert ,
206244 )
207245
208- def upsert (self , doc_ref ):
246+ @retry_if ([502 ])
247+ def upsert (self , doc_ref ) -> Response :
209248 return requests .put (
210249 f"{ self .api_url } /DocumentReference" ,
211250 json = doc_ref ,
212251 headers = self .request_headers ,
213252 cert = self .config .client_cert ,
214253 )
215254
216- def upsert_text (self , doc_ref ):
255+ @retry_if ([502 ])
256+ def upsert_text (self , doc_ref ) -> Response :
217257 return requests .put (
218258 f"{ self .api_url } /DocumentReference" ,
219259 data = doc_ref ,
220260 headers = self .request_headers ,
221261 cert = self .config .client_cert ,
222262 )
223263
224- def update (self , doc_ref , doc_ref_id : str ):
264+ @retry_if ([502 ])
265+ def update (self , doc_ref , doc_ref_id : str ) -> Response :
225266 return requests .put (
226267 f"{ self .api_url } /DocumentReference/{ doc_ref_id } " ,
227268 json = doc_ref ,
228269 headers = self .request_headers ,
229270 cert = self .config .client_cert ,
230271 )
231272
232- def update_text (self , doc_ref , doc_ref_id : str ):
273+ @retry_if ([502 ])
274+ def update_text (self , doc_ref , doc_ref_id : str ) -> Response :
233275 return requests .put (
234276 f"{ self .api_url } /DocumentReference/{ doc_ref_id } " ,
235277 data = doc_ref ,
236278 headers = self .request_headers ,
237279 cert = self .config .client_cert ,
238280 )
239281
240- def delete (self , doc_ref_id : str ):
282+ @retry_if ([502 ])
283+ def delete (self , doc_ref_id : str ) -> Response :
241284 return requests .delete (
242285 f"{ self .api_url } /DocumentReference/{ doc_ref_id } " ,
243286 headers = self .request_headers ,
244287 cert = self .config .client_cert ,
245288 )
246289
247- def read (self , doc_ref_id : str ):
290+ @retry_if ([502 ])
291+ def read (self , doc_ref_id : str ) -> Response :
248292 return requests .get (
249293 f"{ self .api_url } /DocumentReference/{ doc_ref_id } " ,
250294 headers = self .request_headers ,
251295 cert = self .config .client_cert ,
252296 )
253297
298+ @retry_if ([502 ])
254299 def search (
255300 self ,
256301 nhs_number : str | None = None ,
257302 pointer_type : PointerTypes | None = None ,
258303 extra_params : dict [str , str ] | None = None ,
259- ):
304+ ) -> Response :
260305 params = {** (extra_params or {})}
261306
262307 if nhs_number :
@@ -277,12 +322,13 @@ def search(
277322 cert = self .config .client_cert ,
278323 )
279324
325+ @retry_if ([502 ])
280326 def search_post (
281327 self ,
282328 nhs_number : str | None = None ,
283329 pointer_type : PointerTypes | None = None ,
284330 extra_fields : dict [str , str ] | None = None ,
285- ):
331+ ) -> Response :
286332 body = {** (extra_fields or {})}
287333
288334 if nhs_number :
@@ -306,7 +352,8 @@ def search_post(
306352 cert = self .config .client_cert ,
307353 )
308354
309- def read_capability_statement (self ):
355+ @retry_if ([502 ])
356+ def read_capability_statement (self ) -> Response :
310357 return requests .get (
311358 f"{ self .api_url } /metadata" ,
312359 headers = self .request_headers ,
0 commit comments