Skip to content

Commit 1878365

Browse files
demo-code
1 parent 7481c8f commit 1878365

File tree

10 files changed

+853
-0
lines changed

10 files changed

+853
-0
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
> [!CAUTION]
2+
> Under construction.
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import streamlit as st
2+
3+
st.set_page_config(
4+
page_title="Hello",
5+
page_icon="👋",
6+
)
7+
8+
st.write("# Welcome to Streamlit! 👋")
9+
10+
st.sidebar.success("Select a demo above.")
11+
12+
st.markdown(
13+
"""
14+
This is a demo!
15+
"""
16+
)
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
ID,Message
2+
1,I had to cancel my order because of poor service.
3+
2,"The delivery was late, and the packaging was damaged."
4+
3,I was sent the wrong color of the product.
5+
4,My order was incomplete when it arrived.
6+
5,The product I received was damaged.
7+
6,The quality of the product is much worse than expected.
8+
7,The product stopped working after a short period of time.
9+
8,The product doesn’t match the description on the website.
10+
9,I’ve had to contact customer service multiple times for the same issue.
11+
10,Customer support was not helpful at all.
12+
11,The quality of the product was poor.
13+
12,The product was much smaller than I expected.
14+
13,I had trouble finding the product on your website.
15+
14,The instructions were unclear and hard to follow.
16+
15,The website was difficult to navigate during my purchase.
17+
16,I received the wrong size and need a replacement.
18+
17,I was given false information about the product.
19+
18,The product stopped working after a short period of time.
20+
19,The product arrived damaged and unusable.
21+
20,The product arrived in terrible condition.
22+
21,The product arrived damaged and unusable.
23+
22,The customer service was slow to respond.
24+
23,The product was missing some essential accessories.
25+
24,I didn’t receive any confirmation email for my order.
26+
25,The product wasn’t compatible with my other appliances.
27+
26,The product is faulty and doesn’t work properly.
28+
27,The product didn’t fit as expected.
29+
28,The product was extremely hard to set up.
30+
29,I am unhappy with the design of the product.
31+
30,The website was difficult to navigate during my purchase.
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
import json
2+
import logging
3+
from typing import List
4+
5+
from langchain_community.chat_models.oci_generative_ai import ChatOCIGenAI
6+
from langchain_community.embeddings import OCIGenAIEmbeddings
7+
from langchain_core.messages import HumanMessage, SystemMessage
8+
from langchain_core.pydantic_v1 import BaseModel
9+
from langgraph.checkpoint.memory import MemorySaver
10+
from langgraph.graph import END, StateGraph
11+
12+
import backend.message_handler as handler
13+
import backend.utils.llm_config as llm_config
14+
15+
# Set up logging
16+
logging.getLogger("oci").setLevel(logging.DEBUG)
17+
messages_path = "ai/generative-ai-service/sentiment+categorization/demo_code/backend/data/complaints_messages.csv"
18+
19+
20+
class AgentState(BaseModel):
21+
messages_info: List = []
22+
categories: List = []
23+
reports: List = []
24+
25+
26+
class FeedbackAgent:
27+
def __init__(self, model_name: str = "cohere_oci"):
28+
self.model_name = model_name
29+
self.model = self.initialize_model()
30+
self.memory = MemorySaver()
31+
self.builder = self.setup_graph()
32+
self.messages = self.read_messages()
33+
34+
def initialize_model(self):
35+
if self.model_name not in llm_config.MODEL_REGISTRY:
36+
raise ValueError(f"Unknown model: {self.model_name}")
37+
38+
model_config = llm_config.MODEL_REGISTRY[self.model_name]
39+
40+
return ChatOCIGenAI(
41+
model_id=model_config["model_id"],
42+
service_endpoint=model_config["service_endpoint"],
43+
compartment_id=model_config["compartment_id"],
44+
provider=model_config["provider"],
45+
auth_type=model_config["auth_type"],
46+
auth_profile=model_config["auth_profile"],
47+
model_kwargs=model_config["model_kwargs"],
48+
)
49+
50+
def initialize_embeddings(self):
51+
if self.model_name not in llm_config.MODEL_REGISTRY:
52+
raise ValueError(f"Unknown model: {self.model_name}")
53+
54+
model_config = llm_config.MODEL_REGISTRY[self.model_name]
55+
56+
embeddings = OCIGenAIEmbeddings(
57+
model_id=model_config["embedding_model"],
58+
service_endpoint=model_config["service_endpoint"],
59+
truncate="NONE",
60+
compartment_id=model_config["compartment_id"],
61+
auth_type=model_config["auth_type"],
62+
auth_profile=model_config["auth_profile"],
63+
)
64+
return embeddings
65+
66+
def read_messages(self):
67+
messages = handler.read_messages(filepath=messages_path)
68+
return handler.batchify(messages, 30)
69+
70+
def summarization_node(self, state: AgentState):
71+
batch = self.messages
72+
response = self.model.invoke(
73+
[
74+
SystemMessage(
75+
content=llm_config.get_prompt(self.model_name, "SUMMARIZATION")
76+
),
77+
HumanMessage(content=f"Message batch: {batch}"),
78+
]
79+
)
80+
state.messages_info = state.messages_info + [json.loads(response.content)]
81+
return {"messages_info": state.messages_info}
82+
83+
def categorization_node(self, state: AgentState):
84+
batch = state.messages_info
85+
response = self.model.invoke(
86+
[
87+
SystemMessage(
88+
content=llm_config.get_prompt(
89+
self.model_name, "CATEGORIZATION_SYSTEM"
90+
)
91+
),
92+
HumanMessage(
93+
content=llm_config.get_prompt(
94+
self.model_name, "CATEGORIZATION_USER"
95+
).format(MESSAGE_BATCH=batch)
96+
),
97+
]
98+
)
99+
content = [json.loads(response.content)]
100+
state.categories = state.categories + handler.match_categories(batch, content)
101+
return {"categories": state.categories}
102+
103+
def generate_report_node(self, state: AgentState):
104+
response = self.model.invoke(
105+
[
106+
SystemMessage(
107+
content=llm_config.get_prompt(self.model_name, "REPORT_GEN")
108+
),
109+
HumanMessage(content=f"Message info: {state.categories}"),
110+
]
111+
)
112+
state.reports = response.content
113+
return {"reports": [response.content]}
114+
115+
def setup_graph(self):
116+
builder = StateGraph(AgentState)
117+
builder.add_node("summarize", self.summarization_node)
118+
builder.add_node("categorize", self.categorization_node)
119+
builder.add_node("generate_report", self.generate_report_node)
120+
121+
builder.set_entry_point("summarize")
122+
builder.add_edge("summarize", "categorize")
123+
builder.add_edge("categorize", "generate_report")
124+
125+
builder.add_edge("generate_report", END)
126+
return builder.compile(checkpointer=self.memory)
127+
128+
def get_graph(self):
129+
return self.builder.get_graph()
130+
131+
def run(self):
132+
thread = {"configurable": {"thread_id": "1"}}
133+
for s in self.builder.stream(
134+
config=thread,
135+
):
136+
print(f"\n \n{s}")
137+
138+
def run_step_by_step(self):
139+
thread = {"configurable": {"thread_id": "1"}}
140+
initial_state = {
141+
"messages_info": [],
142+
"categories": [],
143+
"reports": [],
144+
}
145+
for state in self.builder.stream(initial_state, thread):
146+
yield state # Yield each intermediate step to allow step-by-step execution
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from backend.feedback_agent import FeedbackAgent
2+
3+
4+
class FeedbackAgentWrapper:
5+
def __init__(self):
6+
self.agent = FeedbackAgent()
7+
self.run_graph = self.agent.run_step_by_step()
8+
9+
def get_nodes_edges(self):
10+
graph_data = self.agent.get_graph()
11+
nodes = list(graph_data.nodes.keys())
12+
edges = [(edge.source, edge.target) for edge in graph_data.edges]
13+
return nodes, edges
14+
15+
def run_step_by_step(self):
16+
try:
17+
action_output = next(self.run_graph)
18+
current_node = list(action_output.keys())[0]
19+
except StopIteration:
20+
action_output = {}
21+
current_node = "FINALIZED"
22+
return current_node, action_output
23+
24+
def get_graph(self):
25+
return self.agent.get_graph()
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import csv
2+
from typing import List
3+
4+
5+
def read_messages(
6+
filepath: str, columns: List[str] = ["ID", "Message"]
7+
) -> List[List[str]]:
8+
with open(filepath, newline="", encoding="utf-8") as file:
9+
reader = csv.DictReader(file)
10+
extracted_data = []
11+
12+
for row in reader:
13+
extracted_row = [row[col] for col in columns if col in row]
14+
extracted_data.append(extracted_row)
15+
16+
return extracted_data
17+
18+
19+
def batchify(lst, batch_size):
20+
return [lst[i : i + batch_size] for i in range(0, len(lst), batch_size)]
21+
22+
23+
def match_categories(summaries, categories):
24+
result = []
25+
for i, elem in enumerate(summaries[0]):
26+
if elem["id"] == categories[0][i]["id"]:
27+
elem["primary_category"] = categories[0][i]["primary_category"]
28+
elem["secondary_category"] = categories[0][i]["secondary_category"]
29+
elem["tertiary_category"] = categories[0][i]["tertiary_category"]
30+
result.append(elem)
31+
return result
32+
33+
34+
def group_by_category_level(categories_list):
35+
result = {}
36+
37+
for category in categories_list:
38+
primary = category["primary_category"]
39+
secondary = category["secondary_category"]
40+
tertiary = category["tertiary_category"]
41+
42+
if primary not in result:
43+
result[primary] = {}
44+
45+
if secondary not in result[primary]:
46+
result[primary][secondary] = {}
47+
48+
if tertiary not in result[primary][secondary]:
49+
result[primary][secondary][tertiary] = []
50+
51+
result[primary][secondary][tertiary].append(category["id"])
52+
53+
return result
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# config.py
2+
# Author: Ansh
3+
4+
DB_TYPE = "qdrant" # Options: "oracle", "qdrant"
5+
6+
# OracleDB Configuration
7+
ORACLE_DB_USER = "ansh" # Enter your oracle vector Db username
8+
ORACLE_DB_PWD = "Gena#######" # Enter your oracle vector Db password
9+
ORACLE_DB_HOST_IP = "######" # Enter your oracle vector Db host ip
10+
ORACLE_DB_PORT = 1521 # Enter your oracle vector Db host port
11+
ORACLE_DB_SERVICE = "orclpdb01.sub05101349370.bpivcnllm.oraclevcn.com"
12+
13+
ORACLE_USERNAME = ORACLE_DB_USER
14+
ORACLE_PASSWORD = ORACLE_DB_PWD
15+
ORACLE_DSN = f"{ORACLE_DB_HOST_IP}:{ORACLE_DB_PORT}/{ORACLE_DB_SERVICE}"
16+
ORACLE_TABLE_NAME = (
17+
"policyTable" # name of table where you want to store the embeddings in oracle DB
18+
)
19+
20+
# Qdrant Configuration
21+
QDRANT_LOCATION = ":memory:"
22+
QDRANT_COLLECTION_NAME = (
23+
"my_documents" # name of table where you want to store the embeddings in qdrant DB
24+
)
25+
QDRANT_DISTANCE_FUNC = "Dot"
26+
27+
# Common Configuration
28+
USER_ID = ""
29+
COMPARTMENT_ID = "ocid1.compartment.oc1..XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
30+
OBJECT_STORAGE_LINK = "https://objectstorage.eu-frankfurt-1.oraclecloud.com/n/##############/b/##########/o/"
31+
DIRECTORY = "data" # directory to store the pdf's from where the RAG model should take the documents from
32+
AUTH_TYPE = "API_KEY"
33+
CONFIG_PROFILE = "DEFAULT"
34+
PROMPT_CONTEXT = "You are an AI Assistant trained to give answers based only on the information provided. Given only the above text provided and not prior knowledge, answer the query. If someone asks you a question and you don't know the answer, don't try to make up a response, simply say: I don't know."
35+
ENDPOINT = "https://inference.generativeai.eu-frankfurt-1.oci.oraclecloud.com" #change in case you want to select a diff region
36+
37+
# COHERE data
38+
PROVIDER_COHERE = "cohere"
39+
EMBEDDING_MODEL_COHERE = "cohere.embed-english-v3.0"
40+
GENERATE_MODEL_COHERE = "cohere.command-r-plus-08-2024" # "ocid1.generativeaimodel.oc1.us-chicago-1.amaaaaaask7dceyanrlpnq5ybfu5hnzarg7jomak3q6kyhkzjsl4qj24fyoq"# cohere.command-r-16k or cohere.command-r-plus
41+
42+
# LLAMA data
43+
PROVIDER_LLAMA = "meta"
44+
GENERATE_MODEL_LLAMA_33= "ocid1.generativeaimodel.oc1.eu-frankfurt-1.amaaaaaask7dceya4tdabclcsqbc3yj2mozvvqoq5ccmliv3354hfu3mx6bq"
45+
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# import utils.config as config
2+
# from utils import prompts as prompts
3+
4+
import backend.utils.config as config
5+
from backend.utils import prompts as prompts
6+
7+
8+
def get_prompt(model_name: str, prompt_type: str) -> str:
9+
if model_name not in PROMPT_SETS:
10+
raise ValueError(f"No prompts defined for model {model_name}")
11+
if prompt_type not in PROMPT_SETS[model_name]:
12+
raise ValueError(f"Unknown prompt type: {prompt_type}")
13+
return PROMPT_SETS[model_name][prompt_type]
14+
15+
16+
MODEL_REGISTRY = {
17+
"cohere_oci": {
18+
"model_id": config.GENERATE_MODEL_COHERE,
19+
"service_endpoint": config.ENDPOINT,
20+
"compartment_id": config.COMPARTMENT_ID,
21+
"provider": config.PROVIDER_COHERE,
22+
"auth_type": config.AUTH_TYPE,
23+
"auth_profile": config.CONFIG_PROFILE,
24+
"model_kwargs": {"temperature": 0, "max_tokens": 4000},
25+
"embedding_model": config.EMBEDDING_MODEL_COHERE,
26+
},
27+
"meta_oci": {
28+
"model_id": config.GENERATE_MODEL_LLAMA_33,
29+
"service_endpoint": config.ENDPOINT,
30+
"compartment_id": config.COMPARTMENT_ID,
31+
"provider": config.PROVIDER_LLAMA,
32+
"auth_type": config.AUTH_TYPE,
33+
"auth_profile": config.CONFIG_PROFILE,
34+
"model_kwargs": {"temperature": 0, "max_tokens": 2000},
35+
},
36+
}
37+
38+
PROMPT_SETS = {
39+
"cohere_oci": {
40+
"SUMMARIZATION": prompts.SUMMARIZATION,
41+
"CATEGORIZATION_SYSTEM": prompts.CATEGORIZATION_SYSTEM,
42+
"CATEGORIZATION_USER": prompts.CATEGORIZATION_USER,
43+
"REPORT_GEN": prompts.REPORT_GEN,
44+
},
45+
"meta_oci": {
46+
"SUMMARIZATION_LLAMA": prompts.SUMMARIZATION_LLAMA,
47+
"CATEGORIZATION_LLAMA": prompts.CATEGORIZATION_LLAMA,
48+
"REPORT_GEN_LLAMA": prompts.REPORT_GEN_LLAMA,
49+
},
50+
}

0 commit comments

Comments
 (0)