|
1 | 1 | (ns graphs.sql |
2 | 2 | (:require |
| 3 | + [babashka.fs :as fs] |
3 | 4 | [clojure.core.async :as async] |
4 | 5 | [clojure.string :as string] |
5 | 6 | [graph])) |
6 | 7 |
|
7 | | -(def db-query-tool-call |
8 | | - {:messages [{:content "" |
9 | | - :tool_calls [{:name "sql_db_query_tool" |
10 | | - :arguments "{}" |
11 | | - :id "tool_abc123"}]}] |
12 | | - :tools [{:name "sql_db_query_tool" |
13 | | - :description "List all tables in the database" |
14 | | - :parameters |
15 | | - {:type "object" |
16 | | - :properties |
17 | | - {:database {:type "string" :description "the database to query"} |
18 | | - :query {:type "string" :description "the sql statement to run"}}} |
19 | | - :container |
20 | | - {:image "vonwig/sqlite:latest" |
21 | | - :command ["{{database}}" "{{query}}"]}}]}) |
22 | | - |
23 | 8 | (def first-tool-call |
24 | 9 | {:messages [{:role "assistant" |
25 | 10 | :content "" |
|
55 | 40 | (async/go |
56 | 41 | first-tool-call)) |
57 | 42 |
|
| 43 | +(defn failed-tool-call-message [tool-call-name] |
| 44 | + ;; this is awful - binds the should-continue edge to the format of this string |
| 45 | + (format |
| 46 | + "Error: The wrong tool was called: %s. Please fix your mistakes. Remember to only call SubmitFinalAnswer to submit the final answer. Generated queries should be outputted WITHOUT a tool call." |
| 47 | + tool-call-name)) |
| 48 | + |
58 | 49 | (defn query-gen |
59 | 50 | "Assistant+Tool Node: has it's own prompt but also adds checks for proper final answers" |
60 | | - [_] |
61 | | - ;; note that we should not bind tools here because the only acceptable output is a proper final answer |
62 | | - ;; so this sub-graph needs to remove all tools |
63 | | - ) |
64 | | - |
65 | | -;; TODO pull out the query check system |
66 | | -(defn correct-query |
67 | | - "Assistant Node: double check the query |
68 | | - - should generate a tool call to execute the query" |
69 | | - [_]) |
70 | | - |
71 | | -(defn execute-query |
72 | | - "Tool Node: runs db query" |
73 | | - [_]) |
| 51 | + [state] |
| 52 | + (async/go |
| 53 | + (let [x (-> |
| 54 | + state |
| 55 | + (dissoc :messages) |
| 56 | + (dissoc :functions) |
| 57 | + (update-in [:opts :level] (fnil inc 0)) |
| 58 | + (update-in [:opts :prompts] (constantly (fs/file "prompts/sql/query-gen.md"))) |
| 59 | + (graph/construct-initial-state-from-prompts) |
| 60 | + (update-in [:messages] concat (:messages state))) |
| 61 | + {:keys [messages _finish-reason]} (async/<! (graph/run-llm |
| 62 | + (:messages x) |
| 63 | + (dissoc (:metadata x) :agent) |
| 64 | + (:functions x) |
| 65 | + (:opts x)))] |
| 66 | + |
| 67 | + ; check for bad tool_calls and create failed Tool messages for them |
| 68 | + {:messages |
| 69 | + (concat |
| 70 | + messages |
| 71 | + (->> (:tool_calls (last messages)) |
| 72 | + (filter (complement #(= "SubmitFinalAnswer" (-> % :function :name)))) |
| 73 | + (map (fn [{:keys [id] :as tc}] |
| 74 | + {:role "tool" |
| 75 | + :content (failed-tool-call-message (-> tc :function :name)) |
| 76 | + :tool_call_id id}))))}))) |
74 | 77 |
|
75 | 78 | (defn should-continue |
76 | 79 | "end, correct-query, or query-gen" |
77 | 80 | [{:keys [messages]}] |
78 | 81 | (let [last-message (last messages)] |
79 | 82 | (cond |
80 | 83 | (contains? last-message :tool_calls) "end" |
81 | | - (string/starts-with? last-message "Error:") "query-gen" |
| 84 | + (string/starts-with? (:content last-message) "Error:") "query-gen" |
82 | 85 | :else "correct-query"))) |
83 | 86 |
|
84 | 87 | (defn seed-get-schema-conversation [state] |
|
90 | 93 | (update-in [:opts :parameters] (constantly {:database "./Chinook.db"})) |
91 | 94 | (update-in [:functions] (fnil concat []) (:tools model-get-schema)))) |
92 | 95 |
|
| 96 | +(defn seed-correct-query-conversation |
| 97 | + [state] |
| 98 | + ; make one LLM call with the last message (which should be a user query containing the SQL we want to check) |
| 99 | + ; add the last message to the conversation |
| 100 | + (-> state |
| 101 | + (dissoc :messages) |
| 102 | + (update-in [:opts :level] (fnil inc 0)) |
| 103 | + (update-in [:opts :prompts] (constantly (fs/file "prompts/sql/query-check.md"))) |
| 104 | + (update-in [:opts :parameters] (constantly {:database "./Chinook.db"})) |
| 105 | + (graph/construct-initial-state-from-prompts) |
| 106 | + (update-in [:messages] concat [(last (:messages state))]))) |
| 107 | + |
93 | 108 | (defn graph [_] |
94 | 109 | (-> {} |
95 | 110 | (graph/add-node "start" graph/start) |
|
99 | 114 | (graph/add-node "list-tables-tool" (graph/tool-node nil)) |
100 | 115 | (graph/add-edge "list-tables-inject-tool" "list-tables-tool") |
101 | 116 |
|
102 | | - (graph/add-node "model-get-schema" (graph/sub-graph-node {:init-state seed-get-schema-conversation})) ; assistant |
| 117 | + ; TODO replace the conversation state, don't append |
| 118 | + (graph/add-node "model-get-schema" (graph/sub-graph-node {:init-state seed-get-schema-conversation |
| 119 | + :next-state graph/append-new-messages})) ; assistant |
103 | 120 | (graph/add-edge "list-tables-tool" "model-get-schema") |
104 | 121 |
|
105 | | - (graph/add-node "end" graph/end) |
106 | | - (graph/add-edge "model-get-schema" "end") |
107 | | - ;(graph/add-node "query-gen" query-gen) ; assistant - might just end if it generates the right response |
108 | | - ;; - might just loop back to query-gen if there's an error |
109 | | - ;; - otherwise switch to correct-query |
110 | | - |
111 | | - ;(graph/add-node "correct-query" correct-query) ; assistant |
112 | | - ;(graph/add-node "execute-query" execute-query) ; tool |
| 122 | + (graph/add-node "query-gen" query-gen) |
| 123 | + (graph/add-edge "model-get-schema" "query-gen") |
113 | 124 |
|
114 | | - ;(graph/add-edge "list-table-sub-graph" "model-get-schema") |
115 | | - ;(graph/add-edge "model-get-schema" "query-gen") |
116 | | - ;(graph/add-conditional-edges "query-gen" should-continue) |
| 125 | + (graph/add-node "end" graph/end) |
| 126 | + (graph/add-node "correct-query" (graph/sub-graph-node {:init-state seed-correct-query-conversation |
| 127 | + :construct-graph graph/one-tool-call |
| 128 | + :next-state graph/append-new-messages})) |
| 129 | + (graph/add-conditional-edges "query-gen" should-continue) |
117 | 130 |
|
118 | | - ;(graph/add-edge "correct-query" "execute-query") |
119 | | - ;(graph/add-edge "execute-query" "query-gen") |
120 | | - )) |
| 131 | + (graph/add-edge "correct-query" "query-gen"))) |
0 commit comments