Skip to content

Commit 5ca40ee

Browse files
committed
feat: Allow passing additional http headers in API methods, fix many docstrings, and update unit tests
1 parent 07e9a1a commit 5ca40ee

File tree

8 files changed

+287
-37
lines changed

8 files changed

+287
-37
lines changed

asknews_sdk/api/analytics.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from datetime import datetime
2-
from typing import Literal, Optional, Union
2+
from typing import Dict, Literal, Optional, Union
33

44
from asknews_sdk.api.base import BaseAPI
55
from asknews_sdk.dto.sentiment import FinanceResponse
@@ -44,6 +44,8 @@ def get_asset_sentiment(
4444
] = "news_positive",
4545
date_from: Optional[Union[datetime, str]] = None,
4646
date_to: Optional[Union[datetime, str]] = None,
47+
*,
48+
http_headers: Optional[Dict] = None,
4749
) -> FinanceResponse:
4850
"""
4951
Get the timeseries sentiment for an asset.
@@ -58,6 +60,8 @@ def get_asset_sentiment(
5860
:type date_from: Optional[Union[str, datetime]]
5961
:param date_to: The end date in ISO format.
6062
:type date_to: Optional[Union[str, datetime]]
63+
:param http_headers: Additional HTTP headers.
64+
:type http_headers: Optional[Dict]
6165
:return: The sentiment response.
6266
:rtype: FinanceResponse
6367
"""
@@ -75,6 +79,7 @@ def get_asset_sentiment(
7579
"date_from": date_from,
7680
"date_to": date_to,
7781
},
82+
headers=http_headers,
7883
accept=[(FinanceResponse.__content_type__, 1.0)],
7984
)
8085
return FinanceResponse.model_validate(response.content)
@@ -119,6 +124,8 @@ async def get_asset_sentiment(
119124
] = "news_positive",
120125
date_from: Optional[Union[datetime, str]] = None,
121126
date_to: Optional[Union[datetime, str]] = None,
127+
*,
128+
http_headers: Optional[Dict] = None,
122129
) -> FinanceResponse:
123130
"""
124131
Get the timeseries sentiment for an asset.
@@ -133,6 +140,8 @@ async def get_asset_sentiment(
133140
:type date_from: Optional[Union[str, datetime]]
134141
:param date_to: The end date in ISO format.
135142
:type date_to: Optional[Union[str, datetime]]
143+
:param http_headers: Additional HTTP headers.
144+
:type http_headers: Optional[Dict]
136145
:return: The sentiment response.
137146
:rtype: FinanceResponse
138147
"""
@@ -150,6 +159,7 @@ async def get_asset_sentiment(
150159
"date_from": date_from,
151160
"date_to": date_to,
152161
},
162+
headers=http_headers,
153163
accept=[(FinanceResponse.__content_type__, 1.0)],
154164
)
155165
return FinanceResponse.model_validate(response.content)

asknews_sdk/api/chat.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ def get_chat_completions(
3232
append_references: bool = True,
3333
asknews_watermark: bool = True,
3434
journalist_mode: bool = True,
35+
*,
36+
http_headers: Optional[Dict] = None,
3537
) -> Union[CreateChatCompletionResponse, Iterator[CreateChatCompletionResponseStream]]:
3638
"""
3739
Get chat completions for a given user message.
@@ -53,6 +55,10 @@ def get_chat_completions(
5355
:type append_references: bool
5456
:param asknews_watermark: Whether to add AskNews watermark, defaults to True
5557
:type asknews_watermark: bool
58+
:param journalist_mode: Whether to enable journalist mode, defaults to True
59+
:type journalist_mode: bool
60+
:param http_headers: Additional HTTP headers.
61+
:type http_headers: Optional[Dict]
5662
:return: Chat completions
5763
:rtype: Union[
5864
CreateChatCompletionResponse, Iterator[CreateChatCompletionResponseStream]
@@ -71,6 +77,7 @@ def get_chat_completions(
7177
journalist_mode=journalist_mode,
7278
).model_dump(mode="json"),
7379
headers={
80+
**(http_headers or {}),
7481
"Content-Type": CreateChatCompletionRequest.__content_type__,
7582
},
7683
accept=[
@@ -91,24 +98,30 @@ def _stream():
9198
else:
9299
return CreateChatCompletionResponse.model_validate(response.content)
93100

94-
def list_chat_models(self) -> ListModelResponse:
101+
def list_chat_models(self, *, http_headers: Optional[Dict] = None) -> ListModelResponse:
95102
"""
96103
List available chat models.
97104
98105
https://docs.asknews.app/en/reference#get-/v1/openai/models
99106
107+
:param http_headers: Additional HTTP headers.
108+
:type http_headers: Optional[Dict]
100109
:return: List of available chat models
101110
:rtype: ListModelResponse
102111
"""
103112
response = self.client.request(
104113
method="GET",
105114
endpoint="/v1/openai/models",
115+
headers=http_headers,
106116
accept=[(ListModelResponse.__content_type__, 1.0)],
107117
)
108118
return ListModelResponse.model_validate(response.content)
109119

110120
def get_headline_questions(
111-
self, queries: Optional[List[str]] = None
121+
self,
122+
queries: Optional[List[str]] = None,
123+
*,
124+
http_headers: Optional[Dict] = None,
112125
) -> HeadlineQuestionsResponse:
113126
"""
114127
Get headline questions for a given query.
@@ -117,12 +130,15 @@ def get_headline_questions(
117130
118131
:param queries: List of queries to get headline questions for
119132
:type queries: Optional[List[str]]
133+
:param http_headers: Additional HTTP headers.
134+
:type http_headers: Optional[Dict]
120135
:return: Headline questions
121136
:rtype: HeadlineQuestionsResponse
122137
"""
123138
response = self.client.request(
124139
method="GET",
125140
endpoint="/v1/chat/questions",
141+
headers=http_headers,
126142
query={"queries": queries}
127143
)
128144
return HeadlineQuestionsResponse.model_validate(response.content)
@@ -149,6 +165,8 @@ async def get_chat_completions(
149165
append_references: bool = True,
150166
asknews_watermark: bool = True,
151167
journalist_mode: bool = True,
168+
*,
169+
http_headers: Optional[Dict] = None,
152170
) -> Union[CreateChatCompletionResponse, AsyncIterator[CreateChatCompletionResponseStream]]:
153171
"""
154172
Get chat completions for a given user message.
@@ -170,6 +188,10 @@ async def get_chat_completions(
170188
:type append_references: bool
171189
:param asknews_watermark: Whether to add AskNews watermark, defaults to True
172190
:type asknews_watermark: bool
191+
:param journalist_mode: Whether to enable journalist mode, defaults to True
192+
:type journalist_mode: bool
193+
:param http_headers: Additional HTTP headers.
194+
:type http_headers: Optional[Dict]
173195
:return: Chat completions
174196
:rtype: Union[
175197
CreateChatCompletionResponse,
@@ -190,6 +212,7 @@ async def get_chat_completions(
190212
).model_dump(mode="json"),
191213
headers={
192214
"Content-Type": CreateChatCompletionRequest.__content_type__,
215+
**(http_headers or {}),
193216
},
194217
accept=[
195218
(CreateChatCompletionResponse.__content_type__, 1.0),
@@ -209,24 +232,30 @@ async def _stream():
209232
else:
210233
return CreateChatCompletionResponse.model_validate(response.content)
211234

212-
async def list_chat_models(self) -> ListModelResponse:
235+
async def list_chat_models(self, *, http_headers: Optional[Dict] = None) -> ListModelResponse:
213236
"""
214237
List available chat models.
215238
216239
https://docs.asknews.app/en/reference#get-/v1/openai/models
217240
241+
:param http_headers: Additional HTTP headers.
242+
:type http_headers: Optional[Dict]
218243
:return: List of available chat models
219244
:rtype: ListModelResponse
220245
"""
221246
response = await self.client.request(
222247
method="GET",
223248
endpoint="/v1/openai/models",
249+
headers=http_headers,
224250
accept=[(ListModelResponse.__content_type__, 1.0)],
225251
)
226252
return ListModelResponse.model_validate(response.content)
227253

228254
async def get_headline_questions(
229-
self, queries: Optional[List[str]] = None
255+
self,
256+
queries: Optional[List[str]] = None,
257+
*,
258+
http_headers: Optional[Dict] = None,
230259
) -> HeadlineQuestionsResponse:
231260
"""
232261
Get headline questions for a given query.
@@ -235,12 +264,15 @@ async def get_headline_questions(
235264
236265
:param queries: List of queries to get headline questions for
237266
:type queries: Optional[List[str]]
267+
:param http_headers: Additional HTTP headers.
268+
:type http_headers: Optional[Dict]
238269
:return: Headline questions
239270
:rtype: HeadlineQuestionsResponse
240271
"""
241272
response = await self.client.request(
242273
method="GET",
243274
endpoint="/v1/chat/questions",
275+
headers=http_headers,
244276
query={"queries": queries}
245277
)
246278
return HeadlineQuestionsResponse.model_validate(response.content)

0 commit comments

Comments
 (0)