forked from TheAurder/Private-LLM-for-Data-Integration
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
105 lines (79 loc) · 3.95 KB
/
app.py
File metadata and controls
105 lines (79 loc) · 3.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import streamlit as st
from llm_chains import load_normal_chain, load_pdf_chat_chain
from langchain.memory import StreamlitChatMessageHistory
from utils import save_chat_history_json, get_timestamp, load_chat_history_json
from pdf_handler import add_documents_to_db
from html_templates import get_bot_template, get_user_template, css
import yaml
import os
with open("config.yaml", "r") as f:
config = yaml.safe_load(f)
def load_chain(chat_history):
if st.session_state.pdf_chat:
print("loading pdf chat chain")
return load_pdf_chat_chain(chat_history)
return load_normal_chain(chat_history)
def clear_input_field():
if st.session_state.user_question == "":
st.session_state.user_question = st.session_state.user_input
st.session_state.user_input = ""
def set_send_input():
st.session_state.send_input = True
clear_input_field()
def toggle_pdf_chat():
st.session_state.pdf_chat = True
def save_chat_history():
if st.session_state.history != []:
if st.session_state.session_key == "new_session":
st.session_state.new_session_key = get_timestamp() + ".json"
save_chat_history_json(st.session_state.history, config["chat_history_path"] + st.session_state.new_session_key)
else:
save_chat_history_json(st.session_state.history, config["chat_history_path"] + st.session_state.session_key)
def main():
st.title("Private LLM Chat App")
st.write(css, unsafe_allow_html=True)
st.sidebar.title("Chat Sessions")
chat_sessions = ["new_session"] + os.listdir(config["chat_history_path"])
if "send_input" not in st.session_state:
st.session_state.session_key = "new_session"
st.session_state.send_input = False
st.session_state.user_question = ""
st.session_state.new_session_key = None
st.session_state.session_index_tracker = "new_session"
if st.session_state.session_key == "new_session" and st.session_state.new_session_key != None:
st.session_state.session_index_tracker = st.session_state.new_session_key
st.session_state.new_session_key = None
index = chat_sessions.index(st.session_state.session_index_tracker)
st.sidebar.selectbox("Select a chat session", chat_sessions, key="session_key", index=index)
st.sidebar.toggle("PDF Chat", key="pdf_chat", value=False)
if st.session_state.session_key != "new_session":
st.session_state.history = load_chat_history_json(config["chat_history_path"] + st.session_state.session_key)
else:
st.session_state.history = []
chat_history = StreamlitChatMessageHistory(key="history")
user_input = st.text_input("Type your message here", key="user_input", on_change=set_send_input)
voice_recording_column, send_button_column = st.columns(2)
chat_container = st.container()
with send_button_column:
send_button = st.button("Send", key="send_button", on_click=clear_input_field)
uploaded_pdf = st.sidebar.file_uploader("Upload a pdf file", accept_multiple_files=True, key="pdf_upload", type=["pdf"], on_change=toggle_pdf_chat)
if uploaded_pdf:
with st.spinner("Processing pdf..."):
add_documents_to_db(uploaded_pdf)
if send_button or st.session_state.send_input:
if st.session_state.user_question != "":
llm_chain = load_chain(chat_history)
llm_response = llm_chain.run(st.session_state.user_question)
st.session_state.user_question = ""
st.session_state.send_input = False
if chat_history.messages != []:
with chat_container:
st.write("Chat History:")
for message in reversed(chat_history.messages):
if message.type == "human":
st.write(get_user_template(message.content), unsafe_allow_html=True)
else:
st.write(get_bot_template(message.content), unsafe_allow_html=True)
save_chat_history()
if __name__ == "__main__":
main()