Skip to content

Commit faeeaf1

Browse files
jameszyaoSimsonW
authored andcommitted
feat: finish sync chat_completion stream
1 parent e89583f commit faeeaf1

File tree

4 files changed

+65
-96
lines changed

4 files changed

+65
-96
lines changed

taskingai/client/api/inference_api.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
import six
1515

1616
from ..api_client import SyncApiClient
17-
from ..stream import Stream
18-
from ..models import INFERENCE_CHAT_COMPLETION_STREAM_CAST_MAP
1917

2018
class InferenceApi(object):
2119

@@ -34,16 +32,7 @@ def chat_completion(self, body, stream = False, **kwargs): # noqa: E501
3432
returns the request thread.
3533
"""
3634
kwargs['_return_http_data_only'] = True
37-
cast_map = INFERENCE_CHAT_COMPLETION_STREAM_CAST_MAP
38-
response = self.chat_completion_with_http_info(body, stream, **kwargs)
39-
if not stream:
40-
return response
41-
else:
42-
return Stream(
43-
cast_map=cast_map,
44-
response=response,
45-
client=self.api_client
46-
)
35+
return self.chat_completion_with_http_info(body, stream, **kwargs)
4736

4837
def chat_completion_with_http_info(self, body, stream, **kwargs): # noqa: E501
4938
"""Chat Completion # noqa: E501

taskingai/client/rest.py

Lines changed: 53 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
# python 2 and python 3 compatibility library
1919
import httpx
2020
from httpx import HTTPError
21+
from .stream import Stream
2122

2223
logger = logging.getLogger(__name__)
2324

@@ -71,9 +72,21 @@ def __init__(self, configuration, pools_size=4, maxsize=None):
7172
if configuration.cert_file and configuration.key_file:
7273
self.client.cert = (configuration.cert_file, configuration.key_file)
7374

74-
def request(self, method, url, stream = False, query_params=None, headers=None,
75+
def _stream_generator(self, method, url, query_params, headers, request_body, _request_timeout):
76+
"""Generator function for streaming requests."""
77+
with self.client.stream(
78+
method, url,
79+
params=query_params,
80+
headers=headers,
81+
content=request_body,
82+
timeout=_request_timeout
83+
) as response:
84+
for line in response.iter_lines():
85+
yield line
86+
87+
def request(self, method, url, stream=False, query_params=None, headers=None,
7588
body=None, post_params=None, _preload_content=True,
76-
_request_timeout=None) -> RESTResponse | httpx.Response:
89+
_request_timeout=None) -> RESTResponse | Stream:
7790
"""
7891
Perform asynchronous HTTP requests.
7992
@@ -110,14 +123,9 @@ def request(self, method, url, stream = False, query_params=None, headers=None,
110123

111124
try:
112125
if stream:
113-
with self.client.stream(
114-
method, url,
115-
params=query_params,
116-
headers=headers,
117-
content=request_body,
118-
timeout=_request_timeout
119-
) as r:
120-
return r
126+
return Stream(stream_generator=self._stream_generator(
127+
method, url, query_params, headers, request_body, _request_timeout
128+
))
121129
else:
122130
r = self.client.request(
123131
method, url,
@@ -210,6 +218,7 @@ def PATCH(self, url, stream=False, headers=None, query_params=None, post_params=
210218
_request_timeout=_request_timeout,
211219
body=body)
212220

221+
213222
class RESTAsyncClientObject(object):
214223

215224
def __init__(self, configuration, pools_size=4, maxsize=None):
@@ -297,53 +306,53 @@ async def HEAD(self, url, headers=None, query_params=None, _preload_content=True
297306
query_params=query_params)
298307

299308
async def OPTIONS(self, url, headers=None, query_params=None, post_params=None,
300-
body=None, _preload_content=True, _request_timeout=None):
309+
body=None, _preload_content=True, _request_timeout=None):
301310
return await self.request("OPTIONS", url,
302-
headers=headers,
303-
query_params=query_params,
304-
post_params=post_params,
305-
_preload_content=_preload_content,
306-
_request_timeout=_request_timeout,
307-
body=body)
311+
headers=headers,
312+
query_params=query_params,
313+
post_params=post_params,
314+
_preload_content=_preload_content,
315+
_request_timeout=_request_timeout,
316+
body=body)
308317

309318
async def DELETE(self, url, headers=None, query_params=None, body=None,
310-
_preload_content=True, _request_timeout=None):
319+
_preload_content=True, _request_timeout=None):
311320
return await self.request("DELETE", url,
312-
headers=headers,
313-
query_params=query_params,
314-
_preload_content=_preload_content,
315-
_request_timeout=_request_timeout,
316-
body=body)
321+
headers=headers,
322+
query_params=query_params,
323+
_preload_content=_preload_content,
324+
_request_timeout=_request_timeout,
325+
body=body)
317326

318327
async def POST(self, url, headers=None, query_params=None, post_params=None,
319-
body=None, _preload_content=True, _request_timeout=None):
328+
body=None, _preload_content=True, _request_timeout=None):
320329
return await self.request("POST", url,
321-
headers=headers,
322-
query_params=query_params,
323-
post_params=post_params,
324-
_preload_content=_preload_content,
325-
_request_timeout=_request_timeout,
326-
body=body)
330+
headers=headers,
331+
query_params=query_params,
332+
post_params=post_params,
333+
_preload_content=_preload_content,
334+
_request_timeout=_request_timeout,
335+
body=body)
327336

328337
async def PUT(self, url, headers=None, query_params=None, post_params=None,
329-
body=None, _preload_content=True, _request_timeout=None):
338+
body=None, _preload_content=True, _request_timeout=None):
330339
return await self.request("PUT", url,
331-
headers=headers,
332-
query_params=query_params,
333-
post_params=post_params,
334-
_preload_content=_preload_content,
335-
_request_timeout=_request_timeout,
336-
body=body)
340+
headers=headers,
341+
query_params=query_params,
342+
post_params=post_params,
343+
_preload_content=_preload_content,
344+
_request_timeout=_request_timeout,
345+
body=body)
337346

338347
async def PATCH(self, url, headers=None, query_params=None, post_params=None,
339-
body=None, _preload_content=True, _request_timeout=None):
348+
body=None, _preload_content=True, _request_timeout=None):
340349
return await self.request("PATCH", url,
341-
headers=headers,
342-
query_params=query_params,
343-
post_params=post_params,
344-
_preload_content=_preload_content,
345-
_request_timeout=_request_timeout,
346-
body=body)
350+
headers=headers,
351+
query_params=query_params,
352+
post_params=post_params,
353+
_preload_content=_preload_content,
354+
_request_timeout=_request_timeout,
355+
body=body)
347356

348357

349358
class ApiException(Exception):

taskingai/client/stream.py

Lines changed: 11 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -7,52 +7,28 @@
77
import httpx
88

99
from .models.entity._base import TaskingaiBaseModel
10-
from .exceptions import ApiException
11-
12-
from .rest import RESTSyncClientObject, RESTAsyncClientObject
1310

11+
from .exceptions import ApiException
1412

1513
class Stream(object):
1614
"""Provides the core interface to iterate over a synchronous stream response."""
1715

18-
response: httpx.Response
19-
20-
def __init__(
21-
self,
22-
*,
23-
cast_map: Dict[str, Type[TaskingaiBaseModel]],
24-
response: httpx.Response,
25-
client: RESTSyncClientObject,
26-
) -> None:
27-
if not isinstance(response, httpx.Response):
28-
raise TypeError("response must be an httpx.Response object")
29-
30-
self.response = response
31-
self._cast_map = cast_map
32-
self._client = client
16+
def __init__(self, stream_generator):
17+
self._stream_generator = stream_generator
3318
self._decoder = SSEDecoder()
3419
self._iterator = self.__stream__()
3520

36-
def __next__(self) -> TaskingaiBaseModel:
21+
def __next__(self):
3722
return self._iterator.__next__()
3823

39-
def __iter__(self) -> Iterator[TaskingaiBaseModel]:
24+
def __iter__(self):
4025
for item in self._iterator:
4126
yield item
4227

43-
def _iter_events(self) -> Iterator[ServerSentEvent]:
44-
yield from self._decoder.iter(self.response.iter_lines())
45-
46-
def _cast(self, obj_dict, class_type) -> TaskingaiBaseModel:
47-
cast_map = self._cast_map
48-
if class_type in cast_map:
49-
return cast_map[class_type](**obj_dict)
50-
else:
51-
raise ValueError(f"No class found for type '{class_type}'")
28+
def _iter_events(self):
29+
yield from self._decoder.iter(self._stream_generator)
5230

53-
def __stream__(self) -> Iterator[TaskingaiBaseModel]:
54-
print("streaming...")
55-
response = self.response
31+
def __stream__(self):
5632
iterator = self._iter_events()
5733

5834
for sse in iterator:
@@ -63,14 +39,11 @@ def __stream__(self) -> Iterator[TaskingaiBaseModel]:
6339
data = sse.json()
6440
if isinstance(data, Dict) and data.get("error"):
6541
raise ApiException(
66-
status=response.status_code,
67-
reason="An error ocurred during streaming",
68-
http_resp=response,
42+
status=400, # or appropriate status code
43+
reason="An error occurred during streaming",
6944
)
7045

71-
object_type = data.get("object")
72-
# todo: raise valid format error
73-
yield self._cast(data, object_type)
46+
yield data
7447

7548
# Ensure the entire stream is consumed
7649
for sse in iterator:

taskingai/inference/chat_completion.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,10 @@ def chat_completion(
9696
stream=stream
9797
)
9898
if not stream:
99-
print("not streaming")
10099
response: ChatCompletionResponse = api_instance.chat_completion(body=body)
101100
chat_completion_result: ChatCompletion = ChatCompletion(**response["data"])
102101
return chat_completion_result
103102
else:
104-
print("streaming")
105103
response: Stream = api_instance.chat_completion(body=body, stream=True)
106104
return response
107105

0 commit comments

Comments
 (0)