11import json
2+ import logging
3+ from typing import Any , Callable
24
35import requests
46from pydantic import BaseModel
7+ from requests import Response
58
69from nrlf .core .constants import Categories , PointerTypes
710from nrlf .core .model import ConnectionMetadata
811
12+ logger = logging .getLogger (__name__ )
13+
914
1015class ClientConfig (BaseModel ):
1116 base_url : str
@@ -34,6 +39,33 @@ def add_pointer_type(self, pointer_type: PointerTypes):
3439 return self
3540
3641
42+ def retry_if (status_codes : list [int ]) -> Callable [..., Any ]:
43+ """
44+ Decorator to retry a function call if it returns certain errors
45+ """
46+
47+ def wrapped_func (func : Callable [..., Response ]) -> Callable [..., Response ]:
48+ def wrapper (* args : Any , ** kwargs : Any ) -> Any :
49+ attempt_responses : list [Response ] = []
50+ for attempt in range (3 ):
51+ response = func (* args , ** kwargs )
52+ if not response .status_code or response .status_code not in status_codes :
53+ return response
54+ attempt_responses .append (response )
55+ logger .warning (
56+ f"Attempt { attempt + 1 } failed with status code { response .status_code } "
57+ )
58+
59+ logger .error (f"All attempts failed with responses: { attempt_responses } " )
60+ raise RuntimeError (
61+ f"Function failed after retries with responses: { attempt_responses } "
62+ )
63+
64+ return wrapper
65+
66+ return wrapped_func
67+
68+
3769class ConsumerTestClient :
3870
3971 def __init__ (self , config : ClientConfig ):
@@ -60,29 +92,32 @@ def __init__(self, config: ClientConfig):
6092
6193 self .request_headers .update (self .config .custom_headers )
6294
63- def read (self , doc_ref_id : str ):
95+ @retry_if ([502 ])
96+ def read (self , doc_ref_id : str ) -> Response :
6497 return requests .get (
6598 f"{ self .api_url } /DocumentReference/{ doc_ref_id } " ,
6699 headers = self .request_headers ,
67100 cert = self .config .client_cert ,
68101 )
69102
70- def count (self , params : dict [str , str ]):
103+ @retry_if ([502 ])
104+ def count (self , params : dict [str , str ]) -> Response :
71105 return requests .get (
72106 f"{ self .api_url } /DocumentReference/_count" ,
73107 params = params ,
74108 headers = self .request_headers ,
75109 cert = self .config .client_cert ,
76110 )
77111
112+ @retry_if ([502 ])
78113 def search (
79114 self ,
80115 nhs_number : str | None = None ,
81116 custodian : str | None = None ,
82117 pointer_type : PointerTypes | None = None ,
83118 category : Categories | None = None ,
84119 extra_params : dict [str , str ] | None = None ,
85- ):
120+ ) -> Response :
86121 params = {** (extra_params or {})}
87122
88123 if nhs_number :
@@ -114,14 +149,15 @@ def search(
114149 cert = self .config .client_cert ,
115150 )
116151
152+ @retry_if ([502 ])
117153 def search_post (
118154 self ,
119155 nhs_number : str | None = None ,
120156 custodian : str | None = None ,
121157 pointer_type : PointerTypes | None = None ,
122158 category : Categories | None = None ,
123159 extra_fields : dict [str , str ] | None = None ,
124- ):
160+ ) -> Response :
125161 body = {** (extra_fields or {})}
126162
127163 if nhs_number :
@@ -156,7 +192,8 @@ def search_post(
156192 cert = self .config .client_cert ,
157193 )
158194
159- def read_capability_statement (self ):
195+ @retry_if ([502 ])
196+ def read_capability_statement (self ) -> Response :
160197 return requests .get (
161198 f"{ self .api_url } /metadata" ,
162199 headers = self .request_headers ,
@@ -189,74 +226,83 @@ def __init__(self, config: ClientConfig):
189226
190227 self .request_headers .update (self .config .custom_headers )
191228
192- def create (self , doc_ref ):
229+ @retry_if ([502 ])
230+ def create (self , doc_ref ) -> Response :
193231 return requests .post (
194232 f"{ self .api_url } /DocumentReference" ,
195233 json = doc_ref ,
196234 headers = self .request_headers ,
197235 cert = self .config .client_cert ,
198236 )
199237
200- def create_text (self , doc_ref ):
238+ @retry_if ([502 ])
239+ def create_text (self , doc_ref ) -> Response :
201240 return requests .post (
202241 f"{ self .api_url } /DocumentReference" ,
203242 data = doc_ref ,
204243 headers = self .request_headers ,
205244 cert = self .config .client_cert ,
206245 )
207246
208- def upsert (self , doc_ref ):
247+ @retry_if ([502 ])
248+ def upsert (self , doc_ref ) -> Response :
209249 return requests .put (
210250 f"{ self .api_url } /DocumentReference" ,
211251 json = doc_ref ,
212252 headers = self .request_headers ,
213253 cert = self .config .client_cert ,
214254 )
215255
216- def upsert_text (self , doc_ref ):
256+ @retry_if ([502 ])
257+ def upsert_text (self , doc_ref ) -> Response :
217258 return requests .put (
218259 f"{ self .api_url } /DocumentReference" ,
219260 data = doc_ref ,
220261 headers = self .request_headers ,
221262 cert = self .config .client_cert ,
222263 )
223264
224- def update (self , doc_ref , doc_ref_id : str ):
265+ @retry_if ([502 ])
266+ def update (self , doc_ref , doc_ref_id : str ) -> Response :
225267 return requests .put (
226268 f"{ self .api_url } /DocumentReference/{ doc_ref_id } " ,
227269 json = doc_ref ,
228270 headers = self .request_headers ,
229271 cert = self .config .client_cert ,
230272 )
231273
232- def update_text (self , doc_ref , doc_ref_id : str ):
274+ @retry_if ([502 ])
275+ def update_text (self , doc_ref , doc_ref_id : str ) -> Response :
233276 return requests .put (
234277 f"{ self .api_url } /DocumentReference/{ doc_ref_id } " ,
235278 data = doc_ref ,
236279 headers = self .request_headers ,
237280 cert = self .config .client_cert ,
238281 )
239282
240- def delete (self , doc_ref_id : str ):
283+ @retry_if ([502 ])
284+ def delete (self , doc_ref_id : str ) -> Response :
241285 return requests .delete (
242286 f"{ self .api_url } /DocumentReference/{ doc_ref_id } " ,
243287 headers = self .request_headers ,
244288 cert = self .config .client_cert ,
245289 )
246290
247- def read (self , doc_ref_id : str ):
291+ @retry_if ([502 ])
292+ def read (self , doc_ref_id : str ) -> Response :
248293 return requests .get (
249294 f"{ self .api_url } /DocumentReference/{ doc_ref_id } " ,
250295 headers = self .request_headers ,
251296 cert = self .config .client_cert ,
252297 )
253298
299+ @retry_if ([502 ])
254300 def search (
255301 self ,
256302 nhs_number : str | None = None ,
257303 pointer_type : PointerTypes | None = None ,
258304 extra_params : dict [str , str ] | None = None ,
259- ):
305+ ) -> Response :
260306 params = {** (extra_params or {})}
261307
262308 if nhs_number :
@@ -277,12 +323,13 @@ def search(
277323 cert = self .config .client_cert ,
278324 )
279325
326+ @retry_if ([502 ])
280327 def search_post (
281328 self ,
282329 nhs_number : str | None = None ,
283330 pointer_type : PointerTypes | None = None ,
284331 extra_fields : dict [str , str ] | None = None ,
285- ):
332+ ) -> Response :
286333 body = {** (extra_fields or {})}
287334
288335 if nhs_number :
@@ -306,7 +353,8 @@ def search_post(
306353 cert = self .config .client_cert ,
307354 )
308355
309- def read_capability_statement (self ):
356+ @retry_if ([502 ])
357+ def read_capability_statement (self ) -> Response :
310358 return requests .get (
311359 f"{ self .api_url } /metadata" ,
312360 headers = self .request_headers ,
0 commit comments