Skip to content

Commit a61768d

Browse files
consolidate printing to generate
1 parent d7d54ca commit a61768d

File tree

1 file changed

+21
-10
lines changed

1 file changed

+21
-10
lines changed

src/pkg/agent/agent.go

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -175,11 +175,6 @@ func (a *Agent) startSession() error {
175175
}
176176

177177
func (a *Agent) handleToolRequest(req *ai.ToolRequest) (*ai.ToolResponse, error) {
178-
inputs, err := json.Marshal(req.Input)
179-
if err != nil {
180-
return nil, fmt.Errorf("error marshaling tool request input: %w", err)
181-
}
182-
a.Printf("* %s(%s)\n", req.Name, inputs)
183178
tool := genkit.LookupTool(a.g, req.Name)
184179
if tool == nil {
185180
return nil, fmt.Errorf("tool %q not found", req.Name)
@@ -229,9 +224,9 @@ func (a *Agent) handleToolCalls(requests []*ai.ToolRequest) ([]*ai.Message, erro
229224
}
230225

231226
func (a *Agent) streamingCallback(ctx context.Context, chunk *ai.ModelResponseChunk) error {
232-
for _, part := range chunk.Content {
233-
a.Printf("%s", part.Text)
234-
}
227+
// for _, part := range chunk.Content {
228+
// a.Printf("%s", part.Text)
229+
// }
235230
return nil
236231
}
237232

@@ -242,7 +237,7 @@ func (a *Agent) handleUserMessage(msg string) error {
242237
}
243238

244239
func (a *Agent) generateLoop() error {
245-
a.Printf("* Thinking...\r* ")
240+
a.Printf("* Thinking...\r")
246241

247242
for range a.maxTurns {
248243
resp, err := a.generate()
@@ -271,10 +266,26 @@ func (a *Agent) generate() (*ai.ModelResponse, error) {
271266
ai.WithReturnToolRequests(true),
272267
ai.WithStreaming(a.streamingCallback),
273268
)
274-
a.Println("")
275269
if err != nil {
276270
return nil, err
277271
}
272+
for _, part := range resp.Message.Content {
273+
if part.Kind == ai.PartText {
274+
a.Printf("%s", part.Text)
275+
}
276+
if part.Kind == ai.PartToolRequest {
277+
req := part.ToolRequest
278+
inputs, err := json.Marshal(req.Input)
279+
if err != nil {
280+
return nil, fmt.Errorf("error marshaling tool request input: %w", err)
281+
}
282+
a.Printf("* %s(%s)\n", req.Name, inputs)
283+
}
284+
if part.Kind == ai.PartReasoning {
285+
a.Printf("_%s_\n", part.Text)
286+
}
287+
}
288+
a.Println("")
278289

279290
a.msgs = append(a.msgs, resp.Message)
280291
return resp, nil

0 commit comments

Comments
 (0)