Skip to content

Commit 5a5002a

Browse files
Sql agent graph
1 parent bd6d6a6 commit 5a5002a

File tree

7 files changed

+117
-64
lines changed

7 files changed

+117
-64
lines changed

prompts/sql/Chinook.db

12 KB
Binary file not shown.

prompts/sql/prompt.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@ agent: sql
44

55
# prompt user
66

7-
find all artists in database ./Chinook.db
7+
find all artists in the database ./Chinook.db
8+

prompts/sql/query-check.md

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,22 @@
11
---
22
tools:
3-
- name: sqlite3
3+
- name: sql_db_query_tool
44
description: execute the DB query
55
parameters:
66
type: object
77
properties:
8+
database:
9+
type: string
10+
description: the database to query
811
sql:
912
type: string
1013
description: the sql statement to run
1114
container:
1215
image: vonwig/sqlite:latest
1316
command:
14-
- "./Chinook.db"
17+
- "{{database}}"
1518
- "{{sql}}"
19+
tool_choice: required
1620
---
1721

1822
# prompt system
@@ -32,6 +36,3 @@ If there are any of the above mistakes, rewrite the query. If there are no mista
3236

3337
You will call the appropriate tool to execute the query after running this check.
3438

35-
# prompt user
36-
37-
SELECT * FROM Artist LIMIT 10;

prompts/sql/query-gen.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,13 @@
11
---
2+
tools:
3+
- name: SubmitFinalAnswer
4+
description: Submit the final answer to the user based on the query results
5+
parameters:
6+
type: object
7+
properties:
8+
final_answer:
9+
type: string
10+
description: The final answer to the user
211
---
312

413
# prompt system

src/graph.clj

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -66,15 +66,14 @@
6666
c))
6767

6868
(defn completion
69-
"get the next llm completion
70-
uses the current conversation messages and outputs
71-
outputs the next set of messages and finish-reason to be added to the conversation"
69+
"generate the next AI message
70+
passes the whole converation to the AI model"
7271
[state]
73-
(run-llm (:messages state) (:metadata state) (:functions state) (:opts state)))
72+
(run-llm (:messages state) (dissoc (:metadata state) :agent) (:functions state) (:opts state)))
7473

7574
;; TODO does the LangGraph Tool Node always search for the a tool_call
7675
(defn tool
77-
"make docker container tool calls"
76+
"execute the tool_calls from the last AI message in the conversation"
7877
[state]
7978
(let [calls (-> (:messages state) last :tool_calls)]
8079
(async/go
@@ -88,7 +87,9 @@
8887
calls)
8988
(async/reduce conj []))))})))
9089

91-
(defn tool-node [_]
90+
(defn tool-node
91+
"add a tool node that will run tool_calls from the last AI message in the conversation"
92+
[_]
9293
tool)
9394

9495
(defn tools-query
@@ -128,6 +129,7 @@
128129
(update-in [:opts :parameters] (constantly arg-context)))))
129130

130131
(comment
132+
;; TODO move this into the thingy
131133
(add-prompt-ref {:messages [{:tool_calls [{:function {:name "sql_db_list_tables"
132134
:arguments "{\"arg\": 1}"}}]}]
133135
:functions [{:function {:name "sql_db_list_tables"
@@ -139,18 +141,31 @@
139141

140142
(declare stream chat-with-tools)
141143

142-
(defn sub-graph-node [{:keys [init-state construct-graph]}]
144+
(defn add-last-message-as-tool-call
145+
[state sub-graph-state]
146+
{:messages [(-> sub-graph-state
147+
:messages
148+
last
149+
(state/add-tool-call-id (-> state :messages last :tool_calls first :id)))]})
150+
151+
(defn append-new-messages
152+
[state sub-graph-state]
153+
{:messages (->> (:messages sub-graph-state)
154+
(filter (complement (fn [m] (some #(= m %) (:messages state))))))})
155+
156+
(defn sub-graph-node
157+
"create a sub-graph node that initializes a conversation from the current one,
158+
creates a new agent graph from the current state and returns the messages to be added
159+
to the parent conversation"
160+
[{:keys [init-state construct-graph next-state]}]
143161
(fn [state]
144162
(async/go
145163
(let [sub-graph-state
146164
(async/<!
147165
(stream
148166
((or construct-graph chat-with-tools) state)
149167
((or init-state (comp construct-initial-state-from-prompts add-prompt-ref)) state)))]
150-
{:messages [(-> sub-graph-state
151-
:messages
152-
last
153-
(state/add-tool-call-id (-> state :messages last :tool_calls first :id)))]}))))
168+
((or next-state add-last-message-as-tool-call) state sub-graph-state)))))
154169

155170
; =====================================================
156171
; edge functions takes state and returns next node
@@ -194,6 +209,7 @@
194209
[state m
195210
node "start"]
196211
(jsonrpc/notify :message {:debug (format "\n-> entering %s\n\n" node)})
212+
#_(jsonrpc/notify :message {:debug (with-out-str (pprint (state/summarize (dissoc state :opts))))})
197213
;; TODO handling bad graphs with missing nodes
198214
(let [enter-node (get-in graph [:nodes node])
199215
new-state (state-reducer state (async/<! (enter-node state)))]
@@ -222,6 +238,18 @@
222238
(add-edge "tools-query" "completion")
223239
(add-conditional-edges "completion" tool-or-end)))
224240

241+
(defn one-tool-call [_]
242+
(-> {}
243+
(add-node "start" start)
244+
(add-node "completion" completion)
245+
(add-node "tool" (tool-node nil))
246+
(add-node "end" end)
247+
(add-node "sub-graph" (sub-graph-node nil))
248+
(add-edge "start" "completion")
249+
(add-edge "sub-graph" "end")
250+
(add-edge "tool" "end")
251+
(add-conditional-edges "completion" tool-or-end)))
252+
225253
(comment
226254
(alter-var-root #'jsonrpc/notify (fn [_] (partial jsonrpc/-println {:debug true})))
227255
(let [x {:prompts (fs/file "/Users/slim/docker/labs-ai-tools-for-devs/prompts/curl/README.md")

src/graphs/sql.clj

Lines changed: 57 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,10 @@
11
(ns graphs.sql
22
(:require
3+
[babashka.fs :as fs]
34
[clojure.core.async :as async]
45
[clojure.string :as string]
56
[graph]))
67

7-
(def db-query-tool-call
8-
{:messages [{:content ""
9-
:tool_calls [{:name "sql_db_query_tool"
10-
:arguments "{}"
11-
:id "tool_abc123"}]}]
12-
:tools [{:name "sql_db_query_tool"
13-
:description "List all tables in the database"
14-
:parameters
15-
{:type "object"
16-
:properties
17-
{:database {:type "string" :description "the database to query"}
18-
:query {:type "string" :description "the sql statement to run"}}}
19-
:container
20-
{:image "vonwig/sqlite:latest"
21-
:command ["{{database}}" "{{query}}"]}}]})
22-
238
(def first-tool-call
249
{:messages [{:role "assistant"
2510
:content ""
@@ -55,30 +40,48 @@
5540
(async/go
5641
first-tool-call))
5742

43+
(defn failed-tool-call-message [tool-call-name]
44+
;; this is awful - binds the should-continue edge to the format of this string
45+
(format
46+
"Error: The wrong tool was called: %s. Please fix your mistakes. Remember to only call SubmitFinalAnswer to submit the final answer. Generated queries should be outputted WITHOUT a tool call."
47+
tool-call-name))
48+
5849
(defn query-gen
5950
"Assistant+Tool Node: has it's own prompt but also adds checks for proper final answers"
60-
[_]
61-
;; note that we should not bind tools here because the only acceptable output is a proper final answer
62-
;; so this sub-graph needs to remove all tools
63-
)
64-
65-
;; TODO pull out the query check system
66-
(defn correct-query
67-
"Assistant Node: double check the query
68-
- should generate a tool call to execute the query"
69-
[_])
70-
71-
(defn execute-query
72-
"Tool Node: runs db query"
73-
[_])
51+
[state]
52+
(async/go
53+
(let [x (->
54+
state
55+
(dissoc :messages)
56+
(dissoc :functions)
57+
(update-in [:opts :level] (fnil inc 0))
58+
(update-in [:opts :prompts] (constantly (fs/file "prompts/sql/query-gen.md")))
59+
(graph/construct-initial-state-from-prompts)
60+
(update-in [:messages] concat (:messages state)))
61+
{:keys [messages _finish-reason]} (async/<! (graph/run-llm
62+
(:messages x)
63+
(dissoc (:metadata x) :agent)
64+
(:functions x)
65+
(:opts x)))]
66+
67+
; check for bad tool_calls and create failed Tool messages for them
68+
{:messages
69+
(concat
70+
messages
71+
(->> (:tool_calls (last messages))
72+
(filter (complement #(= "SubmitFinalAnswer" (-> % :function :name))))
73+
(map (fn [{:keys [id] :as tc}]
74+
{:role "tool"
75+
:content (failed-tool-call-message (-> tc :function :name))
76+
:tool_call_id id}))))})))
7477

7578
(defn should-continue
7679
"end, correct-query, or query-gen"
7780
[{:keys [messages]}]
7881
(let [last-message (last messages)]
7982
(cond
8083
(contains? last-message :tool_calls) "end"
81-
(string/starts-with? last-message "Error:") "query-gen"
84+
(string/starts-with? (:content last-message) "Error:") "query-gen"
8285
:else "correct-query")))
8386

8487
(defn seed-get-schema-conversation [state]
@@ -90,6 +93,18 @@
9093
(update-in [:opts :parameters] (constantly {:database "./Chinook.db"}))
9194
(update-in [:functions] (fnil concat []) (:tools model-get-schema))))
9295

96+
(defn seed-correct-query-conversation
97+
[state]
98+
; make one LLM call with the last message (which should be a user query containing the SQL we want to check)
99+
; add the last message to the conversation
100+
(-> state
101+
(dissoc :messages)
102+
(update-in [:opts :level] (fnil inc 0))
103+
(update-in [:opts :prompts] (constantly (fs/file "prompts/sql/query-check.md")))
104+
(update-in [:opts :parameters] (constantly {:database "./Chinook.db"}))
105+
(graph/construct-initial-state-from-prompts)
106+
(update-in [:messages] concat [(last (:messages state))])))
107+
93108
(defn graph [_]
94109
(-> {}
95110
(graph/add-node "start" graph/start)
@@ -99,22 +114,18 @@
99114
(graph/add-node "list-tables-tool" (graph/tool-node nil))
100115
(graph/add-edge "list-tables-inject-tool" "list-tables-tool")
101116

102-
(graph/add-node "model-get-schema" (graph/sub-graph-node {:init-state seed-get-schema-conversation})) ; assistant
117+
; TODO replace the conversation state, don't append
118+
(graph/add-node "model-get-schema" (graph/sub-graph-node {:init-state seed-get-schema-conversation
119+
:next-state graph/append-new-messages})) ; assistant
103120
(graph/add-edge "list-tables-tool" "model-get-schema")
104121

105-
(graph/add-node "end" graph/end)
106-
(graph/add-edge "model-get-schema" "end")
107-
;(graph/add-node "query-gen" query-gen) ; assistant - might just end if it generates the right response
108-
;; - might just loop back to query-gen if there's an error
109-
;; - otherwise switch to correct-query
110-
111-
;(graph/add-node "correct-query" correct-query) ; assistant
112-
;(graph/add-node "execute-query" execute-query) ; tool
122+
(graph/add-node "query-gen" query-gen)
123+
(graph/add-edge "model-get-schema" "query-gen")
113124

114-
;(graph/add-edge "list-table-sub-graph" "model-get-schema")
115-
;(graph/add-edge "model-get-schema" "query-gen")
116-
;(graph/add-conditional-edges "query-gen" should-continue)
125+
(graph/add-node "end" graph/end)
126+
(graph/add-node "correct-query" (graph/sub-graph-node {:init-state seed-correct-query-conversation
127+
:construct-graph graph/one-tool-call
128+
:next-state graph/append-new-messages}))
129+
(graph/add-conditional-edges "query-gen" should-continue)
117130

118-
;(graph/add-edge "correct-query" "execute-query")
119-
;(graph/add-edge "execute-query" "query-gen")
120-
))
131+
(graph/add-edge "correct-query" "query-gen")))

src/state.clj

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@
1616

1717
(defn summarize [state]
1818
(-> state
19-
(update :messages (each summarize-content summarize-tool-calls))))
19+
(update :messages (each
20+
;summarize-content
21+
;summarize-tool-calls
22+
))))
2023

2124
(defn prompt? [m]
2225
(= "prompt" (-> m :function :type)))

0 commit comments

Comments
 (0)