Skip to content

Commit 0a0d99d

Browse files
authored
Option to enable chat streaming (#1516)
* Option to enable chat streaming * Add stream to HttpConfiguration * Refactor based on a review comment
1 parent 6e2c1ff commit 0a0d99d

File tree

5 files changed

+34
-4
lines changed

5 files changed

+34
-4
lines changed

ansible_ai_connect/ai/api/model_pipelines/http/configuration.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,14 @@ def __init__(
4141
timeout: Optional[int],
4242
enable_health_check: Optional[bool],
4343
verify_ssl: bool,
44+
stream: bool = False,
4445
):
4546
super().__init__(inference_url, model_id, timeout, enable_health_check)
4647
self.verify_ssl = verify_ssl
48+
self.stream = stream
4749

4850
verify_ssl: bool
51+
stream: bool
4952

5053

5154
@Register(api_type="http")
@@ -60,10 +63,12 @@ def __init__(self, **kwargs):
6063
timeout=kwargs["timeout"],
6164
enable_health_check=kwargs["enable_health_check"],
6265
verify_ssl=kwargs["verify_ssl"],
66+
stream=kwargs["stream"],
6367
),
6468
)
6569

6670

6771
@Register(api_type="http")
6872
class HttpConfigurationSerializer(BaseConfigSerializer):
6973
verify_ssl = serializers.BooleanField(required=False, default=True)
74+
stream = serializers.BooleanField(required=False, default=False)

ansible_ai_connect/main/settings/legacy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ def load_from_env_vars():
189189
"inference_url": chatbot_service_url or "http://localhost:8000",
190190
"model_id": chatbot_service_model_id or "granite3-8b",
191191
"verify_ssl": model_service_verify_ssl,
192+
"stream": False,
192193
},
193194
}
194195

ansible_ai_connect/main/templates/chatbot/index.html

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,6 @@
2121
<div id="user_name" hidden>{{user_name}}</div>
2222
<div id="bot_name" hidden>{{bot_name}}</div>
2323
<div id="debug" hidden>{{debug}}</div>
24+
<div id="stream" hidden>{{stream}}</div>
2425
{% endblock content %}
2526
</html>

ansible_ai_connect/main/tests/test_views.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,15 @@
1818
from http import HTTPStatus
1919
from textwrap import dedent
2020

21+
from django.apps import apps
2122
from django.contrib.auth import get_user_model
2223
from django.contrib.auth.models import AnonymousUser, Group
2324
from django.http import HttpResponseRedirect
2425
from django.test import RequestFactory, TestCase, modify_settings, override_settings
2526
from django.urls import reverse
2627
from rest_framework.test import APITransactionTestCase
2728

29+
from ansible_ai_connect.ai.api.model_pipelines.pipelines import ModelPipelineChatBot
2830
from ansible_ai_connect.main.settings.base import SOCIAL_AUTH_OIDC_KEY
2931
from ansible_ai_connect.main.views import LoginView
3032
from ansible_ai_connect.test_utils import (
@@ -339,10 +341,21 @@ def test_chatbot_view_with_rh_user(self):
339341
self.assertContains(r, TestChatbotView.CHATBOT_PAGE_TITLE)
340342
self.assertContains(r, self.rh_user.username)
341343
self.assertContains(r, '<div id="debug" hidden>false</div>')
344+
self.assertContains(r, '<div id="stream" hidden>false</div>')
342345

343346
@override_settings(CHATBOT_DEBUG_UI=True)
344347
def test_chatbot_view_with_debug_ui(self):
345348
self.client.force_login(user=self.rh_user)
346349
r = self.client.get(reverse("chatbot"), {"debug": "true"})
347350
self.assertEqual(r.status_code, HTTPStatus.OK)
348351
self.assertContains(r, '<div id="debug" hidden>true</div>')
352+
353+
def test_chatbot_view_with_streaming_enabled(self):
354+
llm: ModelPipelineChatBot = apps.get_app_config("ai").get_model_pipeline(
355+
ModelPipelineChatBot
356+
)
357+
llm.config.stream = True
358+
self.client.force_login(user=self.rh_user)
359+
r = self.client.get(reverse("chatbot"), {"stream": "true"})
360+
self.assertEqual(r.status_code, HTTPStatus.OK)
361+
self.assertContains(r, '<div id="stream" hidden>true</div>')

ansible_ai_connect/main/views.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,12 +121,21 @@ class ChatbotView(ProtectedTemplateView):
121121
IsRHInternalUser | IsTestUser,
122122
]
123123

124+
llm: ModelPipelineChatBot
125+
chatbot_enabled: bool
126+
127+
def __init__(self):
128+
super().__init__()
129+
self.llm = apps.get_app_config("ai").get_model_pipeline(ModelPipelineChatBot)
130+
self.chatbot_enabled = (
131+
self.llm.config.inference_url
132+
and self.llm.config.model_id
133+
and settings.CHATBOT_DEFAULT_PROVIDER
134+
)
135+
124136
def get(self, request):
125137
# Open the chatbot page when the chatbot service is configured.
126-
llm: ModelPipelineChatBot = apps.get_app_config("ai").get_model_pipeline(
127-
ModelPipelineChatBot
128-
)
129-
if llm.config.inference_url and llm.config.model_id and settings.CHATBOT_DEFAULT_PROVIDER:
138+
if self.chatbot_enabled:
130139
return super().get(request)
131140

132141
# Otherwise, redirect to the home page.
@@ -139,6 +148,7 @@ def get_context_data(self, **kwargs):
139148
if user and user.is_authenticated:
140149
context["user_name"] = user.username
141150
context["debug"] = "true" if settings.CHATBOT_DEBUG_UI else "false"
151+
context["stream"] = "true" if self.llm.config.stream else "false"
142152

143153
return context
144154

0 commit comments

Comments
 (0)