Skip to content

Commit d8f01f7

Browse files
authored
Merge pull request #1734 from dgageot/cmdroot
Small improvements to cmd/root
2 parents 7445ef8 + 1eb0073 commit d8f01f7

File tree

10 files changed

+51
-99
lines changed

10 files changed

+51
-99
lines changed

cmd/root/a2a.go

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
11
package root
22

33
import (
4-
"fmt"
5-
64
"github.com/spf13/cobra"
75

86
"github.com/docker/cagent/pkg/a2a"
97
"github.com/docker/cagent/pkg/cli"
108
"github.com/docker/cagent/pkg/config"
11-
"github.com/docker/cagent/pkg/server"
129
"github.com/docker/cagent/pkg/telemetry"
1310
)
1411

@@ -46,15 +43,10 @@ func (f *a2aFlags) runA2ACommand(cmd *cobra.Command, args []string) error {
4643
out := cli.NewPrinter(cmd.OutOrStdout())
4744
agentFilename := args[0]
4845

49-
// Listen as early as possible
50-
ln, err := server.Listen(ctx, f.listenAddr)
46+
ln, err := listenAndCloseOnCancel(ctx, f.listenAddr)
5147
if err != nil {
52-
return fmt.Errorf("failed to listen on %s: %w", f.listenAddr, err)
48+
return err
5349
}
54-
go func() {
55-
<-ctx.Done()
56-
_ = ln.Close()
57-
}()
5850

5951
out.Println("Listening on", ln.Addr().String())
6052
return a2a.Run(ctx, agentFilename, f.agentName, &f.runConfig, ln)

cmd/root/api.go

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ func (f *apiFlags) runAPICommand(cmd *cobra.Command, args []string) error {
123123
// Start recording proxy if --record is specified
124124
if _, recordCleanup, err := setupRecordingProxy(f.recordPath, &f.runConfig); err != nil {
125125
return err
126-
} else if recordCleanup != nil {
126+
} else {
127127
defer func() {
128128
if err := recordCleanup(); err != nil {
129129
slog.Error("Failed to cleanup recording proxy", "error", err)
@@ -135,14 +135,10 @@ func (f *apiFlags) runAPICommand(cmd *cobra.Command, args []string) error {
135135
return fmt.Errorf("--pull-interval flag can only be used with OCI references, not local files")
136136
}
137137

138-
ln, err := server.Listen(ctx, f.listenAddr)
138+
ln, err := listenAndCloseOnCancel(ctx, f.listenAddr)
139139
if err != nil {
140-
return fmt.Errorf("failed to listen on %s: %w", f.listenAddr, err)
140+
return err
141141
}
142-
go func() {
143-
<-ctx.Done()
144-
_ = ln.Close()
145-
}()
146142

147143
out.Println("Listening on", ln.Addr().String())
148144

cmd/root/catalog.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,8 @@ func newCatalogListCmd() *cobra.Command {
4040
func runCatalogListCommand(cmd *cobra.Command, args []string) error {
4141
telemetry.TrackCommand("catalog", append([]string{"list"}, args...))
4242

43-
var org string
44-
if len(args) == 0 {
45-
org = "agentcatalog"
46-
} else {
43+
org := "agentcatalog"
44+
if len(args) > 0 {
4745
org = args[0]
4846
}
4947

cmd/root/debug.go

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,10 @@ func newDebugCmd() *cobra.Command {
5656
return cmd
5757
}
5858

59-
// resolveSource resolves an agent file reference to a config source.
60-
func (f *debugFlags) resolveSource(agentFilename string) (config.Source, error) {
61-
return config.Resolve(agentFilename, f.runConfig.EnvProvider())
62-
}
63-
6459
// loadTeam loads an agent team from the given agent file and returns
6560
// a cleanup function that must be deferred by the caller.
6661
func (f *debugFlags) loadTeam(ctx context.Context, agentFilename string, opts ...teamloader.Opt) (*team.Team, func(), error) {
67-
agentSource, err := f.resolveSource(agentFilename)
62+
agentSource, err := config.Resolve(agentFilename, f.runConfig.EnvProvider())
6863
if err != nil {
6964
return nil, nil, err
7065
}
@@ -86,7 +81,7 @@ func (f *debugFlags) loadTeam(ctx context.Context, agentFilename string, opts ..
8681
func (f *debugFlags) runDebugConfigCommand(cmd *cobra.Command, args []string) error {
8782
telemetry.TrackCommand("debug", append([]string{"config"}, args...))
8883

89-
agentSource, err := f.resolveSource(args[0])
84+
agentSource, err := config.Resolve(args[0], f.runConfig.EnvProvider())
9085
if err != nil {
9186
return err
9287
}

cmd/root/exec.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,7 @@ func newExecCmd() *cobra.Command {
3535
func (f *runExecFlags) runExecCommand(cmd *cobra.Command, args []string) error {
3636
telemetry.TrackCommand("exec", args)
3737

38-
ctx := cmd.Context()
3938
out := cli.NewPrinter(cmd.OutOrStdout())
4039

41-
tui := false
42-
return f.runOrExec(ctx, out, args, tui)
40+
return f.runOrExec(cmd.Context(), out, args, false)
4341
}

cmd/root/flags.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
package root
22

33
import (
4+
"context"
45
"fmt"
56
"log/slog"
7+
"net"
68
"os"
79
"path/filepath"
810
"strings"
@@ -11,6 +13,7 @@ import (
1113

1214
"github.com/docker/cagent/pkg/config"
1315
"github.com/docker/cagent/pkg/config/latest"
16+
"github.com/docker/cagent/pkg/server"
1417
"github.com/docker/cagent/pkg/userconfig"
1518
)
1619

@@ -113,3 +116,17 @@ func parseModelShorthand(s string) *latest.ModelConfig {
113116
}
114117
return nil
115118
}
119+
120+
// listenAndCloseOnCancel starts a listener and spawns a goroutine
121+
// that closes it when the context is cancelled.
122+
func listenAndCloseOnCancel(ctx context.Context, addr string) (net.Listener, error) {
123+
ln, err := server.Listen(ctx, addr)
124+
if err != nil {
125+
return nil, fmt.Errorf("failed to listen on %s: %w", addr, err)
126+
}
127+
go func() {
128+
<-ctx.Done()
129+
_ = ln.Close()
130+
}()
131+
return ln, nil
132+
}

cmd/root/mcp.go

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
11
package root
22

33
import (
4-
"fmt"
5-
64
"github.com/spf13/cobra"
75

86
"github.com/docker/cagent/pkg/config"
97
"github.com/docker/cagent/pkg/mcp"
10-
"github.com/docker/cagent/pkg/server"
118
"github.com/docker/cagent/pkg/telemetry"
129
)
1310

@@ -52,15 +49,10 @@ func (f *mcpFlags) runMCPCommand(cmd *cobra.Command, args []string) error {
5249
return mcp.StartMCPServer(ctx, agentFilename, f.agentName, &f.runConfig)
5350
}
5451

55-
// Listen as early as possible
56-
ln, err := server.Listen(ctx, f.listenAddr)
52+
ln, err := listenAndCloseOnCancel(ctx, f.listenAddr)
5753
if err != nil {
58-
return fmt.Errorf("failed to listen on %s: %w", f.listenAddr, err)
54+
return err
5955
}
60-
go func() {
61-
<-ctx.Done()
62-
_ = ln.Close()
63-
}()
6456

6557
return mcp.StartHTTPServer(ctx, agentFilename, f.agentName, &f.runConfig, ln)
6658
}

cmd/root/otel.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,9 @@ func initOTelSDK(ctx context.Context) (err error) {
4444
}
4545

4646
// Configure tracer provider
47-
var tracerProviderOpts []trace.TracerProviderOption
48-
tracerProviderOpts = append(tracerProviderOpts, trace.WithResource(res))
47+
tracerProviderOpts := []trace.TracerProviderOption{
48+
trace.WithResource(res),
49+
}
4950

5051
if traceExporter != nil {
5152
tracerProviderOpts = append(tracerProviderOpts,

cmd/root/pull.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import (
66
"os"
77
"strings"
88

9-
"github.com/google/go-containerregistry/pkg/crane"
109
"github.com/spf13/cobra"
1110

1211
"github.com/docker/cagent/pkg/cli"
@@ -46,8 +45,7 @@ func (f *pullFlags) runPullCommand(cmd *cobra.Command, args []string) error {
4645

4746
out.Println("Pulling agent", registryRef)
4847

49-
var opts []crane.Option
50-
_, err := remote.Pull(ctx, registryRef, f.force, opts...)
48+
_, err := remote.Pull(ctx, registryRef, f.force)
5149
if err != nil {
5250
return fmt.Errorf("failed to pull artifact: %w", err)
5351
}

cmd/root/run.go

Lines changed: 18 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,14 @@ func (f *runExecFlags) runOrExec(ctx context.Context, out *cli.Printer, args []s
234234
return err
235235
}
236236

237-
loadResult, err := f.loadAgentFrom(ctx, agentSource)
237+
opts := []teamloader.Opt{
238+
teamloader.WithModelOverrides(f.modelOverrides),
239+
}
240+
if len(f.promptFiles) > 0 {
241+
opts = append(opts, teamloader.WithPromptFiles(f.promptFiles))
242+
}
243+
244+
loadResult, err := teamloader.LoadWithConfig(ctx, agentSource, &f.runConfig, opts...)
238245
if err != nil {
239246
return err
240247
}
@@ -275,58 +282,16 @@ func (f *runExecFlags) runOrExec(ctx context.Context, out *cli.Printer, args []s
275282
return f.handleRunMode(ctx, rt, sess, args)
276283
}
277284

278-
func (f *runExecFlags) loadAgentFrom(ctx context.Context, agentSource config.Source) (*teamloader.LoadResult, error) {
279-
opts := []teamloader.Opt{
280-
teamloader.WithModelOverrides(f.modelOverrides),
281-
}
282-
if len(f.promptFiles) > 0 {
283-
opts = append(opts, teamloader.WithPromptFiles(f.promptFiles))
284-
}
285-
286-
result, err := teamloader.LoadWithConfig(ctx, agentSource, &f.runConfig, opts...)
287-
if err != nil {
288-
return nil, err
289-
}
290-
291-
return result, nil
292-
}
293-
294285
func (f *runExecFlags) createRemoteRuntimeAndSession(ctx context.Context, originalFilename string) (runtime.Runtime, *session.Session, error) {
295-
if f.connectRPC {
296-
return f.createConnectRPCRuntimeAndSession(ctx, originalFilename)
297-
}
298-
return f.createHTTPRuntimeAndSession(ctx, originalFilename)
299-
}
300-
301-
func (f *runExecFlags) createConnectRPCRuntimeAndSession(ctx context.Context, originalFilename string) (runtime.Runtime, *session.Session, error) {
302-
connectClient, err := runtime.NewConnectRPCClient(f.remoteAddress)
303-
if err != nil {
304-
return nil, nil, fmt.Errorf("failed to create connect-rpc client: %w", err)
305-
}
306-
307-
sessTemplate := session.New(
308-
session.WithToolsApproved(f.autoApprove),
309-
)
310-
311-
sess, err := connectClient.CreateSession(ctx, sessTemplate)
312-
if err != nil {
313-
return nil, nil, err
314-
}
315-
316-
remoteRt, err := runtime.NewRemoteRuntime(connectClient,
317-
runtime.WithRemoteCurrentAgent(f.agentName),
318-
runtime.WithRemoteAgentFilename(originalFilename),
286+
var (
287+
client runtime.RemoteClient
288+
err error
319289
)
320-
if err != nil {
321-
return nil, nil, fmt.Errorf("failed to create connect-rpc remote runtime: %w", err)
290+
if f.connectRPC {
291+
client, err = runtime.NewConnectRPCClient(f.remoteAddress)
292+
} else {
293+
client, err = runtime.NewClient(f.remoteAddress)
322294
}
323-
324-
slog.Debug("Using connect-rpc remote runtime", "address", f.remoteAddress, "agent", f.agentName)
325-
return remoteRt, sess, nil
326-
}
327-
328-
func (f *runExecFlags) createHTTPRuntimeAndSession(ctx context.Context, originalFilename string) (runtime.Runtime, *session.Session, error) {
329-
remoteClient, err := runtime.NewClient(f.remoteAddress)
330295
if err != nil {
331296
return nil, nil, fmt.Errorf("failed to create remote client: %w", err)
332297
}
@@ -335,20 +300,20 @@ func (f *runExecFlags) createHTTPRuntimeAndSession(ctx context.Context, original
335300
session.WithToolsApproved(f.autoApprove),
336301
)
337302

338-
sess, err := remoteClient.CreateSession(ctx, sessTemplate)
303+
sess, err := client.CreateSession(ctx, sessTemplate)
339304
if err != nil {
340305
return nil, nil, err
341306
}
342307

343-
remoteRt, err := runtime.NewRemoteRuntime(remoteClient,
308+
remoteRt, err := runtime.NewRemoteRuntime(client,
344309
runtime.WithRemoteCurrentAgent(f.agentName),
345310
runtime.WithRemoteAgentFilename(originalFilename),
346311
)
347312
if err != nil {
348313
return nil, nil, fmt.Errorf("failed to create remote runtime: %w", err)
349314
}
350315

351-
slog.Debug("Using remote runtime", "address", f.remoteAddress, "agent", f.agentName)
316+
slog.Debug("Using remote runtime", "address", f.remoteAddress, "agent", f.agentName, "connect_rpc", f.connectRPC)
352317
return remoteRt, sess, nil
353318
}
354319

0 commit comments

Comments
 (0)