|
9 | 9 | CreateChatCompletionRequest, |
10 | 10 | CreateChatCompletionResponse, |
11 | 11 | CreateChatCompletionResponseStream, |
| 12 | + CreateChatCompletionResponseStreamError, |
12 | 13 | FilterParamsResponse, |
13 | 14 | ForecastResponse, |
14 | 15 | HeadlineQuestionsResponse, |
|
21 | 22 | CreateDeepNewsResponse, |
22 | 23 | CreateDeepNewsResponseStream, |
23 | 24 | CreateDeepNewsResponseStreamChunk, |
| 25 | + CreateDeepNewsResponseStreamError, |
24 | 26 | CreateDeepNewsResponseStreamSource, |
25 | 27 | ) |
| 28 | +from asknews_sdk.errors import APIError |
26 | 29 | from asknews_sdk.response import EventSource |
27 | 30 |
|
28 | 31 |
|
@@ -120,7 +123,23 @@ def _stream(): |
120 | 123 | for event in EventSource.from_api_response(response): |
121 | 124 | if event.content == "[DONE]": |
122 | 125 | break |
123 | | - yield CreateChatCompletionResponseStream.model_validate_json(event.content) |
| 126 | + |
| 127 | + token = ( |
| 128 | + TypeAdapter(Union[ |
| 129 | + CreateChatCompletionResponseStreamError, |
| 130 | + CreateChatCompletionResponseStream |
| 131 | + ]) |
| 132 | + .validate_json(event.content) |
| 133 | + ) |
| 134 | + |
| 135 | + if isinstance(token, CreateChatCompletionResponseStreamError): |
| 136 | + raise APIError( |
| 137 | + response=response, |
| 138 | + detail=token.error.message, |
| 139 | + code=token.error.code, |
| 140 | + ) |
| 141 | + |
| 142 | + yield token |
124 | 143 |
|
125 | 144 | return _stream() |
126 | 145 | else: |
@@ -524,7 +543,22 @@ def _stream(): |
524 | 543 | if event.content == "[DONE]": |
525 | 544 | break |
526 | 545 |
|
527 | | - yield TypeAdapter(CreateDeepNewsResponseStream).validate_json(event.content) |
| 546 | + token = ( |
| 547 | + TypeAdapter(Union[ |
| 548 | + CreateDeepNewsResponseStreamError, |
| 549 | + CreateDeepNewsResponseStream |
| 550 | + ]) |
| 551 | + .validate_json(event.content) |
| 552 | + ) |
| 553 | + |
| 554 | + if isinstance(token, CreateDeepNewsResponseStreamError): |
| 555 | + raise APIError( |
| 556 | + response=response, |
| 557 | + detail=token.error.message, |
| 558 | + code=token.error.code, |
| 559 | + ) |
| 560 | + |
| 561 | + yield token |
528 | 562 |
|
529 | 563 | return _stream() |
530 | 564 | else: |
@@ -625,7 +659,23 @@ async def _stream(): |
625 | 659 | async for event in EventSource.from_api_response(response): |
626 | 660 | if event.content == "[DONE]": |
627 | 661 | break |
628 | | - yield CreateChatCompletionResponseStream.model_validate_json(event.content) |
| 662 | + |
| 663 | + token = ( |
| 664 | + TypeAdapter(Union[ |
| 665 | + CreateChatCompletionResponseStreamError, |
| 666 | + CreateChatCompletionResponseStream |
| 667 | + ]) |
| 668 | + .validate_json(event.content) |
| 669 | + ) |
| 670 | + |
| 671 | + if isinstance(token, CreateChatCompletionResponseStreamError): |
| 672 | + raise APIError( |
| 673 | + response=response, |
| 674 | + detail=token.error.message, |
| 675 | + code=token.error.code, |
| 676 | + ) |
| 677 | + |
| 678 | + yield token |
629 | 679 |
|
630 | 680 | return _stream() |
631 | 681 | else: |
@@ -1022,13 +1072,27 @@ async def get_deep_news( |
1022 | 1072 | ) |
1023 | 1073 |
|
1024 | 1074 | if stream: |
1025 | | - |
1026 | 1075 | async def _stream(): |
1027 | 1076 | async for event in EventSource.from_api_response(response): |
1028 | 1077 | if event.content == "[DONE]": |
1029 | 1078 | break |
1030 | 1079 |
|
1031 | | - yield TypeAdapter(CreateDeepNewsResponseStream).validate_json(event.content) |
| 1080 | + token = ( |
| 1081 | + TypeAdapter(Union[ |
| 1082 | + CreateDeepNewsResponseStreamError, |
| 1083 | + CreateDeepNewsResponseStream |
| 1084 | + ]) |
| 1085 | + .validate_json(event.content) |
| 1086 | + ) |
| 1087 | + |
| 1088 | + if isinstance(token, CreateDeepNewsResponseStreamError): |
| 1089 | + raise APIError( |
| 1090 | + response=response, |
| 1091 | + detail=token.error.message, |
| 1092 | + code=token.error.code, |
| 1093 | + ) |
| 1094 | + |
| 1095 | + yield token |
1032 | 1096 |
|
1033 | 1097 | return _stream() |
1034 | 1098 | else: |
|
0 commit comments