33
44import streamlit as st
55from dotenv import load_dotenv
6- from langchain_community .callbacks .streamlit import (
7- StreamlitCallbackHandler ,
8- )
9- from langchain_core .messages import AIMessage , HumanMessage
6+ from langchain_core .messages import AIMessage , HumanMessage , SystemMessage
107from langchain_ollama import ChatOllama
118from langchain_openai import AzureChatOpenAI
129from openai import APIConnectionError , APIStatusError , APITimeoutError
@@ -33,11 +30,14 @@ def image_to_base64(image_bytes: bytes) -> str:
3330 "# Model"
3431 model_choice = st .radio (
3532 label = "Active Model" ,
36- options = ["azure" , "ollama" ],
33+ options = [
34+ "azure" ,
35+ "ollama" ,
36+ ],
3737 index = 0 ,
3838 key = "model_choice" ,
3939 )
40- "## Model Settings"
40+ f "## Model Settings for { model_choice . capitalize () } "
4141 if model_choice == "azure" :
4242 azure_openai_endpoint = st .text_input (
4343 label = "AZURE_OPENAI_ENDPOINT" ,
@@ -63,18 +63,23 @@ def image_to_base64(image_bytes: bytes) -> str:
6363 key = "AZURE_OPENAI_MODEL_CHAT" ,
6464 type = "default" ,
6565 )
66+ "### Documents"
6667 "[Azure Portal](https://portal.azure.com/)"
6768 "[Azure OpenAI Studio](https://oai.azure.com/resource/overview)"
6869 "[View the source code](https://github.com/ks6088ts-labs/template-streamlit)"
69- else :
70+ elif model_choice == "ollama" :
7071 ollama_model_chat = st .text_input (
7172 label = "OLLAMA_MODEL_CHAT" ,
7273 value = getenv ("OLLAMA_MODEL_CHAT" ),
7374 key = "OLLAMA_MODEL_CHAT" ,
7475 type = "default" ,
7576 )
77+ "### Documents"
7678 "[Ollama Docs](https://github.com/ollama/ollama)"
7779 "[View the source code](https://github.com/ks6088ts-labs/template-streamlit)"
80+ else :
81+ st .error ("Invalid model choice. Please select either 'azure' or 'ollama'." )
82+ raise ValueError ("Invalid model choice. Please select either 'azure' or 'ollama'." )
7883
7984
8085def is_azure_configured ():
@@ -110,14 +115,18 @@ def get_model():
110115 raise ValueError ("No model is configured. Please set up the Azure or Ollama model in the sidebar." )
111116
112117
113- st .title ("chat app with LangChain SDK " )
118+ st .title ("Chat Playground " )
114119
115120if not is_configured ():
116121 st .warning ("Please fill in the required fields at the sidebar." )
117122
118123if "messages" not in st .session_state :
119124 st .session_state ["messages" ] = [
120- AIMessage (content = "Hello! I'm a helpful assistant." ),
125+ SystemMessage (
126+ content = "You are a helpful assistant. Answer concisely. "
127+ "If you don't know the answer, just say you don't know. "
128+ "Do not make up an answer."
129+ ),
121130 ]
122131
123132# Show chat messages
@@ -134,14 +143,35 @@ def get_model():
134143
135144
136145# Receive user input
137- uploaded_file = st .file_uploader ("画像をアップロード" , type = ["png" , "jpg" , "jpeg" ], key = "file_uploader" )
138- if prompt := st .chat_input (disabled = not is_configured ()):
139- user_message_content = [{"type" : "text" , "text" : prompt }]
140- if uploaded_file :
141- image_bytes = uploaded_file .getvalue ()
142- base64_image = image_to_base64 (image_bytes )
143- image_url = f"data:image/jpeg;base64,{ base64_image } "
144- user_message_content .append ({"type" : "image_url" , "image_url" : {"url" : image_url }})
146+ if prompt := st .chat_input (
147+ disabled = not is_configured (),
148+ accept_file = "multiple" ,
149+ file_type = [
150+ "png" ,
151+ "jpg" ,
152+ "jpeg" ,
153+ "gif" ,
154+ "webp" ,
155+ ],
156+ ):
157+ user_message_content = []
158+ for file in prompt .files :
159+ if file .type .startswith ("image/" ):
160+ image_bytes = file .getvalue ()
161+ base64_image = image_to_base64 (image_bytes )
162+ image_url = f"data:{ file .type } ;base64,{ base64_image } "
163+ user_message_content .append (
164+ {
165+ "type" : "image_url" ,
166+ "image_url" : {"url" : image_url },
167+ }
168+ )
169+ user_message_content .append (
170+ {
171+ "type" : "text" ,
172+ "text" : prompt .text ,
173+ }
174+ )
145175
146176 user_message = HumanMessage (content = user_message_content )
147177 st .session_state .messages .append (user_message )
@@ -158,8 +188,6 @@ def get_model():
158188 message_placeholder = st .empty ()
159189 full_response = ""
160190 llm = get_model ()
161- callbacks = [StreamlitCallbackHandler (st .container ())]
162-
163191 try :
164192 if stream_mode :
165193 for chunk in llm .stream (st .session_state .messages ):
0 commit comments