Skip to content

Commit b834f28

Browse files
authored
Merge pull request #17 from emergentmethods/feat/refactor
Refactor SDK and authentication mechanisms
2 parents 2c57e67 + c6a2656 commit b834f28

30 files changed

+2690
-496
lines changed

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,4 +118,7 @@ dmypy.json
118118
*.code-workspace
119119

120120
# Test scripts
121-
test*.py
121+
_test*.py
122+
123+
# Ruff
124+
.ruff_cache/

Taskfile.yml

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,31 @@
11
version: "3"
2+
3+
silent: true
4+
25
vars:
36
PACKAGE_SRC_DIR: asknews_sdk
7+
48
tasks:
9+
# Unit tests
10+
test:
11+
cmds:
12+
- coverage run -m pytest --junitxml=report.xml
13+
- coverage report
14+
- coverage xml
15+
- coverage html -d coverage-report
16+
517
# Lint
618
lint:
719
cmds:
8-
- poetry run ruff check {{.PACKAGE_SRC_DIR}}
9-
- poetry run ruff format --check {{.PACKAGE_SRC_DIR}}
20+
- |
21+
if [ -n "{{.SRC_DIR}}" ]; then
22+
export SRC_DIR="{{.SRC_DIR}}"
23+
else
24+
export SRC_DIR="{{.PACKAGE_SRC_DIR}}"
25+
fi
26+
- poetry run ruff check $SRC_DIR {{.CLI_ARGS}}
27+
- poetry run ruff format $SRC_DIR {{.CLI_ARGS}}
28+
1029
# Build
1130
build:
1231
cmds:

asknews_sdk/api/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from asknews_sdk.api.news import AsyncNewsAPI, NewsAPI
44
from asknews_sdk.api.stories import AsyncStoriesAPI, StoriesAPI
55

6+
67
__all__ = (
78
"AnalyticsAPI",
89
"AsyncAnalyticsAPI",

asknews_sdk/api/analytics.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from datetime import datetime
2-
from typing import Literal, Optional, Union
2+
from typing import Dict, Literal, Optional, Union
33

44
from asknews_sdk.api.base import BaseAPI
55
from asknews_sdk.dto.sentiment import FinanceResponse
@@ -44,6 +44,8 @@ def get_asset_sentiment(
4444
] = "news_positive",
4545
date_from: Optional[Union[datetime, str]] = None,
4646
date_to: Optional[Union[datetime, str]] = None,
47+
*,
48+
http_headers: Optional[Dict] = None,
4749
) -> FinanceResponse:
4850
"""
4951
Get the timeseries sentiment for an asset.
@@ -58,6 +60,8 @@ def get_asset_sentiment(
5860
:type date_from: Optional[Union[str, datetime]]
5961
:param date_to: The end date in ISO format.
6062
:type date_to: Optional[Union[str, datetime]]
63+
:param http_headers: Additional HTTP headers.
64+
:type http_headers: Optional[Dict]
6165
:return: The sentiment response.
6266
:rtype: FinanceResponse
6367
"""
@@ -75,6 +79,7 @@ def get_asset_sentiment(
7579
"date_from": date_from,
7680
"date_to": date_to,
7781
},
82+
headers=http_headers,
7883
accept=[(FinanceResponse.__content_type__, 1.0)],
7984
)
8085
return FinanceResponse.model_validate(response.content)
@@ -119,6 +124,8 @@ async def get_asset_sentiment(
119124
] = "news_positive",
120125
date_from: Optional[Union[datetime, str]] = None,
121126
date_to: Optional[Union[datetime, str]] = None,
127+
*,
128+
http_headers: Optional[Dict] = None,
122129
) -> FinanceResponse:
123130
"""
124131
Get the timeseries sentiment for an asset.
@@ -133,6 +140,8 @@ async def get_asset_sentiment(
133140
:type date_from: Optional[Union[str, datetime]]
134141
:param date_to: The end date in ISO format.
135142
:type date_to: Optional[Union[str, datetime]]
143+
:param http_headers: Additional HTTP headers.
144+
:type http_headers: Optional[Dict]
136145
:return: The sentiment response.
137146
:rtype: FinanceResponse
138147
"""
@@ -150,6 +159,7 @@ async def get_asset_sentiment(
150159
"date_from": date_from,
151160
"date_to": date_to,
152161
},
162+
headers=http_headers,
153163
accept=[(FinanceResponse.__content_type__, 1.0)],
154164
)
155165
return FinanceResponse.model_validate(response.content)

asknews_sdk/api/chat.py

Lines changed: 83 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
CreateChatCompletionRequest,
66
CreateChatCompletionResponse,
77
CreateChatCompletionResponseStream,
8+
HeadlineQuestionsResponse,
89
ListModelResponse,
910
)
11+
from asknews_sdk.response import EventSource
1012

1113

1214
class ChatAPI(BaseAPI):
@@ -26,9 +28,13 @@ def get_chat_completions(
2628
"mixtral-8x7b-32768",
2729
] = "gpt-3.5-turbo-16k",
2830
stream: bool = False,
29-
) -> Union[
30-
CreateChatCompletionResponse, Iterator[CreateChatCompletionResponseStream]
31-
]:
31+
inline_citations: Literal["markdown_link", "numbered", "none"] = "markdown_link",
32+
append_references: bool = True,
33+
asknews_watermark: bool = True,
34+
journalist_mode: bool = True,
35+
*,
36+
http_headers: Optional[Dict] = None,
37+
) -> Union[CreateChatCompletionResponse, Iterator[CreateChatCompletionResponseStream]]:
3238
"""
3339
Get chat completions for a given user message.
3440
@@ -43,6 +49,16 @@ def get_chat_completions(
4349
]
4450
:param stream: Whether to stream the response, defaults to False
4551
:type stream: bool
52+
:param inline_citations: Inline citations format, defaults to "markdown_link"
53+
:type inline_citations: Literal["markdown_link", "numbered", "none"]
54+
:param append_references: Whether to append references, defaults to True
55+
:type append_references: bool
56+
:param asknews_watermark: Whether to add AskNews watermark, defaults to True
57+
:type asknews_watermark: bool
58+
:param journalist_mode: Whether to enable journalist mode, defaults to True
59+
:type journalist_mode: bool
60+
:param http_headers: Additional HTTP headers.
61+
:type http_headers: Optional[Dict]
4662
:return: Chat completions
4763
:rtype: Union[
4864
CreateChatCompletionResponse, Iterator[CreateChatCompletionResponseStream]
@@ -55,68 +71,77 @@ def get_chat_completions(
5571
messages=messages,
5672
model=model,
5773
stream=stream,
74+
inline_citations=inline_citations,
75+
append_references=append_references,
76+
asknews_watermark=asknews_watermark,
77+
journalist_mode=journalist_mode,
5878
).model_dump(mode="json"),
5979
headers={
80+
**(http_headers or {}),
6081
"Content-Type": CreateChatCompletionRequest.__content_type__,
6182
},
6283
accept=[
6384
(CreateChatCompletionResponse.__content_type__, 1.0),
6485
(CreateChatCompletionResponseStream.__content_type__, 1.0),
6586
],
6687
stream=stream,
67-
stream_type="lines", # type: ignore
88+
stream_type="lines",
6889
)
6990

7091
if stream:
71-
7292
def _stream():
73-
for chunk in response.content:
74-
if chunk.strip() == "data: [DONE]":
93+
for event in EventSource.from_api_response(response):
94+
if event.content == "[DONE]":
7595
break
76-
77-
if chunk.startswith("data:"):
78-
json_data = chunk.replace("data: ", "").strip()
79-
yield CreateChatCompletionResponseStream.model_validate_json(
80-
json_data
81-
)
82-
96+
yield CreateChatCompletionResponseStream.model_validate_json(event.content)
8397
return _stream()
8498
else:
8599
return CreateChatCompletionResponse.model_validate(response.content)
86100

87-
def list_chat_models(self) -> ListModelResponse:
101+
def list_chat_models(self, *, http_headers: Optional[Dict] = None) -> ListModelResponse:
88102
"""
89103
List available chat models.
90104
91105
https://docs.asknews.app/en/reference#get-/v1/openai/models
92106
107+
:param http_headers: Additional HTTP headers.
108+
:type http_headers: Optional[Dict]
93109
:return: List of available chat models
94110
:rtype: ListModelResponse
95111
"""
96112
response = self.client.request(
97113
method="GET",
98-
endpoint="/v1/openai/chat/models",
114+
endpoint="/v1/openai/models",
115+
headers=http_headers,
99116
accept=[(ListModelResponse.__content_type__, 1.0)],
100117
)
101118
return ListModelResponse.model_validate(response.content)
102119

103120
def get_headline_questions(
104-
self, queries: Optional[List[str]] = None
105-
) -> Dict[str, List[str]]:
121+
self,
122+
queries: Optional[List[str]] = None,
123+
*,
124+
http_headers: Optional[Dict] = None,
125+
) -> HeadlineQuestionsResponse:
106126
"""
107127
Get headline questions for a given query.
108128
109129
https://docs.asknews.app/en/reference#get-/v1/chat/questions
110130
111131
:param queries: List of queries to get headline questions for
112132
:type queries: Optional[List[str]]
133+
:param http_headers: Additional HTTP headers.
134+
:type http_headers: Optional[Dict]
113135
:return: Headline questions
114-
:rtype: Dict[str, List[str]]
136+
:rtype: HeadlineQuestionsResponse
115137
"""
116138
response = self.client.request(
117-
method="GET", endpoint="/v1/chat/questions", query={"queries": queries}
139+
method="GET",
140+
endpoint="/v1/chat/questions",
141+
headers=http_headers,
142+
query={"queries": queries}
118143
)
119-
return response.content
144+
return HeadlineQuestionsResponse.model_validate(response.content)
120145

121146

122147
class AsyncChatAPI(BaseAPI):
@@ -136,23 +161,13 @@ async def get_chat_completions(
136161
"meta-llama/Meta-Llama-3-70B-Instruct",
137162
] = "gpt-3.5-turbo-16k",
138163
stream: bool = False,
139-
inline_citations: Literal[
140-
"markdown_link", "numbered", "none"
141-
] = "markdown_link",
164+
inline_citations: Literal["markdown_link", "numbered", "none"] = "markdown_link",
142165
append_references: bool = True,
143166
asknews_watermark: bool = True,
144167
journalist_mode: bool = True,
145-
temperature: float = 0.5,
146-
top_p: float = 1,
147-
n: int = 1,
148-
stop: Optional[Union[str, List[str]]] = None,
149-
max_tokens: int = 1000,
150-
presence_penalty: float = 0,
151-
frequency_penalty: float = 0,
152-
user: Optional[str] = None,
153-
) -> Union[
154-
CreateChatCompletionResponse, AsyncIterator[CreateChatCompletionResponseStream]
155-
]:
168+
*,
169+
http_headers: Optional[Dict] = None,
170+
) -> Union[CreateChatCompletionResponse, AsyncIterator[CreateChatCompletionResponseStream]]:
156171
"""
157172
Get chat completions for a given user message.
158173
@@ -167,6 +182,16 @@ async def get_chat_completions(
167182
]
168183
:param stream: Whether to stream the response, defaults to False
169184
:type stream: bool
185+
:param inline_citations: Inline citations format, defaults to "markdown_link"
186+
:type inline_citations: Literal["markdown_link", "numbered", "none"]
187+
:param append_references: Whether to append references, defaults to True
188+
:type append_references: bool
189+
:param asknews_watermark: Whether to add AskNews watermark, defaults to True
190+
:type asknews_watermark: bool
191+
:param journalist_mode: Whether to enable journalist mode, defaults to True
192+
:type journalist_mode: bool
193+
:param http_headers: Additional HTTP headers.
194+
:type http_headers: Optional[Dict]
170195
:return: Chat completions
171196
:rtype: Union[
172197
CreateChatCompletionResponse,
@@ -184,73 +209,70 @@ async def get_chat_completions(
184209
append_references=append_references,
185210
asknews_watermark=asknews_watermark,
186211
journalist_mode=journalist_mode,
187-
temperature=temperature,
188-
top_p=top_p,
189-
n=n,
190-
stop=stop,
191-
max_tokens=max_tokens,
192-
presence_penalty=presence_penalty,
193-
frequency_penalty=frequency_penalty,
194-
user=user,
195212
).model_dump(mode="json"),
196213
headers={
197214
"Content-Type": CreateChatCompletionRequest.__content_type__,
215+
**(http_headers or {}),
198216
},
199217
accept=[
200218
(CreateChatCompletionResponse.__content_type__, 1.0),
201219
(CreateChatCompletionResponseStream.__content_type__, 1.0),
202220
],
203221
stream=stream,
204-
stream_type="lines", # type: ignore
222+
stream_type="lines",
205223
)
206224

207225
if stream:
208-
209226
async def _stream():
210-
async for chunk in response.content:
211-
if chunk.strip() == "data: [DONE]":
227+
async for event in EventSource.from_api_response(response):
228+
if event.content == "[DONE]":
212229
break
213-
214-
if chunk.startswith("data:"):
215-
json_data = chunk.replace("data: ", "").strip()
216-
yield CreateChatCompletionResponseStream.model_validate_json(
217-
json_data
218-
)
219-
230+
yield CreateChatCompletionResponseStream.model_validate_json(event.content)
220231
return _stream()
221232
else:
222233
return CreateChatCompletionResponse.model_validate(response.content)
223234

224-
async def list_chat_models(self) -> ListModelResponse:
235+
async def list_chat_models(self, *, http_headers: Optional[Dict] = None) -> ListModelResponse:
225236
"""
226237
List available chat models.
227238
228239
https://docs.asknews.app/en/reference#get-/v1/openai/models
229240
241+
:param http_headers: Additional HTTP headers.
242+
:type http_headers: Optional[Dict]
230243
:return: List of available chat models
231244
:rtype: ListModelResponse
232245
"""
233246
response = await self.client.request(
234247
method="GET",
235-
endpoint="/v1/openai/chat/models",
248+
endpoint="/v1/openai/models",
249+
headers=http_headers,
236250
accept=[(ListModelResponse.__content_type__, 1.0)],
237251
)
238252
return ListModelResponse.model_validate(response.content)
239253

240254
async def get_headline_questions(
241-
self, queries: Optional[List[str]] = None
242-
) -> Dict[str, List[str]]:
255+
self,
256+
queries: Optional[List[str]] = None,
257+
*,
258+
http_headers: Optional[Dict] = None,
259+
) -> HeadlineQuestionsResponse:
243260
"""
244261
Get headline questions for a given query.
245262
246263
https://docs.asknews.app/en/reference#get-/v1/chat/questions
247264
248265
:param queries: List of queries to get headline questions for
249266
:type queries: Optional[List[str]]
267+
:param http_headers: Additional HTTP headers.
268+
:type http_headers: Optional[Dict]
250269
:return: Headline questions
251-
:rtype: Dict[str, List[str]]
270+
:rtype: HeadlineQuestionsResponse
252271
"""
253272
response = await self.client.request(
254-
method="GET", endpoint="/v1/chat/questions", query={"queries": queries}
273+
method="GET",
274+
endpoint="/v1/chat/questions",
275+
headers=http_headers,
276+
query={"queries": queries}
255277
)
256-
return response.content
278+
return HeadlineQuestionsResponse.model_validate(response.content)

0 commit comments

Comments
 (0)