Skip to content

Commit 4efa407

Browse files
committed
(human) more work towards refactoring chat post-env/agent merge, some other random improvements
1 parent 5ab7733 commit 4efa407

File tree

6 files changed

+126
-56
lines changed

6 files changed

+126
-56
lines changed

lib/agent/cmd/runtime.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ func prepRuntime(args []string, rflags flags.RootPflagpole) (*runtime.Runtime, *
5858

5959
ar.BackfillAgentic()
6060

61-
ws.SetupHandlers(ar)
61+
ws.SetupHandlers(r, ar)
6262
return r, ar, nil
6363
}
6464

lib/agent/cmd/tui/msgs.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ func renderMessages(width int, session session.Session) []string {
4545
author = agentStyle.Render(evt.Author)
4646
if evt.Content != nil {
4747
for _, part := range evt.Content.Parts {
48+
4849
if part.Text != "" {
4950
out, err := glam.Render(part.Text)
5051
if err != nil {
@@ -53,6 +54,7 @@ func renderMessages(width int, session session.Session) []string {
5354
body += out + "\n"
5455
}
5556
}
57+
5658
if part.FunctionCall != nil {
5759
r := part.FunctionCall
5860
var extra string
@@ -75,9 +77,12 @@ func renderMessages(width int, session session.Session) []string {
7577
text := dmp.DiffPrettyText(diffs)
7678
extra += fmt.Sprintf("-------\n%s\n-------\n\n", text)
7779
}
80+
case "exec":
81+
extra = fmt.Sprintf("`%s`", r.Args["script"])
7882
}
7983
body += fmt.Sprintf(" %s: %s ...\n", funcStyle.Render("┃"+r.Name), extra)
8084
}
85+
8186
if part.FunctionResponse != nil {
8287
r := part.FunctionResponse
8388
var extra string
@@ -88,9 +93,12 @@ func renderMessages(width int, session session.Session) []string {
8893
extra = fmt.Sprintf("%s %s", r.Response["path"], r.Response["status"])
8994
case "fs_edit":
9095
extra = fmt.Sprintf("%s %s", r.Response["path"], r.Response["status"])
96+
case "exec":
97+
extra = fmt.Sprintf("%s %v\n--- stdout ---\n%s\n--- stderr ---\n%s\n--- end ---", r.Response["status"], r.Response["exitCode"], r.Response["stdout"], r.Response["stderr"])
9198
}
9299
body += fmt.Sprintf(" %s: %s\n", funcStyle.Render("┃"+r.Name), extra)
93100
}
101+
94102
}
95103
body = strings.TrimSuffix(body, "\n")
96104
}

lib/agent/models/gemini.go

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,40 @@ package models
22

33
import (
44
"context"
5+
"os"
56

67
"google.golang.org/adk/model"
78
"google.golang.org/adk/model/gemini"
89
"google.golang.org/genai"
910
)
1011

1112
func Gemini(ctx context.Context, model string) (model.LLM, error) {
12-
return gemini.NewModel(ctx, model, &genai.ClientConfig{
13-
// Project: "gen-lang-client-0911744172",
14-
// Location: "us-central1",
15-
// Backend: genai.BackendVertexAI,
16-
})
13+
14+
use := os.Getenv("GOOGLE_GENAI_USE_VERTEXAI")
15+
if use != "" { // todo, be more truthy
16+
// creds file
17+
creds := os.Getenv("GOOGLE_APPLICATION_CREDENTIALS")
18+
if creds != "" {
19+
return gemini.NewModel(ctx, model, &genai.ClientConfig{
20+
Backend: genai.BackendVertexAI,
21+
})
22+
}
23+
proj := os.Getenv("GOOGLE_CLOUD_PROJECT")
24+
loc := os.Getenv("GOOGLE_CLOUD_LOCATION")
25+
if proj != "" && loc != "" {
26+
return gemini.NewModel(ctx, model, &genai.ClientConfig{
27+
Project: proj,
28+
Location: loc,
29+
Backend: genai.BackendVertexAI,
30+
})
31+
}
32+
33+
// default inference (same as Go SDK) (typically a service account)
34+
return gemini.NewModel(ctx, model, &genai.ClientConfig{
35+
Backend: genai.BackendVertexAI,
36+
})
37+
}
38+
39+
// default inference (same as Go SDK)
40+
return gemini.NewModel(ctx, model, &genai.ClientConfig{})
1741
}

lib/agent/runtime/handlers/ws/chat.go

Lines changed: 51 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,26 +9,66 @@ import (
99

1010
"github.com/google/uuid"
1111
"github.com/hofstadter-io/hof/lib/agent/agents"
12-
"github.com/hofstadter-io/hof/lib/agent/runtime"
12+
aruntime "github.com/hofstadter-io/hof/lib/agent/runtime"
13+
"github.com/hofstadter-io/hof/lib/agent/runtime/handlers/common"
1314
"github.com/hofstadter-io/hof/lib/agent/services/environ"
15+
"github.com/hofstadter-io/hof/lib/runtime"
1416
"google.golang.org/adk/agent"
1517
"google.golang.org/adk/model"
1618
"google.golang.org/adk/runner"
1719
"google.golang.org/adk/session"
1820
"google.golang.org/genai"
1921
)
2022

21-
type ChatPayload struct {
22-
Text string `json:"text"`
23-
Sid string `json:"sid"`
24-
Agent string `json:"agent"`
25-
Model string `json:"model"`
26-
Environ string `json:"environ"`
23+
func makeChatUserMessageHandler(r *runtime.Runtime) aruntime.Handler {
24+
return func(ar *aruntime.Runtime, c *aruntime.Client, m *aruntime.Message) {
25+
var p common.ChatPayload
26+
if err := json.Unmarshal(m.Payload, &p); err != nil {
27+
c.Mail("chat.event.error", map[string]any{
28+
"agent": p.Agent,
29+
"error_message": fmt.Sprintf("Error unmarshaling 'chat' payload: %v", err),
30+
})
31+
return
32+
}
33+
34+
p.User = c.User
35+
36+
log.Printf("Chatting payload: %#+v", p)
37+
38+
s, err := common.SessionChat(r, ar, &p)
39+
if err != nil {
40+
log.Printf("chat.msg.error.SessionChat: %v", err)
41+
c.Mail("chat.event.error", map[string]any{
42+
"agent": p.Agent,
43+
"error_message": fmt.Sprintf("while chatting: %v", err),
44+
})
45+
return
46+
}
47+
48+
// every time we get an event...
49+
for e := range s.EventChan {
50+
// send the message
51+
c.Mail("chat.event", e)
52+
53+
// look for any errors
54+
select {
55+
case err := <-s.ErrorChan:
56+
log.Printf("chat.msg.error.SessionChat.loop: %v", err)
57+
c.Mail("chat.event.error", map[string]any{
58+
"agent": p.Agent,
59+
"error_message": fmt.Sprintf("while chatting: %v", err),
60+
})
61+
default:
62+
}
63+
}
64+
65+
}
66+
2767
}
2868

29-
func chatUserMessage(r *runtime.Runtime, c *runtime.Client, m *runtime.Message) {
69+
func chatUserMessage(r *aruntime.Runtime, c *aruntime.Client, m *aruntime.Message) {
3070

31-
var p ChatPayload
71+
var p common.ChatPayload
3272
if err := json.Unmarshal(m.Payload, &p); err != nil {
3373
c.Mail("chat.event.error", map[string]any{
3474
"agent": p.Agent,
@@ -125,7 +165,7 @@ func chatUserMessage(r *runtime.Runtime, c *runtime.Client, m *runtime.Message)
125165
// setup subcontext and wait group
126166
chatCtx, chatStop := context.WithCancel(r.Ctx)
127167

128-
r.SetSession(&runtime.Session{
168+
r.SetSession(&aruntime.Session{
129169
Sid: p.Sid,
130170
StopFunc: chatStop,
131171
})
@@ -151,7 +191,7 @@ func chatUserMessage(r *runtime.Runtime, c *runtime.Client, m *runtime.Message)
151191

152192
}
153193

154-
func sessionCancel(r *runtime.Runtime, c *runtime.Client, m *runtime.Message) {
194+
func sessionCancel(r *aruntime.Runtime, c *aruntime.Client, m *aruntime.Message) {
155195
var p SidRequest
156196
if err := json.Unmarshal(m.Payload, &p); err != nil {
157197
log.Printf("Error unmarshaling 'session.cancel' payload: %v", err)

lib/agent/runtime/handlers/ws/index.go

Lines changed: 36 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -5,47 +5,45 @@ import (
55
"fmt"
66
"log"
77

8-
"github.com/hofstadter-io/hof/lib/agent/runtime"
8+
aruntime "github.com/hofstadter-io/hof/lib/agent/runtime"
99
"github.com/hofstadter-io/hof/lib/cuetils"
10+
"github.com/hofstadter-io/hof/lib/runtime"
1011
)
1112

12-
// TODO, we need a good list of message types
13-
// for both the frontend, backend, and where/how they are used
14-
15-
func SetupHandlers(r *runtime.Runtime) {
13+
func SetupHandlers(r *runtime.Runtime, ar *aruntime.Runtime) {
1614

1715
// standard fare
18-
r.Handlers["echo"] = echo
19-
r.Handlers["hello"] = hello
16+
ar.Handlers["echo"] = echo
17+
ar.Handlers["hello"] = hello
2018

2119
// informational handlers
22-
r.Handlers["requestSync"] = broadcastSync
23-
r.Handlers["config.reload"] = reloadEnvConfig
24-
r.Handlers["config.info"] = configInfo
25-
r.Handlers["models.list"] = modelsList
26-
r.Handlers["agents.list"] = agentsList
20+
ar.Handlers["requestSync"] = broadcastSync
21+
ar.Handlers["config.reload"] = reloadEnvConfig
22+
ar.Handlers["config.info"] = configInfo
23+
ar.Handlers["models.list"] = modelsList
24+
ar.Handlers["agents.list"] = agentsList
2725

2826
// chat
29-
r.Handlers["chat"] = chatUserMessage
30-
r.Handlers["chat.userMessage"] = chatUserMessage
31-
r.Handlers["session.cancel"] = sessionCancel
27+
ar.Handlers["chat"] = makeChatUserMessageHandler(r)
28+
ar.Handlers["chat.userMessage"] = makeChatUserMessageHandler(r)
29+
ar.Handlers["session.cancel"] = sessionCancel
3230

3331
// sessions
34-
r.Handlers["session.get"] = sessionGet
35-
r.Handlers["session.getList"] = sessionList
36-
r.Handlers["session.create"] = sessionCreate
37-
r.Handlers["session.delete"] = sessionDelete
38-
r.Handlers["session.getStateAll"] = sessionGetStateAll
39-
r.Handlers["session.state.get"] = sessionGetState
40-
r.Handlers["session.state.put"] = sessionPutState
41-
r.Handlers["session.state.del"] = sessionDelState
42-
43-
r.Handlers["session.merge"] = sessionMerge
44-
r.Handlers["session.tag"] = sessionTag
45-
r.Handlers["session.push"] = sessionPush
46-
r.Handlers["session.pull"] = sessionPull
47-
r.Handlers["session.clone"] = sessionClone
48-
r.Handlers["session.splice"] = sessionSplice
32+
ar.Handlers["session.get"] = sessionGet
33+
ar.Handlers["session.getList"] = sessionList
34+
ar.Handlers["session.create"] = sessionCreate
35+
ar.Handlers["session.delete"] = sessionDelete
36+
ar.Handlers["session.getStateAll"] = sessionGetStateAll
37+
ar.Handlers["session.state.get"] = sessionGetState
38+
ar.Handlers["session.state.put"] = sessionPutState
39+
ar.Handlers["session.state.del"] = sessionDelState
40+
41+
ar.Handlers["session.merge"] = sessionMerge
42+
ar.Handlers["session.tag"] = sessionTag
43+
ar.Handlers["session.push"] = sessionPush
44+
ar.Handlers["session.pull"] = sessionPull
45+
ar.Handlers["session.clone"] = sessionClone
46+
ar.Handlers["session.splice"] = sessionSplice
4947
// r.Handlers["session.environ.set"] = sessionEnvironSet
5048

5149
//
@@ -65,7 +63,7 @@ type EchoResponsePayload struct {
6563
ResponseText string `json:"responseText"`
6664
}
6765

68-
func echo(r *runtime.Runtime, c *runtime.Client, m *runtime.Message) {
66+
func echo(r *aruntime.Runtime, c *aruntime.Client, m *aruntime.Message) {
6967
var p EchoPayload
7068
if err := json.Unmarshal(m.Payload, &p); err != nil {
7169
log.Printf("Error unmarshaling 'echo' payload: %v", err)
@@ -83,7 +81,7 @@ type HelloPayload struct {
8381
Version string `json:"version"`
8482
}
8583

86-
func hello(r *runtime.Runtime, c *runtime.Client, m *runtime.Message) {
84+
func hello(ar *aruntime.Runtime, c *aruntime.Client, m *aruntime.Message) {
8785
var p HelloPayload
8886
if err := json.Unmarshal(m.Payload, &p); err != nil {
8987
log.Printf("Error unmarshaling 'hello' payload: %v", err)
@@ -92,28 +90,28 @@ func hello(r *runtime.Runtime, c *runtime.Client, m *runtime.Message) {
9290
log.Printf("Hello from client version: %s (Client: %p)", p.Version, c)
9391
}
9492

95-
func broadcastSync(r *runtime.Runtime, c *runtime.Client, m *runtime.Message) {
93+
func broadcastSync(ar *aruntime.Runtime, c *aruntime.Client, m *aruntime.Message) {
9694
// fmt.Println("broadcastSync")
9795
// reloadEnvConfig(r, c, m)
98-
sessionGet(r, c, m)
99-
sessionList(r, c, m)
96+
sessionGet(ar, c, m)
97+
sessionList(ar, c, m)
10098
// sessionFilesysDiff(r, c, m)
10199

102100
// runtime (runners?)
103101
// memory
104102
// artifacts
105103
}
106104

107-
func reloadEnvConfig(r *runtime.Runtime, c *runtime.Client, m *runtime.Message) {
105+
func reloadEnvConfig(ar *aruntime.Runtime, c *aruntime.Client, m *aruntime.Message) {
108106
// todo, this should happen on a per-client/user basis
109107
var err error
110-
err = r.ReadEnvConfig()
108+
err = ar.ReadEnvConfig()
111109
if err != nil {
112110
err = cuetils.ExpandCueError(err)
113111
c.Mail("config.reload.error", map[string]any{
114112
"status": "error",
115113
"error_message": fmt.Errorf("while reloading config: %w", err),
116114
})
117115
}
118-
configInfo(r, c, m)
116+
configInfo(ar, c, m)
119117
}

lib/extension/cmd.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ func Run(args []string, rflags flags.RootPflagpole) error {
4848
ar.BackfillAgentic()
4949

5050
// fmt.Println("BR.Agentics:", len(ar.Agentics))
51-
ws.SetupHandlers(ar)
51+
ws.SetupHandlers(r, ar)
5252

5353
return ar.Run()
5454
}

0 commit comments

Comments
 (0)