Skip to content

Commit c8be784

Browse files
Add state functions to manipulate conversation
1 parent ba844ec commit c8be784

File tree

3 files changed

+57
-21
lines changed

3 files changed

+57
-21
lines changed

src/graph.clj

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@
9797

9898
(declare stream chat-with-tools)
9999

100+
(defn apply-functions [coll]
101+
(fn [state]
102+
(reduce (fn [m f] (f m)) state coll)))
103+
100104
(defn sub-graph-node
101105
"create a sub-graph node that initializes a conversation from the current one,
102106
creates a new agent graph from the current state and returns the messages to be added
@@ -109,7 +113,13 @@
109113
(stream
110114
((or construct-graph chat-with-tools) state)
111115
(->
112-
((or init-state (comp state/construct-initial-state-from-prompts state/add-prompt-ref)) state)
116+
((or
117+
(and
118+
init-state
119+
(if (coll? init-state)
120+
(apply-functions init-state)
121+
init-state))
122+
(comp state/construct-initial-state-from-prompts state/add-prompt-ref)) state)
113123
(update-in [:opts :level] (fnil inc 0)))))]
114124
((or next-state state/add-last-message-as-tool-call) state sub-graph-state)))))
115125

src/graphs/sql.clj

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -83,19 +83,6 @@
8383
;; how many times should we try to correct because correct-query will always end up back here
8484
:else "correct-query")))
8585

86-
(defn seed-list-tables-conversation [state]
87-
(-> state
88-
(assoc :finish-reason "tool_calls")
89-
(update-in [:functions] (constantly (:tools first-tool-call)))
90-
(update-in [:messages] concat (:messages first-tool-call))))
91-
92-
(defn seed-get-schema-conversation [state]
93-
; inherit full conversation
94-
; no prompts
95-
; add the schema tool
96-
(-> state
97-
(update-in [:functions] (fnil concat []) (:tools model-get-schema))))
98-
9986
(defn seed-correct-query-conversation
10087
[state]
10188
; make one LLM call with the last message (which should be a user query containing the SQL we want to check)
@@ -106,18 +93,27 @@
10693
(state/construct-initial-state-from-prompts)
10794
(update-in [:messages] concat [(last (:messages state))])))
10895

96+
(comment
97+
[state/messages-reset
98+
(state/messages-from-prompt "prompts/sql/query-check.md")
99+
(state/messages-take-last 1)])
100+
109101
;; query-gen has a prompt
110102
;; seed-correct-query-conversation has a prompt
111103
;; prompts/sql/query-gen.md has a hard-coded db file
112104
(defn graph [_]
113105
(graph/construct-graph
114106
[[["start" graph/start]
115-
["list-tables-tool" (graph/sub-graph-node
116-
{:init-state seed-list-tables-conversation
117-
:construct-graph graph/generate-start-with-tool
118-
:next-state (state/take-last-messages 2)})]
107+
["list-tables-tool" (graph/sub-graph-node
108+
{:init-state
109+
[#(assoc % :finish-reason "tool_calls")
110+
(state/tools-set (:tools first-tool-call))
111+
(state/messages-append (:messages first-tool-call))]
112+
:construct-graph graph/generate-start-with-tool
113+
:next-state (state/take-last-messages 2)})]
119114
["model-get-schema" (graph/sub-graph-node
120-
{:init-state seed-get-schema-conversation
115+
{:init-state
116+
[(state/tools-append (:tools model-get-schema))]
121117
:next-state state/append-new-messages})]
122118
["query-gen" query-gen]
123119
[:edge should-continue]]

src/state.clj

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
(ns state
22
(:require
3+
[babashka.fs :as fs]
4+
[clojure.pprint :refer [pprint]]
35
git
46
jsonrpc
57
prompts
6-
tools
7-
[clojure.pprint :refer [pprint]]))
8+
tools))
89

910
(set! *warn-on-reflection* true)
1011

@@ -65,6 +66,35 @@
6566
(format "failure for prompt configuration:\n %s" (with-out-str (pprint (dissoc opts :pat :jwt))))
6667
:exception (str ex)}))))
6768

69+
(defn tools-append [tools]
70+
(fn [state]
71+
(-> state
72+
(update-in [:functions] (fnil concat []) tools))))
73+
74+
(defn tools-set [tools]
75+
(fn [state]
76+
(-> state
77+
(update-in [:functions] (constantly tools)))))
78+
79+
(defn messages-reset [state]
80+
(dissoc state :messages))
81+
82+
(defn messages-take-last [n]
83+
(fn [state]
84+
(-> state
85+
(update-in [:messages] (fnil concat []) (take-last n (:messages state))))))
86+
87+
(defn messages-append [coll]
88+
(fn [state]
89+
(-> state
90+
(update-in [:messages] (fnil concat []) coll))))
91+
92+
(defn messages-from-prompt [s]
93+
(fn [state]
94+
(-> state
95+
(update-in [:opts :prompts] (constantly (fs/file s)))
96+
(construct-initial-state-from-prompts))))
97+
6898
(defn add-prompt-ref
6999
[state]
70100
(let [definition (state/get-function-definition state)

0 commit comments

Comments
 (0)