Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions apps/15_streamlit_chat_slm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Streamlit Chat with SLM

## Overview

```shell
# Run Ollama server
$ ollama serve

# Install dependencies
$ ollama pull phi3

# Run a simple chat with Ollama
$ poetry run python apps/15_streamlit_chat_slm/chat.py

# Run summarization with SLM
$ poetry run python apps/15_streamlit_chat_slm/summarize.py

# Run streamlit app
$ poetry run python -m streamlit run apps/15_streamlit_chat_slm/main.py
```

# References

- [ChatOllama](https://python.langchain.com/docs/integrations/chat/ollama/)
- [Summarize Text](https://python.langchain.com/docs/tutorials/summarization/)
43 changes: 43 additions & 0 deletions apps/15_streamlit_chat_slm/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import argparse
import logging

from dotenv import load_dotenv
from langchain_core.messages import AIMessage
from langchain_ollama import ChatOllama


def init_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
prog="slm_chat",
description="Chat with SLM model",
)
parser.add_argument("-m", "--model", default="phi3")
parser.add_argument("-s", "--system", default="You are a helpful assistant.")
parser.add_argument("-p", "--prompt", default="What is the capital of France?")
parser.add_argument("-v", "--verbose", action="store_true")
return parser.parse_args()


if __name__ == "__main__":
args = init_args()

# Set verbose mode
if args.verbose:
logging.basicConfig(level=logging.DEBUG)

# Parse .env file and set environment variables
load_dotenv()

llm = ChatOllama(
model=args.model,
temperature=0,
)

ai_msg: AIMessage = llm.invoke(
input=[
("system", args.system),
("human", args.prompt),
]
)
print(ai_msg.model_dump_json(indent=2))
# print(ai_msg.content)
67 changes: 67 additions & 0 deletions apps/15_streamlit_chat_slm/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import streamlit as st
from dotenv import load_dotenv
from langchain_ollama import ChatOllama

load_dotenv()

SUPPORTED_MODELS = [
"phi3",
]
with st.sidebar:
slm_model = st.selectbox(
label="Model",
options=SUPPORTED_MODELS,
index=0,
)
"[Azure Portal](https://portal.azure.com/)"
"[Azure OpenAI Studio](https://oai.azure.com/resource/overview)"
"[View the source code](https://github.com/ks6088ts-labs/workshop-azure-openai/blob/main/apps/15_streamlit_chat_slm/main.py)"


def is_configured():
return slm_model in SUPPORTED_MODELS


st.title("15_streamlit_chat_slm")

if not is_configured():
st.warning("Please fill in the required fields at the sidebar.")

if "messages" not in st.session_state:
st.session_state["messages"] = [
{
"role": "assistant",
"content": "Hello! I'm a helpful assistant.",
}
]

# Show chat messages
for msg in st.session_state.messages:
st.chat_message(msg["role"]).write(msg["content"])

# Receive user input
if prompt := st.chat_input(disabled=not is_configured()):
client = ChatOllama(
model=slm_model,
temperature=0,
)

st.session_state.messages.append(
{
"role": "user",
"content": prompt,
}
)
st.chat_message("user").write(prompt)
with st.spinner("Thinking..."):
response = client.invoke(
input=st.session_state.messages,
)
msg = response.content
st.session_state.messages.append(
{
"role": "assistant",
"content": msg,
}
)
st.chat_message("assistant").write(msg)
162 changes: 162 additions & 0 deletions apps/15_streamlit_chat_slm/summarize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import asyncio
import operator
from os import getenv
from typing import Annotated, Literal, TypedDict

from langchain.chains.combine_documents.reduce import acollapse_docs, split_list_of_docs
from langchain_community.document_loaders import WebBaseLoader
from langchain_core.documents import Document
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_ollama import ChatOllama
from langchain_openai import AzureChatOpenAI
from langchain_text_splitters import CharacterTextSplitter
from langgraph.constants import Send
from langgraph.graph import END, START, StateGraph

token_max = 1000
url = "https://lilianweng.github.io/posts/2023-06-23-agent/"

llm_ollama = ChatOllama(
model="phi3",
temperature=0,
)
llm_azure_openai = AzureChatOpenAI(
temperature=0,
api_key=getenv("AZURE_OPENAI_API_KEY"),
api_version=getenv("AZURE_OPENAI_API_VERSION"),
azure_endpoint=getenv("AZURE_OPENAI_ENDPOINT"),
model=getenv("AZURE_OPENAI_GPT_MODEL"),
)
# Use the Ollama model
llm = llm_ollama


def length_function(documents: list[Document]) -> int:
"""Get number of tokens for input contents."""
return sum(llm.get_num_tokens(doc.page_content) for doc in documents)


# This will be the overall state of the main graph.
# It will contain the input document contents, corresponding
# summaries, and a final summary.
class OverallState(TypedDict):
# Notice here we use the operator.add
# This is because we want combine all the summaries we generate
# from individual nodes back into one list - this is essentially
# the "reduce" part
contents: list[str]
summaries: Annotated[list, operator.add]
collapsed_summaries: list[Document]
final_summary: str


# This will be the state of the node that we will "map" all
# documents to in order to generate summaries
class SummaryState(TypedDict):
content: str


map_prompt = ChatPromptTemplate.from_messages([("system", "Write a concise summary of the following:\\n\\n{context}")])

map_chain = map_prompt | llm | StrOutputParser()


# Here we generate a summary, given a document
async def generate_summary(state: SummaryState):
response = await map_chain.ainvoke(state["content"])
return {"summaries": [response]}


# Here we define the logic to map out over the documents
# We will use this an edge in the graph
def map_summaries(state: OverallState):
# We will return a list of `Send` objects
# Each `Send` object consists of the name of a node in the graph
# as well as the state to send to that node
return [Send("generate_summary", {"content": content}) for content in state["contents"]]


def collect_summaries(state: OverallState):
return {"collapsed_summaries": [Document(summary) for summary in state["summaries"]]}


# Also available via the hub: `hub.pull("rlm/reduce-prompt")`
reduce_template = """
The following is a set of summaries:
{docs}
Take these and distill it into a final, consolidated summary
of the main themes.
"""

reduce_prompt = ChatPromptTemplate([("human", reduce_template)])

reduce_chain = reduce_prompt | llm | StrOutputParser()


# Add node to collapse summaries
async def collapse_summaries(state: OverallState):
doc_lists = split_list_of_docs(state["collapsed_summaries"], length_function, token_max)
results = []
for doc_list in doc_lists:
results.append(await acollapse_docs(doc_list, reduce_chain.ainvoke))

return {"collapsed_summaries": results}


# This represents a conditional edge in the graph that determines
# if we should collapse the summaries or not
def should_collapse(
state: OverallState,
) -> Literal["collapse_summaries", "generate_final_summary"]:
num_tokens = length_function(state["collapsed_summaries"])
if num_tokens > token_max:
return "collapse_summaries"
else:
return "generate_final_summary"


# Here we will generate the final summary
async def generate_final_summary(state: OverallState):
response = await reduce_chain.ainvoke(state["collapsed_summaries"])
return {"final_summary": response}


async def main():
# Construct the graph
# Nodes:
graph = StateGraph(OverallState)
graph.add_node("generate_summary", generate_summary) # same as before
graph.add_node("collect_summaries", collect_summaries)
graph.add_node("collapse_summaries", collapse_summaries)
graph.add_node("generate_final_summary", generate_final_summary)

# Edges:
graph.add_conditional_edges(START, map_summaries, ["generate_summary"])
graph.add_edge("generate_summary", "collect_summaries")
graph.add_conditional_edges("collect_summaries", should_collapse)
graph.add_conditional_edges("collapse_summaries", should_collapse)
graph.add_edge("generate_final_summary", END)

app = graph.compile()

# create graph image
app.get_graph().draw_mermaid_png(output_file_path="docs/images/15_streamlit_chat_slm.summarize_graph.png")

text_splitter = CharacterTextSplitter.from_tiktoken_encoder(chunk_size=1000, chunk_overlap=0)

loader = WebBaseLoader(web_path=url)
docs = loader.load()

split_docs = text_splitter.split_documents(docs)
print(f"Generated {len(split_docs)} documents.")

async for step in app.astream(
{"contents": [doc.page_content for doc in split_docs]},
{"recursion_limit": 10},
):
print(list(step.keys()))
print(step)


asyncio.run(main())
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Loading