@@ -18,10 +18,12 @@ def __init__(
1818 self ,
1919 id : Optional [str ] = "gpt-4o" ,
2020 base_url : str = "https://api.openai.com/v1/chat/completions" ,
21+ extra_headers : Optional [Dict [str , str ]] = None ,
2122 ):
2223 self .model_id = id
2324 self .base_url = base_url
2425 self .api_token = tu .ENV .OPENAI_TOKEN ("" )
26+ self .extra_headers = extra_headers
2527
2628 def set_api_token (self , token : str ) -> None :
2729 self .api_token = token
@@ -131,6 +133,7 @@ def stream_chat(
131133 raw : bool = False ,
132134 ):
133135 headers , messages = self ._process_input (chats , token )
136+ extra_headers = extra_headers or self .extra_headers
134137 if extra_headers :
135138 headers .update (extra_headers )
136139 data = {
@@ -187,8 +190,6 @@ def stream_chat(
187190 yield fn_call
188191 return
189192
190- # def _process_chat_to_string_for_embedding(self, chat: tt.Thread):
191-
192193 def embedding (
193194 self ,
194195 chats : tt .Thread | List [str ] | str ,
@@ -203,6 +204,7 @@ def embedding(
203204 text = []
204205
205206 headers = self ._process_header (token )
207+ extra_headers = extra_headers or self .extra_headers
206208 if extra_headers :
207209 headers .update (extra_headers )
208210 if isinstance (chats , tt .Thread ):
0 commit comments