Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,6 @@ func Execute() {
}

func init() {
rootCmd.AddCommand(server.ServerCmd)
rootCmd.AddCommand(server.CreateServerCmd())
rootCmd.AddCommand(attach.AttachCmd)
}
108 changes: 77 additions & 31 deletions cmd/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"strings"

"github.com/spf13/cobra"
"github.com/spf13/viper"
"golang.org/x/xerrors"

"github.com/coder/agentapi/lib/httpapi"
Expand All @@ -19,15 +20,6 @@ import (
"github.com/coder/agentapi/lib/termexec"
)

var (
agentTypeVar string
port int
printOpenAPI bool
chatBasePath string
termWidth uint16
termHeight uint16
)

type AgentType = msgfmt.AgentType

const (
Expand Down Expand Up @@ -68,11 +60,15 @@ func parseAgentType(firstArg string, agentTypeVar string) (AgentType, error) {

func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) error {
agent := argsToPass[0]
agentType, err := parseAgentType(agent, agentTypeVar)
agentTypeValue := viper.GetString(FlagType)
agentType, err := parseAgentType(agent, agentTypeValue)
if err != nil {
return xerrors.Errorf("failed to parse agent type: %w", err)
}

termWidth := viper.GetUint16(FlagTermWidth)
termHeight := viper.GetUint16(FlagTermHeight)

if termWidth < 10 {
return xerrors.Errorf("term width must be at least 10")
}
Expand All @@ -83,6 +79,7 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er
termHeight = 930 // codex has a bug where the TUI distorts the screen if the height is too large, see: https://github.com/openai/codex/issues/1608
}

printOpenAPI := viper.GetBool(FlagPrintOpenAPI)
var process *termexec.Process
if printOpenAPI {
process = nil
Expand All @@ -97,7 +94,13 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er
return xerrors.Errorf("failed to setup process: %w", err)
}
}
srv := httpapi.NewServer(ctx, agentType, process, port, chatBasePath)
port := viper.GetInt(FlagPort)
srv := httpapi.NewServer(ctx, httpapi.ServerConfig{
AgentType: agentType,
Process: process,
Port: port,
ChatBasePath: viper.GetString(FlagChatBasePath),
})
if printOpenAPI {
fmt.Println(srv.GetOpenAPI())
return nil
Expand Down Expand Up @@ -138,26 +141,69 @@ var agentNames = (func() []string {
return names
})()

var ServerCmd = &cobra.Command{
Use: "server [agent]",
Short: "Run the server",
Long: fmt.Sprintf("Run the server with the specified agent (one of: %s)", strings.Join(agentNames, ", ")),
Args: cobra.MinimumNArgs(1),
Run: func(cmd *cobra.Command, args []string) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := logctx.WithLogger(context.Background(), logger)
if err := runServer(ctx, logger, cmd.Flags().Args()); err != nil {
fmt.Fprintf(os.Stderr, "%+v\n", err)
os.Exit(1)
}
},
type flagSpec struct {
name string
shorthand string
defaultValue any
usage string
flagType string
}

func init() {
ServerCmd.Flags().StringVarP(&agentTypeVar, "type", "t", "", fmt.Sprintf("Override the agent type (one of: %s, custom)", strings.Join(agentNames, ", ")))
ServerCmd.Flags().IntVarP(&port, "port", "p", 3284, "Port to run the server on")
ServerCmd.Flags().BoolVarP(&printOpenAPI, "print-openapi", "P", false, "Print the OpenAPI schema to stdout and exit")
ServerCmd.Flags().StringVarP(&chatBasePath, "chat-base-path", "c", "/chat", "Base path for assets and routes used in the static files of the chat interface")
ServerCmd.Flags().Uint16VarP(&termWidth, "term-width", "W", 80, "Width of the emulated terminal")
ServerCmd.Flags().Uint16VarP(&termHeight, "term-height", "H", 1000, "Height of the emulated terminal")
const (
FlagType = "type"
FlagPort = "port"
FlagPrintOpenAPI = "print-openapi"
FlagChatBasePath = "chat-base-path"
FlagTermWidth = "term-width"
FlagTermHeight = "term-height"
)

func CreateServerCmd() *cobra.Command {
serverCmd := &cobra.Command{
Use: "server [agent]",
Short: "Run the server",
Long: fmt.Sprintf("Run the server with the specified agent (one of: %s)", strings.Join(agentNames, ", ")),
Args: cobra.MinimumNArgs(1),
Run: func(cmd *cobra.Command, args []string) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := logctx.WithLogger(context.Background(), logger)
if err := runServer(ctx, logger, cmd.Flags().Args()); err != nil {
fmt.Fprintf(os.Stderr, "%+v\n", err)
os.Exit(1)
}
},
}

flagSpecs := []flagSpec{
{FlagType, "t", "", fmt.Sprintf("Override the agent type (one of: %s, custom)", strings.Join(agentNames, ", ")), "string"},
{FlagPort, "p", 3284, "Port to run the server on", "int"},
{FlagPrintOpenAPI, "P", false, "Print the OpenAPI schema to stdout and exit", "bool"},
{FlagChatBasePath, "c", "/chat", "Base path for assets and routes used in the static files of the chat interface", "string"},
{FlagTermWidth, "W", uint16(80), "Width of the emulated terminal", "uint16"},
{FlagTermHeight, "H", uint16(1000), "Height of the emulated terminal", "uint16"},
}

for _, spec := range flagSpecs {
switch spec.flagType {
case "string":
serverCmd.Flags().StringP(spec.name, spec.shorthand, spec.defaultValue.(string), spec.usage)
case "int":
serverCmd.Flags().IntP(spec.name, spec.shorthand, spec.defaultValue.(int), spec.usage)
case "bool":
serverCmd.Flags().BoolP(spec.name, spec.shorthand, spec.defaultValue.(bool), spec.usage)
case "uint16":
serverCmd.Flags().Uint16P(spec.name, spec.shorthand, spec.defaultValue.(uint16), spec.usage)
default:
panic(fmt.Sprintf("unknown flag type: %s", spec.flagType))
}
if err := viper.BindPFlag(spec.name, serverCmd.Flags().Lookup(spec.name)); err != nil {
panic(fmt.Sprintf("failed to bind flag %s: %v", spec.name, err))
}
}

viper.SetEnvPrefix("AGENTAPI")
viper.AutomaticEnv()
viper.SetEnvKeyReplacer(strings.NewReplacer("-", "_"))

return serverCmd
}
200 changes: 200 additions & 0 deletions cmd/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,13 @@ package server

import (
"fmt"
"os"
"strings"
"testing"

"github.com/spf13/cobra"
"github.com/spf13/viper"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -83,3 +88,198 @@ func TestParseAgentType(t *testing.T) {
require.Error(t, err)
})
}

// Test helper to isolate viper config between tests
func isolateViper(t *testing.T) {
// Save current state
oldConfig := viper.AllSettings()

// Reset viper
viper.Reset()

// Clear AGENTAPI_ env vars
var agentapiEnvs []string
for _, env := range os.Environ() {
if strings.HasPrefix(env, "AGENTAPI_") {
parts := strings.SplitN(env, "=", 2)
agentapiEnvs = append(agentapiEnvs, parts[0])
os.Unsetenv(parts[0])
}
}

t.Cleanup(func() {
// Restore state
viper.Reset()
for key, value := range oldConfig {
viper.Set(key, value)
}

// Restore env vars
for _, key := range agentapiEnvs {
if val := os.Getenv(key); val != "" {
os.Setenv(key, val)
}
}
})
}

// Test configuration values via ServerCmd execution
func TestServerCmd_AllArgs_Defaults(t *testing.T) {
tests := []struct {
name string
flag string
expected any
getter func() any
}{
{"type default", FlagType, "", func() any { return viper.GetString(FlagType) }},
{"port default", FlagPort, 3284, func() any { return viper.GetInt(FlagPort) }},
{"print-openapi default", FlagPrintOpenAPI, false, func() any { return viper.GetBool(FlagPrintOpenAPI) }},
{"chat-base-path default", FlagChatBasePath, "/chat", func() any { return viper.GetString(FlagChatBasePath) }},
{"term-width default", FlagTermWidth, uint16(80), func() any { return viper.GetUint16(FlagTermWidth) }},
{"term-height default", FlagTermHeight, uint16(1000), func() any { return viper.GetUint16(FlagTermHeight) }},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isolateViper(t)
serverCmd := CreateServerCmd()
cmd := &cobra.Command{}
cmd.AddCommand(serverCmd)

// Execute with no args to get defaults
serverCmd.SetArgs([]string{"--help"}) // Use help to avoid actual execution
serverCmd.Execute()

assert.Equal(t, tt.expected, tt.getter())
})
}
}

func TestServerCmd_AllEnvVars(t *testing.T) {
tests := []struct {
name string
envVar string
envValue string
expected any
getter func() any
}{
{"AGENTAPI_TYPE", "AGENTAPI_TYPE", "claude", "claude", func() any { return viper.GetString(FlagType) }},
{"AGENTAPI_PORT", "AGENTAPI_PORT", "8080", 8080, func() any { return viper.GetInt(FlagPort) }},
{"AGENTAPI_PRINT_OPENAPI", "AGENTAPI_PRINT_OPENAPI", "true", true, func() any { return viper.GetBool(FlagPrintOpenAPI) }},
{"AGENTAPI_CHAT_BASE_PATH", "AGENTAPI_CHAT_BASE_PATH", "/api", "/api", func() any { return viper.GetString(FlagChatBasePath) }},
{"AGENTAPI_TERM_WIDTH", "AGENTAPI_TERM_WIDTH", "120", uint16(120), func() any { return viper.GetUint16(FlagTermWidth) }},
{"AGENTAPI_TERM_HEIGHT", "AGENTAPI_TERM_HEIGHT", "500", uint16(500), func() any { return viper.GetUint16(FlagTermHeight) }},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isolateViper(t)
os.Setenv(tt.envVar, tt.envValue)
defer os.Unsetenv(tt.envVar)

serverCmd := CreateServerCmd()
cmd := &cobra.Command{}
cmd.AddCommand(serverCmd)

serverCmd.SetArgs([]string{"--help"})
serverCmd.Execute()

assert.Equal(t, tt.expected, tt.getter())
})
}
}

func TestServerCmd_ArgsPrecedenceOverEnv(t *testing.T) {
tests := []struct {
name string
envVar string
envValue string
args []string
expected any
getter func() any
}{
{
"type: CLI overrides env",
"AGENTAPI_TYPE", "goose",
[]string{"--type", "claude"},
"claude",
func() any { return viper.GetString(FlagType) },
},
{
"port: CLI overrides env",
"AGENTAPI_PORT", "8080",
[]string{"--port", "9090"},
9090,
func() any { return viper.GetInt(FlagPort) },
},
{
"print-openapi: CLI overrides env",
"AGENTAPI_PRINT_OPENAPI", "false",
[]string{"--print-openapi"},
true,
func() any { return viper.GetBool(FlagPrintOpenAPI) },
},
{
"chat-base-path: CLI overrides env",
"AGENTAPI_CHAT_BASE_PATH", "/env-path",
[]string{"--chat-base-path", "/cli-path"},
"/cli-path",
func() any { return viper.GetString(FlagChatBasePath) },
},
{
"term-width: CLI overrides env",
"AGENTAPI_TERM_WIDTH", "100",
[]string{"--term-width", "150"},
uint16(150),
func() any { return viper.GetUint16(FlagTermWidth) },
},
{
"term-height: CLI overrides env",
"AGENTAPI_TERM_HEIGHT", "500",
[]string{"--term-height", "600"},
uint16(600),
func() any { return viper.GetUint16(FlagTermHeight) },
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isolateViper(t)
os.Setenv(tt.envVar, tt.envValue)
defer os.Unsetenv(tt.envVar)

// Mock execution to test arg parsing without running server
args := append(tt.args, "--help")
serverCmd := CreateServerCmd()
serverCmd.SetArgs(args)
serverCmd.Execute()

assert.Equal(t, tt.expected, tt.getter())
})
}
}

func TestMixed_ConfigurationScenarios(t *testing.T) {
t.Run("some env, some cli, some defaults", func(t *testing.T) {
isolateViper(t)

// Set some env vars
os.Setenv("AGENTAPI_TYPE", "goose")
os.Setenv("AGENTAPI_TERM_WIDTH", "120")
defer os.Unsetenv("AGENTAPI_TYPE")
defer os.Unsetenv("AGENTAPI_TERM_WIDTH")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: t.Setenv


// Set some CLI args
serverCmd := CreateServerCmd()
serverCmd.SetArgs([]string{"--port", "9999", "--print-openapi", "--help"})
serverCmd.Execute()

// Verify mixed configuration
assert.Equal(t, "goose", viper.GetString(FlagType)) // from env
assert.Equal(t, 9999, viper.GetInt(FlagPort)) // from CLI
assert.Equal(t, true, viper.GetBool(FlagPrintOpenAPI)) // from CLI
assert.Equal(t, "/chat", viper.GetString(FlagChatBasePath)) // default
assert.Equal(t, uint16(120), viper.GetUint16(FlagTermWidth)) // from env
assert.Equal(t, uint16(1000), viper.GetUint16(FlagTermHeight)) // default
})
}
Loading