|
1 | 1 | (ns graphs.sql
|
2 | 2 | (:require
|
| 3 | + [babashka.fs :as fs] |
3 | 4 | [clojure.core.async :as async]
|
4 | 5 | [clojure.string :as string]
|
5 | 6 | [graph]))
|
6 | 7 |
|
7 |
| -(def db-query-tool-call |
8 |
| - {:messages [{:content "" |
9 |
| - :tool_calls [{:name "sql_db_query_tool" |
10 |
| - :arguments "{}" |
11 |
| - :id "tool_abc123"}]}] |
12 |
| - :tools [{:name "sql_db_query_tool" |
13 |
| - :description "List all tables in the database" |
14 |
| - :parameters |
15 |
| - {:type "object" |
16 |
| - :properties |
17 |
| - {:database {:type "string" :description "the database to query"} |
18 |
| - :query {:type "string" :description "the sql statement to run"}}} |
19 |
| - :container |
20 |
| - {:image "vonwig/sqlite:latest" |
21 |
| - :command ["{{database}}" "{{query}}"]}}]}) |
22 |
| - |
23 | 8 | (def first-tool-call
|
24 | 9 | {:messages [{:role "assistant"
|
25 | 10 | :content ""
|
|
55 | 40 | (async/go
|
56 | 41 | first-tool-call))
|
57 | 42 |
|
| 43 | +(defn failed-tool-call-message [tool-call-name] |
| 44 | + ;; this is awful - binds the should-continue edge to the format of this string |
| 45 | + (format |
| 46 | + "Error: The wrong tool was called: %s. Please fix your mistakes. Remember to only call SubmitFinalAnswer to submit the final answer. Generated queries should be outputted WITHOUT a tool call." |
| 47 | + tool-call-name)) |
| 48 | + |
58 | 49 | (defn query-gen
|
59 | 50 | "Assistant+Tool Node: has it's own prompt but also adds checks for proper final answers"
|
60 |
| - [_] |
61 |
| - ;; note that we should not bind tools here because the only acceptable output is a proper final answer |
62 |
| - ;; so this sub-graph needs to remove all tools |
63 |
| - ) |
64 |
| - |
65 |
| -;; TODO pull out the query check system |
66 |
| -(defn correct-query |
67 |
| - "Assistant Node: double check the query |
68 |
| - - should generate a tool call to execute the query" |
69 |
| - [_]) |
70 |
| - |
71 |
| -(defn execute-query |
72 |
| - "Tool Node: runs db query" |
73 |
| - [_]) |
| 51 | + [state] |
| 52 | + (async/go |
| 53 | + (let [x (-> |
| 54 | + state |
| 55 | + (dissoc :messages) |
| 56 | + (dissoc :functions) |
| 57 | + (update-in [:opts :level] (fnil inc 0)) |
| 58 | + (update-in [:opts :prompts] (constantly (fs/file "prompts/sql/query-gen.md"))) |
| 59 | + (graph/construct-initial-state-from-prompts) |
| 60 | + (update-in [:messages] concat (:messages state))) |
| 61 | + {:keys [messages _finish-reason]} (async/<! (graph/run-llm |
| 62 | + (:messages x) |
| 63 | + (dissoc (:metadata x) :agent) |
| 64 | + (:functions x) |
| 65 | + (:opts x)))] |
| 66 | + |
| 67 | + ; check for bad tool_calls and create failed Tool messages for them |
| 68 | + {:messages |
| 69 | + (concat |
| 70 | + messages |
| 71 | + (->> (:tool_calls (last messages)) |
| 72 | + (filter (complement #(= "SubmitFinalAnswer" (-> % :function :name)))) |
| 73 | + (map (fn [{:keys [id] :as tc}] |
| 74 | + {:role "tool" |
| 75 | + :content (failed-tool-call-message (-> tc :function :name)) |
| 76 | + :tool_call_id id}))))}))) |
74 | 77 |
|
75 | 78 | (defn should-continue
|
76 | 79 | "end, correct-query, or query-gen"
|
77 | 80 | [{:keys [messages]}]
|
78 | 81 | (let [last-message (last messages)]
|
79 | 82 | (cond
|
80 | 83 | (contains? last-message :tool_calls) "end"
|
81 |
| - (string/starts-with? last-message "Error:") "query-gen" |
| 84 | + (string/starts-with? (:content last-message) "Error:") "query-gen" |
82 | 85 | :else "correct-query")))
|
83 | 86 |
|
84 | 87 | (defn seed-get-schema-conversation [state]
|
|
90 | 93 | (update-in [:opts :parameters] (constantly {:database "./Chinook.db"}))
|
91 | 94 | (update-in [:functions] (fnil concat []) (:tools model-get-schema))))
|
92 | 95 |
|
| 96 | +(defn seed-correct-query-conversation |
| 97 | + [state] |
| 98 | + ; make one LLM call with the last message (which should be a user query containing the SQL we want to check) |
| 99 | + ; add the last message to the conversation |
| 100 | + (-> state |
| 101 | + (dissoc :messages) |
| 102 | + (update-in [:opts :level] (fnil inc 0)) |
| 103 | + (update-in [:opts :prompts] (constantly (fs/file "prompts/sql/query-check.md"))) |
| 104 | + (update-in [:opts :parameters] (constantly {:database "./Chinook.db"})) |
| 105 | + (graph/construct-initial-state-from-prompts) |
| 106 | + (update-in [:messages] concat [(last (:messages state))]))) |
| 107 | + |
93 | 108 | (defn graph [_]
|
94 | 109 | (-> {}
|
95 | 110 | (graph/add-node "start" graph/start)
|
|
99 | 114 | (graph/add-node "list-tables-tool" (graph/tool-node nil))
|
100 | 115 | (graph/add-edge "list-tables-inject-tool" "list-tables-tool")
|
101 | 116 |
|
102 |
| - (graph/add-node "model-get-schema" (graph/sub-graph-node {:init-state seed-get-schema-conversation})) ; assistant |
| 117 | + ; TODO replace the conversation state, don't append |
| 118 | + (graph/add-node "model-get-schema" (graph/sub-graph-node {:init-state seed-get-schema-conversation |
| 119 | + :next-state graph/append-new-messages})) ; assistant |
103 | 120 | (graph/add-edge "list-tables-tool" "model-get-schema")
|
104 | 121 |
|
105 |
| - (graph/add-node "end" graph/end) |
106 |
| - (graph/add-edge "model-get-schema" "end") |
107 |
| - ;(graph/add-node "query-gen" query-gen) ; assistant - might just end if it generates the right response |
108 |
| - ;; - might just loop back to query-gen if there's an error |
109 |
| - ;; - otherwise switch to correct-query |
110 |
| - |
111 |
| - ;(graph/add-node "correct-query" correct-query) ; assistant |
112 |
| - ;(graph/add-node "execute-query" execute-query) ; tool |
| 122 | + (graph/add-node "query-gen" query-gen) |
| 123 | + (graph/add-edge "model-get-schema" "query-gen") |
113 | 124 |
|
114 |
| - ;(graph/add-edge "list-table-sub-graph" "model-get-schema") |
115 |
| - ;(graph/add-edge "model-get-schema" "query-gen") |
116 |
| - ;(graph/add-conditional-edges "query-gen" should-continue) |
| 125 | + (graph/add-node "end" graph/end) |
| 126 | + (graph/add-node "correct-query" (graph/sub-graph-node {:init-state seed-correct-query-conversation |
| 127 | + :construct-graph graph/one-tool-call |
| 128 | + :next-state graph/append-new-messages})) |
| 129 | + (graph/add-conditional-edges "query-gen" should-continue) |
117 | 130 |
|
118 |
| - ;(graph/add-edge "correct-query" "execute-query") |
119 |
| - ;(graph/add-edge "execute-query" "query-gen") |
120 |
| - )) |
| 131 | + (graph/add-edge "correct-query" "query-gen"))) |
0 commit comments