@@ -36,36 +36,50 @@ def __init__(
3636
3737 @log_elapsed (operation = "ANSWER_PROMPTS" )
3838 def answer_prompt (
39- self , payload : dict [str , Any ], params : Optional [dict [str , str ]] = None
39+ self ,
40+ payload : dict [str , Any ],
41+ params : Optional [dict [str , str ]] = None ,
42+ headers : Optional [dict [str , str ]] = None ,
4043 ) -> dict [str , Any ]:
4144 url_path = "answer-prompt"
4245 if self .is_public_call :
4346 url_path = "answer-prompt-public"
4447 return self ._post_call (
45- url_path = url_path ,
46- payload = payload ,
47- params = params ,
48+ url_path = url_path , payload = payload , params = params , headers = headers
4849 )
4950
5051 def single_pass_extraction (
51- self , payload : dict [str , Any ], params : Optional [dict [str , str ]] = None
52+ self ,
53+ payload : dict [str , Any ],
54+ params : Optional [dict [str , str ]] = None ,
55+ headers : Optional [dict [str , str ]] = None ,
5256 ) -> dict [str , Any ]:
5357 return self ._post_call (
5458 url_path = "single-pass-extraction" ,
5559 payload = payload ,
5660 params = params ,
61+ headers = headers ,
5762 )
5863
5964 def summarize (
60- self , payload : dict [str , Any ], params : Optional [dict [str , str ]] = None
65+ self ,
66+ payload : dict [str , Any ],
67+ params : Optional [dict [str , str ]] = None ,
68+ headers : Optional [dict [str , str ]] = None ,
6169 ) -> dict [str , Any ]:
62- return self ._post_call (url_path = "summarize" , payload = payload , params = params )
70+ return self ._post_call (
71+ url_path = "summarize" ,
72+ payload = payload ,
73+ params = params ,
74+ headers = headers ,
75+ )
6376
6477 def _post_call (
6578 self ,
6679 url_path : str ,
6780 payload : dict [str , Any ],
6881 params : Optional [dict [str , str ]] = None ,
82+ headers : Optional [dict [str , str ]] = None ,
6983 ) -> dict [str , Any ]:
7084 """Invokes and communicates to prompt service to fetch response for the
7185 prompt.
@@ -74,6 +88,7 @@ def _post_call(
7488 url_path (str): URL path to the service endpoint
7589 payload (dict): Payload to send in the request body
7690 params (dict, optional): Query parameters to include in the request
91+ headers (dict, optional): Headers to include in the request
7792
7893 Returns:
7994 dict: Response from the prompt service
@@ -94,13 +109,19 @@ def _post_call(
94109 "status_code" : 500 ,
95110 }
96111 url : str = f"{ self .base_url } /{ url_path } "
97- headers : dict [str , str ] = {}
112+
113+ default_headers = {}
114+
98115 if not self .is_public_call :
99- headers = {"Authorization" : f"Bearer { self .bearer_token } " }
116+ default_headers = {"Authorization" : f"Bearer { self .bearer_token } " }
117+
118+ if headers :
119+ default_headers .update (headers )
120+
100121 response : Response = Response ()
101122 try :
102123 response = requests .post (
103- url = url , json = payload , params = params , headers = headers
124+ url = url , json = payload , params = params , headers = default_headers
104125 )
105126 response .raise_for_status ()
106127 result ["status" ] = "OK"
0 commit comments