|
| 1 | +from base64 import b64encode |
| 2 | +from os import getenv |
| 3 | + |
| 4 | +import streamlit as st |
| 5 | +from dotenv import load_dotenv |
| 6 | +from langchain_community.callbacks.streamlit import ( |
| 7 | + StreamlitCallbackHandler, |
| 8 | +) |
| 9 | +from langchain_core.messages import AIMessage, HumanMessage |
| 10 | +from langchain_ollama import ChatOllama |
| 11 | +from langchain_openai import AzureChatOpenAI |
| 12 | +from openai import APIConnectionError, APIStatusError, APITimeoutError |
| 13 | + |
| 14 | +from template_langgraph.loggers import get_logger |
| 15 | + |
| 16 | +load_dotenv(override=True) |
| 17 | +logger = get_logger(__name__) |
| 18 | +logger.setLevel("DEBUG") |
| 19 | + |
| 20 | + |
| 21 | +def image_to_base64(image_bytes: bytes) -> str: |
| 22 | + """Convert image bytes to base64 string.""" |
| 23 | + return b64encode(image_bytes).decode("utf-8") |
| 24 | + |
| 25 | + |
| 26 | +with st.sidebar: |
| 27 | + "# Common Settings" |
| 28 | + stream_mode = st.checkbox( |
| 29 | + label="ストリーム出力を有効にする", |
| 30 | + value=True, |
| 31 | + key="STREAM_MODE", |
| 32 | + ) |
| 33 | + "# Model" |
| 34 | + model_choice = st.radio( |
| 35 | + label="Active Model", |
| 36 | + options=["azure", "ollama"], |
| 37 | + index=0, |
| 38 | + key="model_choice", |
| 39 | + ) |
| 40 | + "## Model Settings" |
| 41 | + if model_choice == "azure": |
| 42 | + azure_openai_endpoint = st.text_input( |
| 43 | + label="AZURE_OPENAI_ENDPOINT", |
| 44 | + value=getenv("AZURE_OPENAI_ENDPOINT"), |
| 45 | + key="AZURE_OPENAI_ENDPOINT", |
| 46 | + type="default", |
| 47 | + ) |
| 48 | + azure_openai_api_key = st.text_input( |
| 49 | + label="AZURE_OPENAI_API_KEY", |
| 50 | + value=getenv("AZURE_OPENAI_API_KEY"), |
| 51 | + key="AZURE_OPENAI_API_KEY", |
| 52 | + type="password", |
| 53 | + ) |
| 54 | + azure_openai_api_version = st.text_input( |
| 55 | + label="AZURE_OPENAI_API_VERSION", |
| 56 | + value=getenv("AZURE_OPENAI_API_VERSION"), |
| 57 | + key="AZURE_OPENAI_API_VERSION", |
| 58 | + type="default", |
| 59 | + ) |
| 60 | + azure_openai_model_chat = st.text_input( |
| 61 | + label="AZURE_OPENAI_MODEL_CHAT", |
| 62 | + value=getenv("AZURE_OPENAI_MODEL_CHAT"), |
| 63 | + key="AZURE_OPENAI_MODEL_CHAT", |
| 64 | + type="default", |
| 65 | + ) |
| 66 | + "[Azure Portal](https://portal.azure.com/)" |
| 67 | + "[Azure OpenAI Studio](https://oai.azure.com/resource/overview)" |
| 68 | + "[View the source code](https://github.com/ks6088ts-labs/template-streamlit)" |
| 69 | + else: |
| 70 | + ollama_model_chat = st.text_input( |
| 71 | + label="OLLAMA_MODEL_CHAT", |
| 72 | + value=getenv("OLLAMA_MODEL_CHAT"), |
| 73 | + key="OLLAMA_MODEL_CHAT", |
| 74 | + type="default", |
| 75 | + ) |
| 76 | + "[Ollama Docs](https://github.com/ollama/ollama)" |
| 77 | + "[View the source code](https://github.com/ks6088ts-labs/template-streamlit)" |
| 78 | + |
| 79 | + |
| 80 | +def is_azure_configured(): |
| 81 | + return ( |
| 82 | + st.session_state.get("AZURE_OPENAI_API_KEY") |
| 83 | + and st.session_state.get("AZURE_OPENAI_ENDPOINT") |
| 84 | + and st.session_state.get("AZURE_OPENAI_API_VERSION") |
| 85 | + and st.session_state.get("AZURE_OPENAI_MODEL_CHAT") |
| 86 | + and st.session_state.get("model_choice") == "azure" |
| 87 | + ) |
| 88 | + |
| 89 | + |
| 90 | +def is_ollama_configured(): |
| 91 | + return st.session_state.get("OLLAMA_MODEL_CHAT") and st.session_state.get("model_choice") == "ollama" |
| 92 | + |
| 93 | + |
| 94 | +def is_configured(): |
| 95 | + return is_azure_configured() or is_ollama_configured() |
| 96 | + |
| 97 | + |
| 98 | +def get_model(): |
| 99 | + if is_azure_configured(): |
| 100 | + return AzureChatOpenAI( |
| 101 | + azure_endpoint=st.session_state.get("AZURE_OPENAI_ENDPOINT"), |
| 102 | + api_key=st.session_state.get("AZURE_OPENAI_API_KEY"), |
| 103 | + openai_api_version=st.session_state.get("AZURE_OPENAI_API_VERSION"), |
| 104 | + azure_deployment=st.session_state.get("AZURE_OPENAI_MODEL_CHAT"), |
| 105 | + ) |
| 106 | + elif is_ollama_configured(): |
| 107 | + return ChatOllama( |
| 108 | + model=st.session_state.get("OLLAMA_MODEL_CHAT", ""), |
| 109 | + ) |
| 110 | + raise ValueError("No model is configured. Please set up the Azure or Ollama model in the sidebar.") |
| 111 | + |
| 112 | + |
| 113 | +st.title("chat app with LangChain SDK") |
| 114 | + |
| 115 | +if not is_configured(): |
| 116 | + st.warning("Please fill in the required fields at the sidebar.") |
| 117 | + |
| 118 | +if "messages" not in st.session_state: |
| 119 | + st.session_state["messages"] = [ |
| 120 | + AIMessage(content="Hello! I'm a helpful assistant."), |
| 121 | + ] |
| 122 | + |
| 123 | +# Show chat messages |
| 124 | +for message in st.session_state.messages: |
| 125 | + with st.chat_message(message.type): |
| 126 | + if isinstance(message.content, str): |
| 127 | + st.markdown(message.content) |
| 128 | + else: |
| 129 | + for item in message.content: |
| 130 | + if item["type"] == "text": |
| 131 | + st.markdown(item["text"]) |
| 132 | + elif item["type"] == "image_url": |
| 133 | + st.image(item["image_url"]["url"]) |
| 134 | + |
| 135 | + |
| 136 | +# Receive user input |
| 137 | +uploaded_file = st.file_uploader("画像をアップロード", type=["png", "jpg", "jpeg"], key="file_uploader") |
| 138 | +if prompt := st.chat_input(disabled=not is_configured()): |
| 139 | + user_message_content = [{"type": "text", "text": prompt}] |
| 140 | + if uploaded_file: |
| 141 | + image_bytes = uploaded_file.getvalue() |
| 142 | + base64_image = image_to_base64(image_bytes) |
| 143 | + image_url = f"data:image/jpeg;base64,{base64_image}" |
| 144 | + user_message_content.append({"type": "image_url", "image_url": {"url": image_url}}) |
| 145 | + |
| 146 | + user_message = HumanMessage(content=user_message_content) |
| 147 | + st.session_state.messages.append(user_message) |
| 148 | + |
| 149 | + with st.chat_message("user"): |
| 150 | + for item in user_message_content: |
| 151 | + if item["type"] == "text": |
| 152 | + st.markdown(item["text"]) |
| 153 | + elif item["type"] == "image_url": |
| 154 | + st.image(item["image_url"]["url"]) |
| 155 | + |
| 156 | + with st.spinner("Thinking..."): |
| 157 | + with st.chat_message("assistant"): |
| 158 | + message_placeholder = st.empty() |
| 159 | + full_response = "" |
| 160 | + llm = get_model() |
| 161 | + callbacks = [StreamlitCallbackHandler(st.container())] |
| 162 | + |
| 163 | + try: |
| 164 | + if stream_mode: |
| 165 | + for chunk in llm.stream(st.session_state.messages): |
| 166 | + if chunk.content is not None: |
| 167 | + full_response += chunk.content |
| 168 | + message_placeholder.markdown(full_response + "▌") |
| 169 | + message_placeholder.markdown(full_response) |
| 170 | + else: |
| 171 | + response = llm.invoke(input=st.session_state.messages) |
| 172 | + full_response = response.content if hasattr(response, "content") else str(response) |
| 173 | + message_placeholder.markdown(full_response) |
| 174 | + |
| 175 | + st.session_state.messages.append(AIMessage(content=full_response)) |
| 176 | + |
| 177 | + except APITimeoutError as e: |
| 178 | + logger.exception(f"APIタイムアウトエラーが発生しました: {e}") |
| 179 | + st.error(f"APIタイムアウトエラーが発生しました: {e}") |
| 180 | + st.warning("再度お試しいただくか、接続を確認してください。") |
| 181 | + except APIConnectionError as e: |
| 182 | + logger.exception(f"API接続エラーが発生しました: {e}") |
| 183 | + st.error(f"API接続エラーが発生しました: {e}") |
| 184 | + st.warning("ネットワーク接続を確認してください。") |
| 185 | + except APIStatusError as e: |
| 186 | + logger.exception(f"APIステータスエラーが発生しました: {e.status_code} - {e.response}") |
| 187 | + st.error(f"APIステータスエラーが発生しました: {e.status_code} - {e.response}") |
| 188 | + st.warning("Azure OpenAIの設定(デプロイメント名、APIバージョンなど)を確認してください。") |
| 189 | + except Exception as e: |
| 190 | + logger.exception(f"予期せぬエラーが発生しました: {e}") |
| 191 | + st.error(f"予期せぬエラーが発生しました: {e}") |
0 commit comments