Skip to content

Commit 29e9fce

Browse files
committed
add stream_options in streaming mode
1 parent 1b2cdfe commit 29e9fce

File tree

7 files changed

+71
-0
lines changed

7 files changed

+71
-0
lines changed

app/core/runner/llm_backend.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def run(
2222
tools: List = None,
2323
tool_choice="auto",
2424
stream=False,
25+
stream_options=None,
2526
extra_body=None,
2627
temperature=None,
2728
top_p=None,
@@ -38,6 +39,10 @@ def run(
3839
if "n" in model_params:
3940
raise ValueError("n is not allowed in model_params")
4041
chat_params.update(model_params)
42+
if stream_options:
43+
if isinstance(stream_options, dict):
44+
if "include_usage" in stream_options:
45+
chat_params["stream_options"] = {"include_usage": bool(stream_options["include_usage"])}
4146
if temperature:
4247
chat_params["temperature"] = temperature
4348
if top_p:

app/core/runner/llm_callback_handler.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ def handle_llm_response(
4141
for chunk in response_stream:
4242
logging.debug(chunk)
4343

44+
if chunk.usage:
45+
self.event_handler.pub_message_usage(chunk)
46+
continue
47+
4448
if not chunk.choices:
4549
continue
4650

app/core/runner/pub_handler.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,24 @@ def pub_message_in_progress(self, message):
214214
events.ThreadMessageInProgress(data=_data_adjust_message(message), event="thread.message.in_progress")
215215
)
216216

217+
def pub_message_usage(self, chunk):
218+
"""
219+
目前 stream 未有 usage 相关 event,借用 thread.message.in_progress 进行传输,待官方更新
220+
"""
221+
data = {
222+
"id": chunk.id,
223+
"content": [],
224+
"created_at": 0,
225+
"object": "thread.message",
226+
"role": "assistant",
227+
"status": "in_progress",
228+
"thread_id": "",
229+
"metadata": {"usage": chunk.usage.json()}
230+
}
231+
self.pub_event(
232+
events.ThreadMessageInProgress(data=data, event="thread.message.in_progress")
233+
)
234+
217235
def pub_message_completed(self, message):
218236
self.pub_event(
219237
events.ThreadMessageCompleted(data=_data_adjust_message(message), event="thread.message.completed")

app/core/runner/thread_runner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def __run_step(
129129
tools=[tool.openai_function for tool in tools],
130130
tool_choice="auto" if len(run_steps) < self.max_step else "none",
131131
stream=True,
132+
stream_options=run.stream_options,
132133
extra_body=run.extra_body,
133134
temperature=run.temperature,
134135
top_p=run.top_p,

app/models/run.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class RunBase(BaseModel):
4949
failed_at: Optional[datetime] = Field(default=None)
5050
additional_instructions: Optional[str] = Field(default=None, max_length=32768, sa_column=Column(TEXT))
5151
extra_body: Optional[dict] = Field(default={}, sa_column=Column(JSON))
52+
stream_options: Optional[dict] = Field(default=None, sa_column=Column(JSON))
5253
incomplete_details: Optional[str] = Field(default=None) # 未完成详情
5354
max_completion_tokens: Optional[int] = Field(default=None) # 最大完成长度
5455
max_prompt_tokens: Optional[int] = Field(default=None) # 最大提示长度
@@ -74,6 +75,7 @@ class RunCreate(BaseModel):
7475
tools: Optional[list] = []
7576
extra_body: Optional[dict[str, Union[dict[str, Union[Authentication, Any]], Any]]] = {}
7677
stream: Optional[bool] = False
78+
stream_options: Optional[dict] = Field(default=None, sa_column=Column(JSON))
7779
additional_messages: Optional[list[MessageCreate]] = Field(default=[], sa_column=Column(JSON)) # 消息列表
7880
max_completion_tokens: Optional[int] = None # 最大完成长度
7981
max_prompt_tokens: Optional[int] = Field(default=None) # 最大提示长度

examples/run_assistant_stream.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import logging
55

66
from openai import AssistantEventHandler
7+
from openai.types.beta import AssistantStreamEvent
8+
from openai.types.beta.assistant_stream_event import ThreadMessageInProgress
79
from openai.types.beta.threads.message import Message
810
from openai.types.beta.threads.runs import ToolCall, ToolCallDelta
911

@@ -47,6 +49,11 @@ def on_text_delta(self, delta, snapshot) -> None:
4749
def on_text_done(self, text) -> None:
4850
logging.info("text done: %s\n", text)
4951

52+
def on_event(self, event: AssistantStreamEvent) -> None:
53+
if isinstance(event, ThreadMessageInProgress):
54+
logging.info("event: %s\n", event)
55+
56+
5057
if __name__ == "__main__":
5158
assistant = client.beta.assistants.create(
5259
name="Assistant Demo",
@@ -70,5 +77,8 @@ def on_text_done(self, text) -> None:
7077
thread_id=thread.id,
7178
assistant_id=assistant.id,
7279
event_handler=event_handler,
80+
extra_body={
81+
"stream_options": {"include_usage": True}
82+
}
7383
) as stream:
7484
stream.until_done()
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
"""empty message
2+
3+
Revision ID: 5b2b73d0fdf6
4+
Revises: b217fafdb5f0
5+
Create Date: 2024-11-15 18:30:43.391344
6+
7+
"""
8+
from typing import Sequence, Union
9+
10+
from alembic import op
11+
import sqlalchemy as sa
12+
import sqlmodel
13+
14+
15+
# revision identifiers, used by Alembic.
16+
revision: str = '5b2b73d0fdf6'
17+
down_revision: Union[str, None] = 'b217fafdb5f0'
18+
branch_labels: Union[str, Sequence[str], None] = None
19+
depends_on: Union[str, Sequence[str], None] = None
20+
21+
22+
def upgrade() -> None:
23+
# ### commands auto generated by Alembic - please adjust! ###
24+
op.add_column('run', sa.Column('stream_options', sa.JSON(), nullable=True))
25+
# ### end Alembic commands ###
26+
27+
28+
def downgrade() -> None:
29+
# ### commands auto generated by Alembic - please adjust! ###
30+
op.drop_column('run', 'stream_options')
31+
# ### end Alembic commands ###

0 commit comments

Comments
 (0)