Skip to content

Commit 348b3d8

Browse files
authored
Merge pull request road-core#333 from TamiTakamiya/TamiTakamiya/event_stream_format_support
Event Stream Format support
2 parents 0846595 + d27d5f4 commit 348b3d8

File tree

4 files changed

+44
-5
lines changed

4 files changed

+44
-5
lines changed

examples/rcsconfig.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ ols_config:
7373
default_provider: my_bam
7474
default_model: ibm/granite-3-8b-instruct
7575
expire_llm_is_ready_persistent_state: -1
76+
enable_event_stream_format: true
7677
# query_filters:
7778
# - name: foo_filter
7879
# pattern: '\b(?:foo)\b'

ols/app/endpoints/streaming_ols.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -129,13 +129,19 @@ async def invalid_response_generator() -> AsyncGenerator[str, None]:
129129
yield prompts.INVALID_QUERY_RESP
130130

131131

132+
def format_stream_data(d: dict) -> str:
133+
"""Format outbound data in the Event Stream Format if required."""
134+
data = json.dumps(d)
135+
return f"data: {data}\n\n" if config.ols_config.enable_event_stream_format else data
136+
137+
132138
def stream_start_event(conversation_id: str) -> str:
133139
"""Yield the start of the data stream.
134140
135141
Args:
136142
conversation_id: The conversation ID (UUID).
137143
"""
138-
return json.dumps(
144+
return format_stream_data(
139145
{
140146
"event": "start",
141147
"data": {
@@ -157,7 +163,7 @@ def stream_end_event(
157163
token_counter: Token counter for the whole stream.
158164
"""
159165
if media_type == constants.MEDIA_TYPE_JSON:
160-
return json.dumps(
166+
return format_stream_data(
161167
{
162168
"event": "end",
163169
"data": {
@@ -203,7 +209,7 @@ def prompt_too_long_error(error: PromptTooLongError, media_type: str) -> str:
203209
logger.error("Prompt is too long: %s", error)
204210
if media_type == MEDIA_TYPE_TEXT:
205211
return f"Prompt is too long: {error}"
206-
return json.dumps(
212+
return format_stream_data(
207213
{
208214
"event": "error",
209215
"data": {
@@ -230,7 +236,7 @@ def generic_llm_error(error: Exception, media_type: str) -> str:
230236

231237
if media_type == MEDIA_TYPE_TEXT:
232238
return f"{response}: {cause}"
233-
return json.dumps(
239+
return format_stream_data(
234240
{
235241
"event": "error",
236242
"data": {
@@ -254,7 +260,12 @@ def build_yield_item(item: str, idx: int, media_type: str) -> str:
254260
"""
255261
if media_type == MEDIA_TYPE_TEXT:
256262
return item
257-
return json.dumps({"event": "token", "data": {"id": idx, "token": item}})
263+
return format_stream_data(
264+
{
265+
"event": "token",
266+
"data": {"id": idx, "token": item},
267+
}
268+
)
258269

259270

260271
def store_data(

ols/app/models/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -870,6 +870,8 @@ class OLSConfig(BaseModel):
870870
extra_ca: list[FilePath] = []
871871
certificate_directory: Optional[str] = None
872872

873+
enable_event_stream_format: bool = False
874+
873875
def __init__(
874876
self, data: Optional[dict] = None, ignore_missing_certs: bool = False
875877
) -> None:
@@ -920,6 +922,7 @@ def __init__(
920922
self.tls_security_profile = TLSSecurityProfile(
921923
data.get("tlsSecurityProfile", None)
922924
)
925+
self.enable_event_stream_format = data.get("enable_event_stream_format", False)
923926

924927
def __eq__(self, other: object) -> bool:
925928
"""Compare two objects for equality."""

tests/unit/app/endpoints/test_streaming_ols.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from ols.app.endpoints.streaming_ols import (
99
build_referenced_docs,
1010
build_yield_item,
11+
format_stream_data,
1112
generic_llm_error,
1213
invalid_response_generator,
1314
prompt_too_long_error,
@@ -149,3 +150,26 @@ def test_build_referenced_docs():
149150
{"doc_title": "title_1", "doc_url": "url_1"},
150151
{"doc_title": "title_2", "doc_url": "url_2"},
151152
]
153+
154+
155+
@pytest.mark.usefixtures("_load_config")
156+
def test_format_stream_data():
157+
"""Test format_stream_data function."""
158+
stream_data = {
159+
"event": "token",
160+
"data": {"id": "ABC-123456", "token": "***TOKEN***"},
161+
}
162+
stringified_data = json.dumps(stream_data)
163+
data_in_event_stream_data_format = f"data: {stringified_data}\n\n"
164+
165+
saved_value = config.ols_config.enable_event_stream_format
166+
try:
167+
config.ols_config.enable_event_stream_format = False
168+
output = format_stream_data((stream_data))
169+
assert output == stringified_data
170+
171+
config.ols_config.enable_event_stream_format = True
172+
output = format_stream_data((stream_data))
173+
assert output == data_in_event_stream_data_format
174+
finally:
175+
config.ols_config.enable_event_stream_format = saved_value

0 commit comments

Comments
 (0)