Skip to content

Commit 5bb7b2c

Browse files
Update model.py
1 parent 368bb75 commit 5bb7b2c

File tree

1 file changed

+57
-34
lines changed

1 file changed

+57
-34
lines changed

model.py

Lines changed: 57 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -7,50 +7,73 @@
77
from langchain.llms import CTransformers
88
from langchain.chains import ConversationalRetrievalChain
99

10-
st.title("Conversational Retrieval System")
10+
def add_vertical_space(spaces=1):
11+
for _ in range(spaces):
12+
st.sidebar.markdown("---")
1113

12-
DB_FAISS_PATH = "vectorstore/db_faiss"
13-
TEMP_DIR = "temp"
14+
def main():
15+
st.set_page_config(page_title="Llama-2-GGML CSV Chatbot")
16+
st.title("Llama-2-GGML CSV Chatbot")
1417

15-
# Create temp directory if it doesn't exist
16-
if not os.path.exists(TEMP_DIR):
17-
os.makedirs(TEMP_DIR)
18+
st.sidebar.title("About")
19+
st.sidebar.markdown('''
20+
The Llama-2-GGML CSV Chatbot uses the **Llama-2-7B-Chat-GGML** model.
21+
22+
### 🔄Bot evolving, stay tuned!
23+
24+
## Useful Links 🔗
25+
26+
- **Model:** [Llama-2-7B-Chat-GGML](https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/tree/main) 📚
27+
- **GitHub:** [ThisIs-Developer/Llama-2-GGML-CSV-Chatbot](https://github.com/ThisIs-Developer/Llama-2-GGML-CSV-Chatbot) 💬
28+
''')
1829

19-
# Sidebar for uploading CSV file
20-
uploaded_file = st.sidebar.file_uploader("Upload CSV file", type=['csv'])
30+
DB_FAISS_PATH = "vectorstore/db_faiss"
31+
TEMP_DIR = "temp"
2132

22-
if uploaded_file is not None:
23-
file_path = os.path.join(TEMP_DIR, uploaded_file.name)
24-
with open(file_path, "wb") as f:
25-
f.write(uploaded_file.getvalue())
33+
if not os.path.exists(TEMP_DIR):
34+
os.makedirs(TEMP_DIR)
2635

27-
st.write(f"Uploaded file: {uploaded_file.name}")
28-
st.write("Processing CSV file...")
36+
uploaded_file = st.sidebar.file_uploader("Upload CSV file", type=['csv'])
2937

30-
loader = CSVLoader(file_path=file_path, encoding="utf-8", csv_args={'delimiter': ','})
31-
data = loader.load()
38+
add_vertical_space(1)
39+
st.sidebar.write('Made by [@ThisIs-Developer](https://huggingface.co/ThisIs-Developer)')
3240

33-
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=20)
34-
text_chunks = text_splitter.split_documents(data)
41+
if uploaded_file is not None:
42+
file_path = os.path.join(TEMP_DIR, uploaded_file.name)
43+
with open(file_path, "wb") as f:
44+
f.write(uploaded_file.getvalue())
3545

36-
st.write(f"Total text chunks: {len(text_chunks)}")
46+
st.write(f"Uploaded file: {uploaded_file.name}")
47+
st.write("Processing CSV file...")
3748

38-
embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')
39-
docsearch = FAISS.from_documents(text_chunks, embeddings)
40-
docsearch.save_local(DB_FAISS_PATH)
49+
loader = CSVLoader(file_path=file_path, encoding="utf-8", csv_args={'delimiter': ','})
50+
data = loader.load()
4151

42-
llm = CTransformers(model="models/llama-2-7b-chat.ggmlv3.q4_0.bin",
43-
model_type="llama",
44-
max_new_tokens=512,
45-
temperature=0.1)
52+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=20)
53+
text_chunks = text_splitter.split_documents(data)
4654

47-
qa = ConversationalRetrievalChain.from_llm(llm, retriever=docsearch.as_retriever())
55+
st.write(f"Total text chunks: {len(text_chunks)}")
4856

49-
st.write("Enter your query:")
50-
query = st.text_input("Input Prompt:")
51-
if query:
52-
chat_history = []
53-
result = qa({"question": query, "chat_history": chat_history})
54-
st.write("Response:", result['answer'])
57+
embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')
58+
docsearch = FAISS.from_documents(text_chunks, embeddings)
59+
docsearch.save_local(DB_FAISS_PATH)
5560

56-
os.remove(file_path) # Remove the temporary file after processing
61+
llm = CTransformers(model="models/llama-2-7b-chat.ggmlv3.q4_0.bin",
62+
model_type="llama",
63+
max_new_tokens=512,
64+
temperature=0.1)
65+
66+
qa = ConversationalRetrievalChain.from_llm(llm, retriever=docsearch.as_retriever())
67+
68+
st.write("Enter your query:")
69+
query = st.text_input("Input Prompt:")
70+
if query:
71+
with st.spinner("Processing your question..."):
72+
chat_history = []
73+
result = qa({"question": query, "chat_history": chat_history})
74+
st.write("Response:", result['answer'])
75+
76+
os.remove(file_path)
77+
78+
if __name__ == "__main__":
79+
main()

0 commit comments

Comments
 (0)