Skip to content

Commit bc9b055

Browse files
jameszyaoSimsonW
authored andcommitted
feat: add stream feature
1 parent faeeaf1 commit bc9b055

File tree

14 files changed

+289
-148
lines changed

14 files changed

+289
-148
lines changed

examples/inference/chat_completion.ipynb

Lines changed: 89 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
"metadata": {
1212
"collapsed": false,
1313
"ExecuteTime": {
14-
"end_time": "2023-11-28T12:07:45.292065Z",
15-
"start_time": "2023-11-28T12:07:45.248068Z"
14+
"end_time": "2023-11-30T11:03:32.009173Z",
15+
"start_time": "2023-11-30T11:03:30.470410Z"
1616
}
1717
},
1818
"id": "1a6bfd1682fcb23f"
@@ -40,8 +40,8 @@
4040
"metadata": {
4141
"collapsed": false,
4242
"ExecuteTime": {
43-
"end_time": "2023-11-28T12:07:46.170766Z",
44-
"start_time": "2023-11-28T12:07:46.157613Z"
43+
"end_time": "2023-11-30T11:03:35.629715Z",
44+
"start_time": "2023-11-30T11:03:35.602754Z"
4545
}
4646
},
4747
"id": "49abde692940b09e"
@@ -52,7 +52,7 @@
5252
"outputs": [
5353
{
5454
"data": {
55-
"text/plain": "{'created_timestamp': 1701173269243,\n 'finish_reason': 'stop',\n 'message': {'content': 'Hello! How can I assist you today?',\n 'function_call': None,\n 'role': 'assistant'},\n 'object': 'ChatCompletion'}"
55+
"text/plain": "ChatCompletion(object='ChatCompletion', finish_reason=<ChatCompletionFinishReason.stop: 'stop'>, message=ChatCompletionAssistantMessage(content='Hello! How can I assist you today?', role=<ChatCompletionRole.assistant: 'assistant'>, function_call=None), created_timestamp=1701342218014)"
5656
},
5757
"execution_count": 3,
5858
"metadata": {},
@@ -61,20 +61,20 @@
6161
],
6262
"source": [
6363
"# normal \n",
64-
"chat_completion = taskingai.inference.chat_completion(\n",
64+
"chat_completion_result = taskingai.inference.chat_completion(\n",
6565
" model_id=model_id,\n",
6666
" messages=[\n",
6767
" SystemMessage(\"You are a professional assistant.\"),\n",
6868
" UserMessage(\"Hi\"),\n",
6969
" ]\n",
7070
")\n",
71-
"chat_completion"
71+
"chat_completion_result"
7272
],
7373
"metadata": {
7474
"collapsed": false,
7575
"ExecuteTime": {
76-
"end_time": "2023-11-28T12:07:49.355234Z",
77-
"start_time": "2023-11-28T12:07:46.700962Z"
76+
"end_time": "2023-11-30T11:03:38.108157Z",
77+
"start_time": "2023-11-30T11:03:36.349445Z"
7878
}
7979
},
8080
"id": "43dcc632665f0de4"
@@ -85,7 +85,7 @@
8585
"outputs": [
8686
{
8787
"data": {
88-
"text/plain": "{'created_timestamp': 1701173272255,\n 'finish_reason': 'stop',\n 'message': {'content': \"Of course! Here's another joke for you: Why don't \"\n \"skeletons fight each other? They don't have the guts!\",\n 'function_call': None,\n 'role': 'assistant'},\n 'object': 'ChatCompletion'}"
88+
"text/plain": "ChatCompletion(object='ChatCompletion', finish_reason=<ChatCompletionFinishReason.stop: 'stop'>, message=ChatCompletionAssistantMessage(content=\"Of course! How about this one: Why don't skeletons fight each other? They don't have the guts!\", role=<ChatCompletionRole.assistant: 'assistant'>, function_call=None), created_timestamp=1701342220370)"
8989
},
9090
"execution_count": 4,
9191
"metadata": {},
@@ -94,7 +94,7 @@
9494
],
9595
"source": [
9696
"# multi round chat completion\n",
97-
"chat_completion = taskingai.inference.chat_completion(\n",
97+
"chat_completion_result = taskingai.inference.chat_completion(\n",
9898
" model_id=model_id,\n",
9999
" messages=[\n",
100100
" SystemMessage(\"You are a professional assistant.\"),\n",
@@ -105,13 +105,13 @@
105105
" UserMessage(\"That's funny. Can you tell me another one?\"),\n",
106106
" ]\n",
107107
")\n",
108-
"chat_completion"
108+
"chat_completion_result"
109109
],
110110
"metadata": {
111111
"collapsed": false,
112112
"ExecuteTime": {
113-
"end_time": "2023-11-28T12:07:52.367618Z",
114-
"start_time": "2023-11-28T12:07:50.109888Z"
113+
"end_time": "2023-11-30T11:03:40.460984Z",
114+
"start_time": "2023-11-30T11:03:39.092991Z"
115115
}
116116
},
117117
"id": "e8933bc07f4b3228"
@@ -122,7 +122,7 @@
122122
"outputs": [
123123
{
124124
"data": {
125-
"text/plain": "{'created_timestamp': 1701173274744,\n 'finish_reason': 'length',\n 'message': {'content': \"Of course! Here's\",\n 'function_call': None,\n 'role': 'assistant'},\n 'object': 'ChatCompletion'}"
125+
"text/plain": "ChatCompletion(object='ChatCompletion', finish_reason=<ChatCompletionFinishReason.length: 'length'>, message=ChatCompletionAssistantMessage(content=\"Of course! Here's\", role=<ChatCompletionRole.assistant: 'assistant'>, function_call=None), created_timestamp=1701342221823)"
126126
},
127127
"execution_count": 5,
128128
"metadata": {},
@@ -131,7 +131,7 @@
131131
],
132132
"source": [
133133
"# config max tokens\n",
134-
"chat_completion = taskingai.inference.chat_completion(\n",
134+
"chat_completion_result = taskingai.inference.chat_completion(\n",
135135
" model_id=model_id,\n",
136136
" messages=[\n",
137137
" SystemMessage(\"You are a professional assistant.\"),\n",
@@ -145,13 +145,13 @@
145145
" \"max_tokens\": 5\n",
146146
" }\n",
147147
")\n",
148-
"chat_completion"
148+
"chat_completion_result"
149149
],
150150
"metadata": {
151151
"collapsed": false,
152152
"ExecuteTime": {
153-
"end_time": "2023-11-28T12:07:54.817719Z",
154-
"start_time": "2023-11-28T12:07:53.137411Z"
153+
"end_time": "2023-11-30T11:03:41.911247Z",
154+
"start_time": "2023-11-30T11:03:40.969583Z"
155155
}
156156
},
157157
"id": "f7c1b8be2579d9e0"
@@ -164,13 +164,7 @@
164164
"name": "stdout",
165165
"output_type": "stream",
166166
"text": [
167-
"chat_completion = {'created_timestamp': 1701173277776,\n",
168-
" 'finish_reason': 'function_call',\n",
169-
" 'message': {'content': None,\n",
170-
" 'function_call': {'arguments': {'a': 112, 'b': 22},\n",
171-
" 'name': 'plus_a_and_b'},\n",
172-
" 'role': 'assistant'},\n",
173-
" 'object': 'ChatCompletion'}\n",
167+
"chat_completion_result = object='ChatCompletion' finish_reason=<ChatCompletionFinishReason.function_calls: 'function_call'> message=ChatCompletionAssistantMessage(content=None, role=<ChatCompletionRole.assistant: 'assistant'>, function_call=ChatCompletionFunctionCall(arguments={'a': 112, 'b': 22}, name='plus_a_and_b')) created_timestamp=1701342223979\n",
174168
"function name: plus_a_and_b, argument content: {\"a\": 112, \"b\": 22}\n"
175169
]
176170
}
@@ -195,26 +189,26 @@
195189
" \"required\": [\"a\", \"b\"]\n",
196190
" },\n",
197191
")\n",
198-
"chat_completion = taskingai.inference.chat_completion(\n",
192+
"chat_completion_result = taskingai.inference.chat_completion(\n",
199193
" model_id=model_id,\n",
200194
" messages=[\n",
201195
" SystemMessage(\"You are a professional assistant.\"),\n",
202196
" UserMessage(\"What is the result of 112 plus 22?\"),\n",
203197
" ],\n",
204198
" functions=[function]\n",
205199
")\n",
206-
"print(f\"chat_completion = {chat_completion}\")\n",
200+
"print(f\"chat_completion_result = {chat_completion_result}\")\n",
207201
"\n",
208-
"assistant_function_call_message = chat_completion.message\n",
202+
"assistant_function_call_message = chat_completion_result.message\n",
209203
"fucntion_name = assistant_function_call_message.function_call.name\n",
210204
"argument_content = json.dumps(assistant_function_call_message.function_call.arguments)\n",
211205
"print(f\"function name: {fucntion_name}, argument content: {argument_content}\")"
212206
],
213207
"metadata": {
214208
"collapsed": false,
215209
"ExecuteTime": {
216-
"end_time": "2023-11-28T12:07:57.823570Z",
217-
"start_time": "2023-11-28T12:07:55.601317Z"
210+
"end_time": "2023-11-30T11:03:44.068090Z",
211+
"start_time": "2023-11-30T11:03:42.646835Z"
218212
}
219213
},
220214
"id": "2645bdc3df011e7d"
@@ -233,7 +227,7 @@
233227
"outputs": [
234228
{
235229
"data": {
236-
"text/plain": "{'created_timestamp': 1701173282280,\n 'finish_reason': 'stop',\n 'message': {'content': 'The result of 112 plus 22 is 144.',\n 'function_call': None,\n 'role': 'assistant'},\n 'object': 'ChatCompletion'}"
230+
"text/plain": "ChatCompletion(object='ChatCompletion', finish_reason=<ChatCompletionFinishReason.stop: 'stop'>, message=ChatCompletionAssistantMessage(content='The result of 112 plus 22 is 134.', role=<ChatCompletionRole.assistant: 'assistant'>, function_call=None), created_timestamp=1701342226882)"
237231
},
238232
"execution_count": 7,
239233
"metadata": {},
@@ -252,16 +246,76 @@
252246
" ],\n",
253247
" functions=[function]\n",
254248
")\n",
255-
"chat_completion"
249+
"chat_completion_result"
256250
],
257251
"metadata": {
258252
"collapsed": false,
259253
"ExecuteTime": {
260-
"end_time": "2023-11-28T12:08:02.319026Z",
261-
"start_time": "2023-11-28T12:08:00.109622Z"
254+
"end_time": "2023-11-30T11:03:46.969143Z",
255+
"start_time": "2023-11-30T11:03:46.043885Z"
262256
}
263257
},
264258
"id": "9df9a8b9eafa17d9"
259+
},
260+
{
261+
"cell_type": "code",
262+
"execution_count": 14,
263+
"outputs": [
264+
{
265+
"name": "stdout",
266+
"output_type": "stream",
267+
"text": [
268+
"1 2 3 4 5 6 7 8 9 10 \n",
269+
"11 12 13 14 15 16 17 18 19 20 \n",
270+
"21 22 23 24 25 26 27 28 29 30 \n",
271+
"31 32 33 34 35 36 37 38 39 40 \n",
272+
"41 42 43 44 45 46 47 48 49 50 \n",
273+
"51 52 53 54 55 56 57 58 59 60 \n",
274+
"61 62 63 64 65 66 67 68 69 70 \n",
275+
"71 72 73 74 75 76 77 78 79 80 \n",
276+
"81 82 83 84 85 86 87 88 89 90 \n",
277+
"91 92 93 94 95 96 97 98 99 100\n",
278+
" message object: content='1 2 3 4 5 6 7 8 9 10 \\n11 12 13 14 15 16 17 18 19 20 \\n21 22 23 24 25 26 27 28 29 30 \\n31 32 33 34 35 36 37 38 39 40 \\n41 42 43 44 45 46 47 48 49 50 \\n51 52 53 54 55 56 57 58 59 60 \\n61 62 63 64 65 66 67 68 69 70 \\n71 72 73 74 75 76 77 78 79 80 \\n81 82 83 84 85 86 87 88 89 90 \\n91 92 93 94 95 96 97 98 99 100' role=<ChatCompletionRole.assistant: 'assistant'> function_call=None\n"
279+
]
280+
}
281+
],
282+
"source": [
283+
"# generate in stream mode\n",
284+
"chat_completion_result = taskingai.inference.chat_completion(\n",
285+
" model_id=model_id,\n",
286+
" messages=[\n",
287+
" UserMessage(\"Please count from 1 to 100, and add a new line after every 10 numbers.\"),\n",
288+
" ],\n",
289+
" stream=True\n",
290+
")\n",
291+
"for item in chat_completion_result:\n",
292+
" \n",
293+
" if item.get(\"object\") == \"ChatCompletionChunk\":\n",
294+
" chunk = ChatCompletionChunk(**item)\n",
295+
" print(chunk.delta, end=\"\", flush=True)\n",
296+
" \n",
297+
" elif item.get(\"object\") == \"ChatCompletion\":\n",
298+
" result = ChatCompletion(**item)\n",
299+
" print(f\"\\n message object: {result.message}\")"
300+
],
301+
"metadata": {
302+
"collapsed": false,
303+
"ExecuteTime": {
304+
"end_time": "2023-11-30T11:08:05.624547Z",
305+
"start_time": "2023-11-30T11:07:58.499361Z"
306+
}
307+
},
308+
"id": "4f3290f87635152a"
309+
},
310+
{
311+
"cell_type": "code",
312+
"execution_count": null,
313+
"outputs": [],
314+
"source": [],
315+
"metadata": {
316+
"collapsed": false
317+
},
318+
"id": "dab2cb3f84463700"
265319
}
266320
],
267321
"metadata": {

taskingai/assistant/message.py

Lines changed: 48 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from taskingai.client.models import MessageCreateRequest, MessageCreateResponse, \
66
MessageUpdateRequest, MessageUpdateResponse, \
77
MessageGetResponse, MessageListResponse, MessageGenerateRequest
8+
from taskingai.client.stream import Stream, AsyncStream
89

910
__all__ = [
1011
"Message",
@@ -275,8 +276,9 @@ def generate_assistant_message(
275276
assistant_id: str,
276277
chat_id: str,
277278
system_prompt_variables: Optional[Dict] = None,
278-
# todo: add stream and debug support
279-
) -> Message:
279+
stream: bool = False,
280+
debug: bool = False,
281+
) -> Message | Stream:
280282
"""
281283
Generate a message.
282284
@@ -289,23 +291,36 @@ def generate_assistant_message(
289291

290292
api_instance = get_api_instance(ModuleType.assistant)
291293
body = MessageGenerateRequest(
292-
options=MessageGenerationResponseOption(stream=False, debug=False),
293-
system_prompt_variables=system_prompt_variables
294-
)
295-
response = api_instance.generate_assistant_message(
296-
assistant_id=assistant_id,
297-
chat_id=chat_id,
298-
body=body
294+
options=MessageGenerationResponseOption(stream=stream, debug=debug),
295+
system_prompt_variables=system_prompt_variables,
299296
)
300-
message: Message = Message(**response["data"])
301-
return message
302297

298+
if not stream and not debug:
299+
response = api_instance.generate_assistant_message(
300+
assistant_id=assistant_id,
301+
chat_id=chat_id,
302+
body=body,
303+
stream=False,
304+
)
305+
message: Message = Message(**response["data"])
306+
return message
307+
else:
308+
response: Stream = api_instance.generate_assistant_message(
309+
assistant_id=assistant_id,
310+
chat_id=chat_id,
311+
body=body,
312+
stream=True,
313+
_preload_content=False
314+
)
315+
return response
303316

304317
async def a_generate_assistant_message(
305318
assistant_id: str,
306319
chat_id: str,
307320
system_prompt_variables: Optional[Dict] = None,
308-
) -> Message:
321+
stream: bool = False,
322+
debug: bool = False,
323+
) -> Message | AsyncStream:
309324
"""
310325
Generate a message in async mode.
311326
@@ -318,14 +333,27 @@ async def a_generate_assistant_message(
318333

319334
api_instance = get_api_instance(ModuleType.assistant, async_client=True)
320335
body = MessageGenerateRequest(
321-
options=MessageGenerationResponseOption(stream=False, debug=False),
336+
options=MessageGenerationResponseOption(stream=stream, debug=debug),
322337
system_prompt_variables=system_prompt_variables
323338
)
324-
response = await api_instance.generate_assistant_message(
325-
assistant_id=assistant_id,
326-
chat_id=chat_id,
327-
body=body
328-
)
329-
message: Message = Message(**response["data"])
330-
return message
339+
340+
if not stream and not debug:
341+
response = await api_instance.generate_assistant_message(
342+
assistant_id=assistant_id,
343+
chat_id=chat_id,
344+
body=body,
345+
stream=False,
346+
)
347+
message: Message = Message(**response["data"])
348+
return message
349+
else:
350+
response: AsyncStream = await api_instance.generate_assistant_message(
351+
assistant_id=assistant_id,
352+
chat_id=chat_id,
353+
body=body,
354+
stream=True,
355+
_preload_content=False
356+
)
357+
return response
358+
331359

0 commit comments

Comments
 (0)