Skip to content

Commit 3fd6e57

Browse files
feat(client): setup streaming
1 parent fa47f44 commit 3fd6e57

File tree

10 files changed

+893
-28
lines changed

10 files changed

+893
-28
lines changed

.stats.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
configured_endpoints: 76
22
openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/digitalocean%2Fgradientai-e8b3cbc80e18e4f7f277010349f25e1319156704f359911dc464cc21a0d077a6.yml
33
openapi_spec_hash: c773d792724f5647ae25a5ae4ccec208
4-
config_hash: f0976fbc552ea878bb527447b5e663c9
4+
config_hash: e1b3d85ba9ae21d729a914c789422ba7

api.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ Methods:
6565
Types:
6666

6767
```python
68-
from gradientai.types.agents.chat import CompletionCreateResponse
68+
from gradientai.types.agents.chat import ChatCompletionChunk, CompletionCreateResponse
6969
```
7070

7171
Methods:

src/gradientai/_client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@ def __init__(
117117
_strict_response_validation=_strict_response_validation,
118118
)
119119

120+
self._default_stream_cls = Stream
121+
120122
@cached_property
121123
def agents(self) -> AgentsResource:
122124
from .resources.agents import AgentsResource
@@ -355,6 +357,8 @@ def __init__(
355357
_strict_response_validation=_strict_response_validation,
356358
)
357359

360+
self._default_stream_cls = AsyncStream
361+
358362
@cached_property
359363
def agents(self) -> AsyncAgentsResource:
360364
from .resources.agents import AsyncAgentsResource

src/gradientai/_streaming.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010
import httpx
1111

12-
from ._utils import extract_type_var_from_base
12+
from ._utils import is_mapping, extract_type_var_from_base
13+
from ._exceptions import APIError
1314

1415
if TYPE_CHECKING:
1516
from ._client import GradientAI, AsyncGradientAI
@@ -55,7 +56,25 @@ def __stream__(self) -> Iterator[_T]:
5556
iterator = self._iter_events()
5657

5758
for sse in iterator:
58-
yield process_data(data=sse.json(), cast_to=cast_to, response=response)
59+
if sse.data.startswith("[DONE]"):
60+
break
61+
62+
data = sse.json()
63+
if is_mapping(data) and data.get("error"):
64+
message = None
65+
error = data.get("error")
66+
if is_mapping(error):
67+
message = error.get("message")
68+
if not message or not isinstance(message, str):
69+
message = "An error occurred during streaming"
70+
71+
raise APIError(
72+
message=message,
73+
request=self.response.request,
74+
body=data["error"],
75+
)
76+
77+
yield process_data(data=data, cast_to=cast_to, response=response)
5978

6079
# Ensure the entire stream is consumed
6180
for _sse in iterator:
@@ -119,7 +138,25 @@ async def __stream__(self) -> AsyncIterator[_T]:
119138
iterator = self._iter_events()
120139

121140
async for sse in iterator:
122-
yield process_data(data=sse.json(), cast_to=cast_to, response=response)
141+
if sse.data.startswith("[DONE]"):
142+
break
143+
144+
data = sse.json()
145+
if is_mapping(data) and data.get("error"):
146+
message = None
147+
error = data.get("error")
148+
if is_mapping(error):
149+
message = error.get("message")
150+
if not message or not isinstance(message, str):
151+
message = "An error occurred during streaming"
152+
153+
raise APIError(
154+
message=message,
155+
request=self.response.request,
156+
body=data["error"],
157+
)
158+
159+
yield process_data(data=data, cast_to=cast_to, response=response)
123160

124161
# Ensure the entire stream is consumed
125162
async for _sse in iterator:

0 commit comments

Comments
 (0)