Skip to content

Commit 9f653cf

Browse files
committed
RAG MD Python Code for BYOC and Streamlit application
1 parent c6f0dba commit 9f653cf

File tree

11 files changed

+294
-0
lines changed

11 files changed

+294
-0
lines changed
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
FROM python:3.10
2+
3+
RUN mkdir /app
4+
5+
WORKDIR /app
6+
7+
COPY requirements.txt requirements.txt
8+
COPY main.py main.py
9+
COPY start.sh start.sh
10+
11+
# Installing the server dependencies.
12+
RUN pip3 install -r requirements.txt
13+
14+
EXPOSE 8080
15+
16+
RUN chmod +x start.sh
17+
18+
CMD ["./start.sh"]
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
"""The main model serving HTTP server. Creates the following endpoints:
2+
3+
/predict (POST) - model prediction endpoint
4+
"""
5+
from fastapi import FastAPI, Body, Request, Response, status
6+
from fastapi.responses import HTMLResponse, JSONResponse
7+
import logging
8+
from langchain.embeddings import LlamaCppEmbeddings
9+
from langchain.vectorstores import Qdrant
10+
11+
12+
fast_app = FastAPI()
13+
model_path = "/opt/ds/model/deployed_model/7B/ggml-model-q4_0.bin"
14+
15+
def load_model(model_folder_directory):
16+
embedding = LlamaCppEmbeddings(model_path=model_folder_directory)
17+
return embedding
18+
19+
try:
20+
logging.info("Loading the model")
21+
embedding = load_model(model_path)
22+
except Exception as e:
23+
print("Error: %s", e)
24+
25+
url = "https://0ad84320-52a6-407d-9c82-375bf60e1fc6.us-east4-0.gcp.cloud.qdrant.io"
26+
api_key= "a675QyMVF8SxqY9wNAssu4dwuIpbHGuXj8aZVDPBKX22AJeBGCOhqw"
27+
28+
29+
qdrant = None
30+
text_count = 0
31+
32+
@fast_app.get("/", response_class=HTMLResponse)
33+
def read_root():
34+
return """
35+
<h2>Hello! Welcome to the model serving api.</h2>
36+
Check the <a href="/docs">api specs</a>.
37+
"""
38+
39+
@fast_app.post("/predict")
40+
def model_predict(request: Request, response: Response, data=Body(None)):
41+
global embedding, qdrant, text_count, url, api_key
42+
text = data.decode("utf-8")
43+
try:
44+
if qdrant is None:
45+
qdrant = Qdrant.from_texts(
46+
text,
47+
embedding,
48+
url=url,
49+
api_key=api_key,
50+
collection_name="my_documents"
51+
)
52+
else:
53+
qdrant.add_texts(text)
54+
text_count += 1
55+
result = "Sentence Added: Total sentences count is " + str(text_count)
56+
except Exception as e:
57+
result = "Error " + str(e)
58+
return result
59+
60+
'''
61+
Health GET endpoint returning the health status
62+
'''
63+
@fast_app.get("/health")
64+
def model_predict1(request: Request, response: Response):
65+
return {"status":"success"}
66+
67+
if __name__ == "__main__":
68+
uvicorn.run("main:fast_app", port=8080,reload=True)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
langchain
2+
llama-cpp-python
3+
requests
4+
uvicorn
5+
fastapi
6+
qdrant-client
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
#!/bin/bash
2+
3+
uvicorn main:fast_app --port 8080 --host=0.0.0.0
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
## TO ADD
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
FROM python:3.10
2+
3+
RUN mkdir /app
4+
5+
WORKDIR /app
6+
7+
COPY requirements.txt requirements.txt
8+
COPY main.py main.py
9+
COPY start.sh start.sh
10+
11+
# Installing the server dependencies.
12+
RUN pip3 install -r requirements.txt
13+
14+
EXPOSE 8080
15+
16+
RUN chmod +x start.sh
17+
18+
CMD ["./start.sh"]
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
"""The main model serving HTTP server. Creates the following endpoints:
2+
3+
/predict (POST) - model prediction endpoint
4+
"""
5+
from fastapi import FastAPI, Body, Request, Response, status
6+
from fastapi.responses import HTMLResponse, JSONResponse
7+
import logging
8+
from langchain.embeddings import LlamaCppEmbeddings
9+
from langchain.chains.question_answering import load_qa_chain
10+
from langchain.llms import LlamaCpp
11+
from langchain.prompts.prompt import PromptTemplate
12+
from langchain.vectorstores import Qdrant
13+
import qdrant_client
14+
15+
fast_app = FastAPI()
16+
17+
model_path = "/opt/ds/model/deployed_model/7B/ggml-model-q4_0.bin"
18+
19+
def load_model(model_folder_directory):
20+
embedding = LlamaCppEmbeddings(model_path=model_folder_directory,n_gpu_layers=15000)
21+
return embedding
22+
23+
try:
24+
logging.info("Loading the model")
25+
embeddings = load_model(model_path)
26+
except Exception as e:
27+
print("Error: %s", e)
28+
29+
url = "QDRANT_URL"
30+
api_key= "API_KEY"
31+
32+
template = """You are an assistant to the user, you are given some context below, please answer the query of the user with as detail as possible
33+
34+
Context:\"""
35+
{context}
36+
\"""
37+
38+
Question:\"
39+
{question}
40+
\"""
41+
42+
Answer:"""
43+
44+
45+
46+
client = qdrant_client.QdrantClient(
47+
url,
48+
api_key=api_key
49+
)
50+
51+
qdrant = Qdrant(
52+
client=client, collection_name="my_documents",
53+
embeddings=embeddings
54+
)
55+
56+
qa_prompt = PromptTemplate.from_template(template)
57+
58+
llm = LlamaCpp(model_path=model_path,n_gpu_layers=15000, n_ctx=2048)
59+
# llm = LlamaCpp(model_path=model_path, n_ctx=2048)
60+
61+
@fast_app.get("/", response_class=HTMLResponse)
62+
def read_root():
63+
return """
64+
<h2>Hello! Welcome to the model serving api.</h2>
65+
Check the <a href="/docs">api specs</a>.
66+
"""
67+
68+
@fast_app.post("/predict")
69+
def model_predict(request: Request, response: Response, data=Body(None)):
70+
global llm, embeddings, qa_prompt, qdrant
71+
print(data)
72+
question = data.decode("utf-8")
73+
print(question)
74+
chain = load_qa_chain(llm, chain_type="stuff", prompt=qa_prompt)
75+
print("OK")
76+
if question =="Hi":
77+
return "I am able to load the embedding"
78+
if question == "Hello":
79+
docs = qdrant.similarity_search(question)
80+
return docs
81+
try:
82+
docs = qdrant.similarity_search(question)
83+
print(docs)
84+
except Exception as e:
85+
print(e)
86+
return e
87+
print(question)
88+
answer = chain({"input_documents": docs, "question": question,"context": docs}, return_only_outputs=True)['output_text']
89+
return answer
90+
91+
'''
92+
Health GET endpoint returning the health status
93+
'''
94+
@fast_app.get("/health")
95+
def model_predict1(request: Request, response: Response):
96+
return {"status":"success"}
97+
98+
if __name__ == "__main__":
99+
uvicorn.run("main:fast_app", port=8080, reload=True)
100+
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
langchain==0.0.333
2+
llama-cpp-python==0.2.15
3+
oci==2.47.1
4+
requests==2.25.1
5+
uvicorn
6+
fastapi
7+
qdrant-client==1.6.9
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
#!/bin/bash
2+
3+
uvicorn main:fast_app --port 8080 --host=0.0.0.0
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import streamlit as st
2+
from streamlit_chat import message
3+
import oci
4+
import time
5+
import os
6+
import oci
7+
import requests
8+
from oci.signer import Signer
9+
10+
# token_file = os.path.expanduser("/Users/gagachau/.oci/sessions/OC1/token")
11+
# with open(token_file, 'r') as f:
12+
# token = f.read()
13+
# private_key = oci.signer.load_private_key_from_file("/Users/gagachau/.oci/sessions/OC1/oci_api_key.pem")
14+
# signer = oci.auth.signers.SecurityTokenSigner(token, private_key)
15+
16+
17+
def generate_response(prompt):
18+
# global signer
19+
endpoint = "http://localhost:8080/predict"
20+
headers = {"content-type": "application/text"} # header goes here
21+
# response = requests.post(endpoint, data=prompt, auth=signer, headers=headers)
22+
response = requests.post(endpoint, data=prompt, headers=headers)
23+
res = response.text
24+
print(res)
25+
res = res.replace('\n', '')
26+
res = res.replace("\n", "")
27+
res = res.replace('"', "")
28+
res = res.replace("'", "")
29+
res = res.replace('\\', "")
30+
return res
31+
32+
# Create the title and
33+
st.set_page_config(page_title="SQuAD Chatbot")
34+
35+
# create the header and the line underneath it
36+
header_html = "<h1 style='text-align: center; margin-bottom: 1px;'>🤖 The SQuAD Chatbot 🤖</h1>"
37+
line_html = "<hr style='border: 2px solid green; margin-top: 1px; margin-bottom: 0px;'>"
38+
st.markdown(header_html, unsafe_allow_html=True)
39+
st.markdown(line_html, unsafe_allow_html=True)
40+
41+
# create lists to store user queries and generated responses
42+
if "generated" not in st.session_state:
43+
st.session_state["generated"] = []
44+
if "past" not in st.session_state:
45+
st.session_state["past"] = []
46+
47+
48+
# create input field for user queries
49+
user_input = st.chat_input("How can I help?")
50+
51+
# generate response when a user prompt is submitted
52+
if user_input:
53+
output = generate_response(prompt=user_input)
54+
print(output)
55+
st.session_state.past.append(user_input)
56+
st.session_state.generated.append(output)
57+
58+
59+
# show queries and responses in the user interface
60+
if st.session_state["generated"]:
61+
62+
for i in range(len(st.session_state["generated"])):
63+
message(st.session_state["past"][i], is_user=True, key=str(i) + "_user")
64+
message(st.session_state["generated"][i], key=str(i))
65+
66+

0 commit comments

Comments
 (0)