-
Notifications
You must be signed in to change notification settings - Fork 10
Expand file tree
/
Copy pathprompt_agent.py
More file actions
278 lines (239 loc) · 17.4 KB
/
prompt_agent.py
File metadata and controls
278 lines (239 loc) · 17.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
import json
import traceback
from dotenv import load_dotenv
import openai
import pandas as pd
from collections import Counter
from prototxt_parser.prototxt import parse
import os
from solid_step_helper import clean_up_llm_output_func
import networkx as nx
import jsonlines
import json
import re
import time
import sys
import numpy as np
from langchain.prompts import PromptTemplate, FewShotPromptTemplate
from langchain.chains import LLMChain
import warnings
from langchain._api import LangChainDeprecationWarning
from langchain_chroma import Chroma
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_openai import AzureOpenAIEmbeddings
warnings.simplefilter("ignore", category=LangChainDeprecationWarning)
prompt_suffix = """Begin! Remember to ensure that you generate valid Python code in the following format:
Answer:
```python
${{Code that will answer the user question or request}}
```
Question: {input}
"""
EXAMPLE_LIST = [
{
"question": "Update the physical capacity value of ju1.a3.m2.s2c4.p10 to 72. Return a graph.",
"answer": r'''def process_graph(graph_data):
graph_copy = copy.deepcopy(graph_data)
for node in graph_copy.nodes(data=True):
if node[1]['name'] == 'ju1.a3.m2.s2c4.p10' and 'EK_PORT' in node[1]['type']:
node[1]['physical_capacity_bps'] = 72
break
graph_json = nx.readwrite.json_graph.node_link_data(graph_copy)
# in the return_object, it should be a json object with three keys, 'type', 'data' and 'updated_graph'.
return return_object''',
},
{
"question": "Add new node with name new_EK_PORT_82 type EK_PORT, to ju1.a2.m4.s3c6. Return a graph.",
"answer": r'''def process_graph(graph_data):
graph_copy = copy.deepcopy(graph_data)
graph_copy.add_node('new_EK_PORT_82', type=['EK_PORT'], physical_capacity_bps=1000)
graph_copy.add_edge('ju1.a2.m4.s3c6', 'new_EK_PORT_82', type=['RK_CONTAINS'])
graph_json = nx.readwrite.json_graph.node_link_data(graph_copy)
# in the return_object, it should be a json object with three keys, 'type', 'data' and 'updated_graph'.
return return_object''',
},
{
"question": "Count the EK_PACKET_SWITCH in the ju1.a2.dom. Return only the count number.",
"answer": r'''def process_graph(graph_data):
graph_copy = graph_data.copy()
count = 0
for node in graph_copy.nodes(data=True):
if 'EK_PACKET_SWITCH' in node[1]['type'] and node[0].startswith('ju1.a2.'):
count += 1
# the return_object should be a json object with three keys, 'type', 'data' and 'updated_graph'.
return return_object''',
},
{
"question": "Remove ju1.a1.m4.s3c6.p1 from the graph. Return a graph.",
"answer": r'''def process_graph(graph_data):
graph_copy = graph_data.copy()
node_to_remove = None
for node in graph_copy.nodes(data=True):
if node[0] == 'ju1.a1.m4.s3c6.p1':
node_to_remove = node[0]
break
if node_to_remove:
graph_copy.remove_node(node_to_remove)
graph_json = nx.readwrite.json_graph.node_link_data(graph_copy)
# in the return_object, it should be a json object with three keys, 'type', 'data' and 'updated_graph'.
return return_object''',
},
]
class BasePromptAgent:
def __init__(self):
self.prompt_prefix = self.generate_prompt()
def generate_prompt(self):
# old_prompt = """
# Generate the Python code needed to process the network graph to answer the user question or request. The network graph data is stored as a networkx graph object, the Python code you generate should be in the form of a function named process_graph that takes a single input argument graph_data and returns a single object return_object. The input argument graph_data will be a networkx graph object with nodes and edges.
# The graph is directed and each node has a 'name' attribute to represent itself.
# Each node has a 'type' attribute, in the format of EK_TYPE. 'type' must be a list, can include ['EK_SUPERBLOCK', 'EK_CHASSIS', 'EK_RACK', 'EK_AGG_BLOCK', 'EK_JUPITER', 'EK_PORT', 'EK_SPINEBLOCK', 'EK_PACKET_SWITCH', 'EK_CONTROL_POINT', 'EK_CONTROL_DOMAIN'].
# Each node can have other attributes depending on its type.
# Each directed edge also has a 'type' attribute, include RK_CONTAINS, RK_CONTROL.
# You should check relationship based on edge, check name based on node attribute.
# Nodes has hierarchy: CHASSIS contains PACKET_SWITCH, JUPITER contains SUPERBLOCK, SUPERBLOCK contains AGG_BLOCK, AGG_BLOCK contains PACKET_SWITCH, PACKET_SWITCH contains PORT.
# Each PORT node has an attribute 'physical_capacity_bps'. For example, a PORT node name is ju1.a1.m1.s2c1.p3.
# When calculating capacity of a node, you need to sum the physical_capacity_bps on the PORT of each hierarchy contains in this node.
# When update a graph, always create a graph copy, do not modify the input graph.
# To find node based on type, check the name and type list. For example, [node[0] == 'ju1.a1.m1.s2c1' and 'EK_PACKET_SWITCH' in node[1]['type']].
# Do not use multi-layer function. The output format should only return one object. The return_object will be a JSON object with two keys, 'type' and 'data' and "updated_graph". The 'type' key should indicate the output format depending on the user query or request. It should be one of 'text', 'list', 'table' or 'graph'.
# The 'data' key should contain the data needed to render the output. If the output type is 'text' then the 'data' key should contain a string. If the output type is 'list' then the 'data' key should contain a list of items.
# The 'updated_graph' key should contain the updated graph, no matter what the output type is. It should be a graph json "graph_json = nx.readwrite.json_graph.node_link_data(graph_copy)".
# If the output type is 'table' then the 'data' key should contain a list of lists where each list represents a row in the table.If the output type is 'graph' then the 'data' key should be a graph json "graph_json = nx.readwrite.json_graph.node_link_data(graph_copy)".
# node.startswith will not work for the node name. you have to check the node name with the node['name'].
# Context: When the user requests to make changes to the graph, it is generally appropriate to return the graph.
# In the Python code you generate, you should process the networkx graph object to produce the needed output.
# Remember, your reply should always start with string "\nAnswer:\n", and you should generate a function called "def process_graph".
# All of your output should only contain the defined function without example usages, no additional text, and display in a Python code block.
# Do not include any package import in your answer.
# """
prompt = """
You need to behave like a network engineer who processes graph data to answer user queries about capacity planning.
Your task is to generate the Python code needed to process the network graph to answer the user question or request. The code should take the form of a function named process_graph that accepts a single input argument graph_data and returns a single object return_object.
Graph Structure:
- The input graph_data is a networkx graph object with nodes and edges
- The graph is directed and each node has a 'name' attribute to represent itself
- Each node has a 'type' attribute in the format of EK_TYPE. 'type' must be a list, which can include ['EK_SUPERBLOCK', 'EK_CHASSIS', 'EK_RACK', 'EK_AGG_BLOCK', 'EK_JUPITER', 'EK_PORT', 'EK_SPINEBLOCK', 'EK_PACKET_SWITCH', 'EK_CONTROL_POINT', 'EK_CONTROL_DOMAIN']
- Each node can have other attributes depending on its type
- Each directed edge also has a 'type' attribute, including RK_CONTAINS, RK_CONTROL
Important Guidelines:
- Check relationships based on edge, check name based on node attribute
- Nodes follow this hierarchy: CHASSIS contains PACKET_SWITCH, JUPITER contains SUPERBLOCK, SUPERBLOCK contains AGG_BLOCK, AGG_BLOCK contains PACKET_SWITCH, PACKET_SWITCH contains PORT
- Each PORT node has an attribute 'physical_capacity_bps'. For example, a PORT node name is ju1.a1.m1.s2c1.p3
- When calculating capacity of a node, sum the physical_capacity_bps on the PORT of each hierarchy contained in this node
- When updating a graph, always create a graph copy, do not modify the input graph
- To find a node based on type, check the name and type list. For example, [node[0] == 'ju1.a1.m1.s2c1' and 'EK_PACKET_SWITCH' in node[1]['type']]
- node.startswith will not work for the node name. You have to check the node name with the node['name']
Output Format:
- Do not use multi-layer functions. The output format should only return one object
- The return_object must be a JSON object with three keys: 'type', 'data', and 'updated_graph'
- The 'type' key should indicate the output format depending on the user query or request. It should be one of 'text', 'list', 'table', or 'graph'
- The 'data' key should contain the data needed to render the output:
* If output type is 'text': 'data' should contain a string
* If output type is 'list': 'data' should contain a list of items
* If output type is 'table': 'data' should contain a list of lists where each list represents a row in the table
* If output type is 'graph': 'data' should contain a graph JSON
- The 'updated_graph' key should always contain the updated graph as "graph_json = nx.readwrite.json_graph.node_link_data(graph_copy)"
Response Format:
- Your reply should always start with string "\\nAnswer:\\n"
- You should generate a function called "def process_graph"
- All of your output should only contain the defined function without example usages, no additional text, and displayed in a Python code block
- Do not include any package imports in your answer
"""
return prompt
class ZeroShot_CoT_PromptAgent:
def __init__(self):
self.prompt_prefix = self.generate_prompt()
def generate_prompt(self):
cot_prompt_prefix = """
You need to behave like a network engineer who processes graph data to answer user queries about capacity planning.
Your task is to generate the Python code needed to process the network graph to answer the user question or request. The code should take the form of a function named process_graph that accepts a single input argument graph_data and returns a single object return_object.
Graph Structure:
- The input graph_data is a networkx graph object with nodes and edges
- The graph is directed and each node has a 'name' attribute to represent itself
- Each node has a 'type' attribute in the format of EK_TYPE. 'type' must be a list, which can include ['EK_SUPERBLOCK', 'EK_CHASSIS', 'EK_RACK', 'EK_AGG_BLOCK', 'EK_JUPITER', 'EK_PORT', 'EK_SPINEBLOCK', 'EK_PACKET_SWITCH', 'EK_CONTROL_POINT', 'EK_CONTROL_DOMAIN']
- Each node can have other attributes depending on its type
- Each directed edge also has a 'type' attribute, including RK_CONTAINS, RK_CONTROL
Important Guidelines:
- Check relationships based on edge, check name based on node attribute
- Nodes follow this hierarchy: CHASSIS contains PACKET_SWITCH, JUPITER contains SUPERBLOCK, SUPERBLOCK contains AGG_BLOCK, AGG_BLOCK contains PACKET_SWITCH, PACKET_SWITCH contains PORT
- Each PORT node has an attribute 'physical_capacity_bps'. For example, a PORT node name is ju1.a1.m1.s2c1.p3
- When calculating capacity of a node, sum the physical_capacity_bps on the PORT of each hierarchy contained in this node
- When updating a graph, always create a graph copy, do not modify the input graph
- To find a node based on type, check the name and type list. For example, [node[0] == 'ju1.a1.m1.s2c1' and 'EK_PACKET_SWITCH' in node[1]['type']]
- node.startswith will not work for the node name. You have to check the node name with the node['name']
Output Format:
- Do not use multi-layer functions. The output format should only return one object
- The return_object must be a JSON object with three keys: 'type', 'data', and 'updated_graph'
- The 'type' key should indicate the output format depending on the user query or request. It should be one of 'text', 'list', 'table', or 'graph'
- The 'data' key should contain the data needed to render the output:
* If output type is 'text': 'data' should contain a string
* If output type is 'list': 'data' should contain a list of items
* If output type is 'table': 'data' should contain a list of lists where each list represents a row in the table
* If output type is 'graph': 'data' should contain a graph JSON
- The 'updated_graph' key should always contain the updated graph as "graph_json = nx.readwrite.json_graph.node_link_data(graph_copy)"
Response Format:
- Your reply should always start with string "\\nAnswer:\\n"
- You should generate a function called "def process_graph"
- All of your output should only contain the defined function without example usages, no additional text, and displayed in a Python code block
- Do not include any package imports in your answer
Please think step by step and provide your output.
"""
return cot_prompt_prefix
class FewShot_Basic_PromptAgent(ZeroShot_CoT_PromptAgent):
def __init__(self):
super().__init__()
self.examples = EXAMPLE_LIST
self.cot_prompt_prefix = super().generate_prompt()
def get_few_shot_prompt(self):
example_prompt = PromptTemplate(
input_variables=["question", "answer"],
template="Question: {question}\nAnswer: {answer}"
)
few_shot_prompt = FewShotPromptTemplate(
examples=self.examples,
example_prompt=example_prompt,
prefix=self.cot_prompt_prefix + "Here are some example question-answer pairs:\n",
suffix=prompt_suffix,
input_variables=["input"]
)
return few_shot_prompt
class FewShot_Semantic_PromptAgent(ZeroShot_CoT_PromptAgent):
def __init__(self):
self.examples = EXAMPLE_LIST
self.cot_prompt_prefix = super().generate_prompt()
def get_few_shot_prompt(self, query):
embeddings = AzureOpenAIEmbeddings(
model="text-embedding-3-large"
)
example_selector = SemanticSimilarityExampleSelector.from_examples(
# This is the list of examples available to select from.
self.examples,
# This is the embedding class used to produce embeddings which are used to measure semantic similarity.
embeddings,
# This is the VectorStore class that is used to store the embeddings and do a similarity search over.
Chroma,
# This is the number of examples to produce.
k=1)
example_prompt = PromptTemplate(
input_variables=["question", "answer"],
template="Question: {question}\nAnswer: {answer}"
)
few_shot_prompt = FewShotPromptTemplate(
examples=example_selector.select_examples({"question": query}),
example_prompt=example_prompt,
prefix=self.cot_prompt_prefix + "Here are some example question-answer pairs:\n",
suffix=prompt_suffix,
input_variables=["input"]
)
return few_shot_prompt
class ReAct_PromptAgent(BasePromptAgent):
def __init__(self):
self.base_prompt_prefix = BasePromptAgent.generate_prompt(self)
# Now set prompt_prefix manually instead of through super().__init__()
self.prompt_prefix = self.generate_prompt()
def generate_prompt(self):
react_prompt_prefix = """
Answer the following question as best you can. Please use a tool if you need to.
"""
react_prompt = react_prompt_prefix + self.base_prompt_prefix
return react_prompt