Skip to content

Commit dbd2247

Browse files
authored
Merge pull request #165 from ks6088ts-labs/feature/issue-164_slm
add SLM based chat app
2 parents 54bd03d + 4779767 commit dbd2247

File tree

7 files changed

+2200
-1653
lines changed

7 files changed

+2200
-1653
lines changed
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Streamlit Chat with SLM
2+
3+
## Overview
4+
5+
```shell
6+
# Run Ollama server
7+
$ ollama serve
8+
9+
# Install dependencies
10+
$ ollama pull phi3
11+
12+
# Run a simple chat with Ollama
13+
$ poetry run python apps/15_streamlit_chat_slm/chat.py
14+
15+
# Run summarization with SLM
16+
$ poetry run python apps/15_streamlit_chat_slm/summarize.py
17+
18+
# Run streamlit app
19+
$ poetry run python -m streamlit run apps/15_streamlit_chat_slm/main.py
20+
```
21+
22+
# References
23+
24+
- [ChatOllama](https://python.langchain.com/docs/integrations/chat/ollama/)
25+
- [Summarize Text](https://python.langchain.com/docs/tutorials/summarization/)

apps/15_streamlit_chat_slm/chat.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import argparse
2+
import logging
3+
4+
from dotenv import load_dotenv
5+
from langchain_core.messages import AIMessage
6+
from langchain_ollama import ChatOllama
7+
8+
9+
def init_args() -> argparse.Namespace:
10+
parser = argparse.ArgumentParser(
11+
prog="slm_chat",
12+
description="Chat with SLM model",
13+
)
14+
parser.add_argument("-m", "--model", default="phi3")
15+
parser.add_argument("-s", "--system", default="You are a helpful assistant.")
16+
parser.add_argument("-p", "--prompt", default="What is the capital of France?")
17+
parser.add_argument("-v", "--verbose", action="store_true")
18+
return parser.parse_args()
19+
20+
21+
if __name__ == "__main__":
22+
args = init_args()
23+
24+
# Set verbose mode
25+
if args.verbose:
26+
logging.basicConfig(level=logging.DEBUG)
27+
28+
# Parse .env file and set environment variables
29+
load_dotenv()
30+
31+
llm = ChatOllama(
32+
model=args.model,
33+
temperature=0,
34+
)
35+
36+
ai_msg: AIMessage = llm.invoke(
37+
input=[
38+
("system", args.system),
39+
("human", args.prompt),
40+
]
41+
)
42+
print(ai_msg.model_dump_json(indent=2))
43+
# print(ai_msg.content)

apps/15_streamlit_chat_slm/main.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import streamlit as st
2+
from dotenv import load_dotenv
3+
from langchain_ollama import ChatOllama
4+
5+
load_dotenv()
6+
7+
SUPPORTED_MODELS = [
8+
"phi3",
9+
]
10+
with st.sidebar:
11+
slm_model = st.selectbox(
12+
label="Model",
13+
options=SUPPORTED_MODELS,
14+
index=0,
15+
)
16+
"[Azure Portal](https://portal.azure.com/)"
17+
"[Azure OpenAI Studio](https://oai.azure.com/resource/overview)"
18+
"[View the source code](https://github.com/ks6088ts-labs/workshop-azure-openai/blob/main/apps/15_streamlit_chat_slm/main.py)"
19+
20+
21+
def is_configured():
22+
return slm_model in SUPPORTED_MODELS
23+
24+
25+
st.title("15_streamlit_chat_slm")
26+
27+
if not is_configured():
28+
st.warning("Please fill in the required fields at the sidebar.")
29+
30+
if "messages" not in st.session_state:
31+
st.session_state["messages"] = [
32+
{
33+
"role": "assistant",
34+
"content": "Hello! I'm a helpful assistant.",
35+
}
36+
]
37+
38+
# Show chat messages
39+
for msg in st.session_state.messages:
40+
st.chat_message(msg["role"]).write(msg["content"])
41+
42+
# Receive user input
43+
if prompt := st.chat_input(disabled=not is_configured()):
44+
client = ChatOllama(
45+
model=slm_model,
46+
temperature=0,
47+
)
48+
49+
st.session_state.messages.append(
50+
{
51+
"role": "user",
52+
"content": prompt,
53+
}
54+
)
55+
st.chat_message("user").write(prompt)
56+
with st.spinner("Thinking..."):
57+
response = client.invoke(
58+
input=st.session_state.messages,
59+
)
60+
msg = response.content
61+
st.session_state.messages.append(
62+
{
63+
"role": "assistant",
64+
"content": msg,
65+
}
66+
)
67+
st.chat_message("assistant").write(msg)
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
import asyncio
2+
import operator
3+
from os import getenv
4+
from typing import Annotated, Literal, TypedDict
5+
6+
from langchain.chains.combine_documents.reduce import acollapse_docs, split_list_of_docs
7+
from langchain_community.document_loaders import WebBaseLoader
8+
from langchain_core.documents import Document
9+
from langchain_core.output_parsers import StrOutputParser
10+
from langchain_core.prompts import ChatPromptTemplate
11+
from langchain_ollama import ChatOllama
12+
from langchain_openai import AzureChatOpenAI
13+
from langchain_text_splitters import CharacterTextSplitter
14+
from langgraph.constants import Send
15+
from langgraph.graph import END, START, StateGraph
16+
17+
token_max = 1000
18+
url = "https://lilianweng.github.io/posts/2023-06-23-agent/"
19+
20+
llm_ollama = ChatOllama(
21+
model="phi3",
22+
temperature=0,
23+
)
24+
llm_azure_openai = AzureChatOpenAI(
25+
temperature=0,
26+
api_key=getenv("AZURE_OPENAI_API_KEY"),
27+
api_version=getenv("AZURE_OPENAI_API_VERSION"),
28+
azure_endpoint=getenv("AZURE_OPENAI_ENDPOINT"),
29+
model=getenv("AZURE_OPENAI_GPT_MODEL"),
30+
)
31+
# Use the Ollama model
32+
llm = llm_ollama
33+
34+
35+
def length_function(documents: list[Document]) -> int:
36+
"""Get number of tokens for input contents."""
37+
return sum(llm.get_num_tokens(doc.page_content) for doc in documents)
38+
39+
40+
# This will be the overall state of the main graph.
41+
# It will contain the input document contents, corresponding
42+
# summaries, and a final summary.
43+
class OverallState(TypedDict):
44+
# Notice here we use the operator.add
45+
# This is because we want combine all the summaries we generate
46+
# from individual nodes back into one list - this is essentially
47+
# the "reduce" part
48+
contents: list[str]
49+
summaries: Annotated[list, operator.add]
50+
collapsed_summaries: list[Document]
51+
final_summary: str
52+
53+
54+
# This will be the state of the node that we will "map" all
55+
# documents to in order to generate summaries
56+
class SummaryState(TypedDict):
57+
content: str
58+
59+
60+
map_prompt = ChatPromptTemplate.from_messages([("system", "Write a concise summary of the following:\\n\\n{context}")])
61+
62+
map_chain = map_prompt | llm | StrOutputParser()
63+
64+
65+
# Here we generate a summary, given a document
66+
async def generate_summary(state: SummaryState):
67+
response = await map_chain.ainvoke(state["content"])
68+
return {"summaries": [response]}
69+
70+
71+
# Here we define the logic to map out over the documents
72+
# We will use this an edge in the graph
73+
def map_summaries(state: OverallState):
74+
# We will return a list of `Send` objects
75+
# Each `Send` object consists of the name of a node in the graph
76+
# as well as the state to send to that node
77+
return [Send("generate_summary", {"content": content}) for content in state["contents"]]
78+
79+
80+
def collect_summaries(state: OverallState):
81+
return {"collapsed_summaries": [Document(summary) for summary in state["summaries"]]}
82+
83+
84+
# Also available via the hub: `hub.pull("rlm/reduce-prompt")`
85+
reduce_template = """
86+
The following is a set of summaries:
87+
{docs}
88+
Take these and distill it into a final, consolidated summary
89+
of the main themes.
90+
"""
91+
92+
reduce_prompt = ChatPromptTemplate([("human", reduce_template)])
93+
94+
reduce_chain = reduce_prompt | llm | StrOutputParser()
95+
96+
97+
# Add node to collapse summaries
98+
async def collapse_summaries(state: OverallState):
99+
doc_lists = split_list_of_docs(state["collapsed_summaries"], length_function, token_max)
100+
results = []
101+
for doc_list in doc_lists:
102+
results.append(await acollapse_docs(doc_list, reduce_chain.ainvoke))
103+
104+
return {"collapsed_summaries": results}
105+
106+
107+
# This represents a conditional edge in the graph that determines
108+
# if we should collapse the summaries or not
109+
def should_collapse(
110+
state: OverallState,
111+
) -> Literal["collapse_summaries", "generate_final_summary"]:
112+
num_tokens = length_function(state["collapsed_summaries"])
113+
if num_tokens > token_max:
114+
return "collapse_summaries"
115+
else:
116+
return "generate_final_summary"
117+
118+
119+
# Here we will generate the final summary
120+
async def generate_final_summary(state: OverallState):
121+
response = await reduce_chain.ainvoke(state["collapsed_summaries"])
122+
return {"final_summary": response}
123+
124+
125+
async def main():
126+
# Construct the graph
127+
# Nodes:
128+
graph = StateGraph(OverallState)
129+
graph.add_node("generate_summary", generate_summary) # same as before
130+
graph.add_node("collect_summaries", collect_summaries)
131+
graph.add_node("collapse_summaries", collapse_summaries)
132+
graph.add_node("generate_final_summary", generate_final_summary)
133+
134+
# Edges:
135+
graph.add_conditional_edges(START, map_summaries, ["generate_summary"])
136+
graph.add_edge("generate_summary", "collect_summaries")
137+
graph.add_conditional_edges("collect_summaries", should_collapse)
138+
graph.add_conditional_edges("collapse_summaries", should_collapse)
139+
graph.add_edge("generate_final_summary", END)
140+
141+
app = graph.compile()
142+
143+
# create graph image
144+
app.get_graph().draw_mermaid_png(output_file_path="docs/images/15_streamlit_chat_slm.summarize_graph.png")
145+
146+
text_splitter = CharacterTextSplitter.from_tiktoken_encoder(chunk_size=1000, chunk_overlap=0)
147+
148+
loader = WebBaseLoader(web_path=url)
149+
docs = loader.load()
150+
151+
split_docs = text_splitter.split_documents(docs)
152+
print(f"Generated {len(split_docs)} documents.")
153+
154+
async for step in app.astream(
155+
{"contents": [doc.page_content for doc in split_docs]},
156+
{"recursion_limit": 10},
157+
):
158+
print(list(step.keys()))
159+
print(step)
160+
161+
162+
asyncio.run(main())
20.1 KB
Loading

0 commit comments

Comments
 (0)