|
| 1 | +# GitHub: https://github.com/naotaka1128/llm_app_codes/chapter_010/main.py |
| 2 | + |
| 3 | +from os import getenv |
| 4 | + |
| 5 | +import streamlit as st |
| 6 | +from dotenv import load_dotenv |
| 7 | +from langchain.agents import AgentExecutor, create_tool_calling_agent |
| 8 | +from langchain.memory import ConversationBufferWindowMemory |
| 9 | +from langchain_community.callbacks import StreamlitCallbackHandler |
| 10 | +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
| 11 | +from langchain_core.runnables import RunnableConfig |
| 12 | + |
| 13 | +# models |
| 14 | +from langchain_openai import AzureChatOpenAI |
| 15 | +from streamlit.runtime.scriptrunner import get_script_run_ctx |
| 16 | +from tools.fetch_contents import fetch_contents |
| 17 | + |
| 18 | +CUSTOM_SYSTEM_PROMPT = """ |
| 19 | +あなたは上鳥羽製作所の総務係です。 |
| 20 | +社内からのお問い合わせに対して、誠実かつ正確な回答を心がけてください。 |
| 21 | +
|
| 22 | +上鳥羽製作所の社内規則に関する一般的な知識についてのみ答えます。 |
| 23 | +それ以外のトピックに関する質問には、丁重にお断りしてください。 |
| 24 | +
|
| 25 | +回答の正確性を保証するため「上鳥羽製作所」に関する質問を受けた際は、 |
| 26 | +必ずツールを使用して回答を見つけてください。 |
| 27 | +
|
| 28 | +ユーザーが質問に使用した言語で回答してください。 |
| 29 | +例えば、ユーザーが英語で質問された場合は、必ず英語で回答してください。 |
| 30 | +スペイン語ならスペイン語で回答してください。 |
| 31 | +
|
| 32 | +回答する際、不明な点がある場合は、ユーザーに確認しましょう。 |
| 33 | +それにより、ユーザーの意図を把握して、適切な回答を行えます。 |
| 34 | +
|
| 35 | +例えば、ユーザーが「オフィスはどこにありますか?」と質問した場合、 |
| 36 | +まずユーザーの居住都道府県を尋ねてください。 |
| 37 | +
|
| 38 | +日本全国のオフィスの場所を知りたいユーザーはほとんどいません。 |
| 39 | +自分の都道府県内のオフィスの場所を知りたいのです。 |
| 40 | +したがって、日本全国のオフィスを検索して回答するのではなく、 |
| 41 | +ユーザーの意図を本当に理解するまで回答しないでください。 |
| 42 | +
|
| 43 | +あくまでこれは一例です。 |
| 44 | +その他のケースでもユーザーの意図を理解し、適切な回答を行ってください。 |
| 45 | +""" |
| 46 | + |
| 47 | +with st.sidebar: |
| 48 | + azure_openai_endpoint = st.text_input( |
| 49 | + label="AZURE_OPENAI_ENDPOINT", |
| 50 | + value=getenv("AZURE_OPENAI_ENDPOINT"), |
| 51 | + key="AZURE_OPENAI_ENDPOINT", |
| 52 | + type="default", |
| 53 | + ) |
| 54 | + azure_openai_api_key = st.text_input( |
| 55 | + label="AZURE_OPENAI_API_KEY", |
| 56 | + key="AZURE_OPENAI_API_KEY", |
| 57 | + type="password", |
| 58 | + ) |
| 59 | + azure_openai_api_version = st.text_input( |
| 60 | + label="AZURE_OPENAI_API_VERSION", |
| 61 | + value=getenv("AZURE_OPENAI_API_VERSION"), |
| 62 | + key="AZURE_OPENAI_API_VERSION", |
| 63 | + type="default", |
| 64 | + ) |
| 65 | + azure_openai_gpt_model = st.text_input( |
| 66 | + label="AZURE_OPENAI_GPT_MODEL", |
| 67 | + value=getenv("AZURE_OPENAI_GPT_MODEL"), |
| 68 | + key="AZURE_OPENAI_GPT_MODEL", |
| 69 | + type="default", |
| 70 | + ) |
| 71 | + "[Go to Azure Portal to get an Azure OpenAI API key](https://portal.azure.com/)" |
| 72 | + "[Go to Azure OpenAI Studio](https://oai.azure.com/resource/overview)" |
| 73 | + "[View the source code](https://github.com/ks6088ts-labs/workshop-azure-openai/blob/main/apps/4_streamlit_chat_history/main.py)" |
| 74 | + |
| 75 | +if not azure_openai_api_key or not azure_openai_endpoint or not azure_openai_api_version or not azure_openai_gpt_model: |
| 76 | + st.warning("サイドバーに Azure OpenAI の設定を入力してください") |
| 77 | + st.stop() |
| 78 | + |
| 79 | + |
| 80 | +def get_session_id(): |
| 81 | + return get_script_run_ctx().session_id |
| 82 | + |
| 83 | + |
| 84 | +def init_page(): |
| 85 | + st.title("Streamlit Chat") |
| 86 | + st.write(f"Session ID: {get_session_id()}") |
| 87 | + |
| 88 | + |
| 89 | +def init_messages(): |
| 90 | + clear_button = st.sidebar.button("Clear Conversation", key="clear") |
| 91 | + if clear_button or "messages" not in st.session_state: |
| 92 | + welcome_message = "ベアーモバイル カスタマーサポートへようこそ。ご質問をどうぞ🐻" |
| 93 | + st.session_state.messages = [{"role": "assistant", "content": welcome_message}] |
| 94 | + st.session_state["memory"] = ConversationBufferWindowMemory( |
| 95 | + return_messages=True, memory_key="chat_history", k=10 |
| 96 | + ) |
| 97 | + |
| 98 | + |
| 99 | +def select_model(): |
| 100 | + return AzureChatOpenAI( |
| 101 | + temperature=0, |
| 102 | + api_key=azure_openai_api_key, |
| 103 | + api_version=azure_openai_api_version, |
| 104 | + azure_endpoint=azure_openai_endpoint, |
| 105 | + model=azure_openai_gpt_model, |
| 106 | + ) |
| 107 | + |
| 108 | + |
| 109 | +def create_agent(): |
| 110 | + ## https://learn.deeplearning.ai/functions-tools-agents-langchain/lesson/7/conversational-agent |
| 111 | + tools = [ |
| 112 | + fetch_contents, |
| 113 | + ] |
| 114 | + prompt = ChatPromptTemplate.from_messages( |
| 115 | + [ |
| 116 | + ("system", CUSTOM_SYSTEM_PROMPT), |
| 117 | + MessagesPlaceholder(variable_name="chat_history"), |
| 118 | + ("user", "{input}"), |
| 119 | + MessagesPlaceholder(variable_name="agent_scratchpad"), |
| 120 | + ] |
| 121 | + ) |
| 122 | + llm = select_model() |
| 123 | + agent = create_tool_calling_agent(llm, tools, prompt) |
| 124 | + return AgentExecutor(agent=agent, tools=tools, verbose=True, memory=st.session_state["memory"]) |
| 125 | + |
| 126 | + |
| 127 | +def main(): |
| 128 | + init_page() |
| 129 | + init_messages() |
| 130 | + customer_support_agent = create_agent() |
| 131 | + |
| 132 | + for msg in st.session_state["memory"].chat_memory.messages: |
| 133 | + st.chat_message(msg.type).write(msg.content) |
| 134 | + |
| 135 | + if prompt := st.chat_input(placeholder="法人で契約することはできるの?"): |
| 136 | + st.chat_message("user").write(prompt) |
| 137 | + |
| 138 | + with st.chat_message("assistant"): |
| 139 | + st_cb = StreamlitCallbackHandler(st.container(), expand_new_thoughts=True) |
| 140 | + response = customer_support_agent.invoke({"input": prompt}, config=RunnableConfig({"callbacks": [st_cb]})) |
| 141 | + st.write(response["output"]) |
| 142 | + |
| 143 | + |
| 144 | +if __name__ == "__main__": |
| 145 | + load_dotenv() |
| 146 | + |
| 147 | + main() |
0 commit comments