66
77import json
88import requests
9- from typing import Optional , Any , List
109
10+ from typing import Optional , Any , List , Dict
1111
1212import tuneapi .utils as tu
1313import tuneapi .types as tt
@@ -98,6 +98,7 @@ def chat(
9898 max_tokens : int = 1024 ,
9999 temperature : float = 1 ,
100100 token : Optional [str ] = None ,
101+ extra_headers : Optional [Dict [str , str ]] = None ,
101102 ** kwargs ,
102103 ) -> Any :
103104 output = ""
@@ -107,6 +108,8 @@ def chat(
107108 max_tokens = max_tokens ,
108109 temperature = temperature ,
109110 token = token ,
111+ extra_headers = extra_headers ,
112+ raw = False ,
110113 ** kwargs ,
111114 ):
112115 if isinstance (x , dict ):
@@ -123,10 +126,13 @@ def stream_chat(
123126 temperature : float = 1 ,
124127 token : Optional [str ] = None ,
125128 timeout = (5 , 60 ),
126- raw : bool = False ,
129+ extra_headers : Optional [ Dict [ str , str ]] = None ,
127130 debug : bool = False ,
131+ raw : bool = False ,
128132 ):
129133 headers , messages = self ._process_input (chats , token )
134+ if extra_headers :
135+ headers .update (extra_headers )
130136 data = {
131137 "temperature" : temperature ,
132138 "messages" : messages ,
@@ -191,11 +197,14 @@ def embedding(
191197 token : Optional [str ] = None ,
192198 timeout = (5 , 60 ),
193199 raw : bool = False ,
200+ extra_headers : Optional [Dict [str , str ]] = None ,
194201 ):
195202 """If you pass a list then returned items are in the insertion order"""
196203 text = []
197204
198205 headers = self ._process_header (token )
206+ if extra_headers :
207+ headers .update (extra_headers )
199208 if isinstance (chats , tt .Thread ):
200209 _ , messages = self ._process_input (chats , token )
201210 for i , m in enumerate (messages ):
0 commit comments