Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions cmd/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er
ChatBasePath: viper.GetString(FlagChatBasePath),
AllowedHosts: viper.GetStringSlice(FlagAllowedHosts),
AllowedOrigins: viper.GetStringSlice(FlagAllowedOrigins),
InitialPrompt: viper.GetString(FlagInitialPrompt),
})
if err != nil {
return xerrors.Errorf("failed to create server: %w", err)
Expand Down Expand Up @@ -174,6 +175,7 @@ const (
FlagAllowedHosts = "allowed-hosts"
FlagAllowedOrigins = "allowed-origins"
FlagExit = "exit"
FlagInitialPrompt = "initial-prompt"
)

func CreateServerCmd() *cobra.Command {
Expand Down Expand Up @@ -211,6 +213,7 @@ func CreateServerCmd() *cobra.Command {
{FlagAllowedHosts, "a", []string{"localhost", "127.0.0.1", "[::1]"}, "HTTP allowed hosts (hostnames only, no ports). Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_HOSTS env var", "stringSlice"},
// localhost:3284 is the default origin when you open the chat interface in your browser. localhost:3000 and 3001 are used during development.
{FlagAllowedOrigins, "o", []string{"http://localhost:3284", "http://localhost:3000", "http://localhost:3001"}, "HTTP allowed origins. Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_ORIGINS env var", "stringSlice"},
{FlagInitialPrompt, "I", "", "Initial prompt for the agent (recommended only if the agent doesn't support initial prompt in interaction mode)", "string"},
}

for _, spec := range flagSpecs {
Expand Down
17 changes: 15 additions & 2 deletions lib/httpapi/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ type ServerConfig struct {
ChatBasePath string
AllowedHosts []string
AllowedOrigins []string
InitialPrompt string
}

// Validate allowed hosts don't contain whitespace, commas, schemes, or ports.
Expand Down Expand Up @@ -230,7 +231,7 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) {
SnapshotInterval: snapshotInterval,
ScreenStabilityLength: 2 * time.Second,
FormatMessage: formatMessage,
})
}, config.InitialPrompt)
emitter := NewEventEmitter(1024)
s := &Server{
router: router,
Expand Down Expand Up @@ -306,7 +307,19 @@ func (s *Server) StartSnapshotLoop(ctx context.Context) {
s.conversation.StartSnapshotLoop(ctx)
go func() {
for {
s.emitter.UpdateStatusAndEmitChanges(s.conversation.Status())
currentStatus := s.conversation.Status()

// Send initial prompt when agent becomes stable for the first time
if !s.conversation.InitialPromptSent && convertStatus(currentStatus) == AgentStatusStable {
if err := s.conversation.SendMessage(FormatMessage(s.agentType, s.conversation.InitialPrompt)...); err != nil {
s.logger.Error("Failed to send initial prompt", "error", err)
} else {
s.conversation.InitialPromptSent = true
currentStatus = st.ConversationStatusChanging
s.logger.Info("Initial prompt sent successfully")
}
}
s.emitter.UpdateStatusAndEmitChanges(currentStatus)
s.emitter.UpdateMessagesAndEmitChanges(s.conversation.Messages())
s.emitter.UpdateScreenAndEmitChanges(s.conversation.Screen())
time.Sleep(snapshotInterval)
Expand Down
8 changes: 7 additions & 1 deletion lib/screentracker/conversation.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ type Conversation struct {
messages []ConversationMessage
screenBeforeLastUserMessage string
lock sync.Mutex
// InitialPrompt is the initial prompt passed to the agent
InitialPrompt string
// InitialPromptSent keeps track if the InitialPrompt has been successfully sent to the agents
InitialPromptSent bool
}

type ConversationStatus string
Expand All @@ -94,7 +98,7 @@ func getStableSnapshotsThreshold(cfg ConversationConfig) int {
return threshold + 1
}

func NewConversation(ctx context.Context, cfg ConversationConfig) *Conversation {
func NewConversation(ctx context.Context, cfg ConversationConfig, initialPrompt string) *Conversation {
threshold := getStableSnapshotsThreshold(cfg)
c := &Conversation{
cfg: cfg,
Expand All @@ -107,6 +111,8 @@ func NewConversation(ctx context.Context, cfg ConversationConfig) *Conversation
Time: cfg.GetTime(),
},
},
InitialPrompt: initialPrompt,
InitialPromptSent: len(initialPrompt) == 0,
}
return c
}
Expand Down
4 changes: 2 additions & 2 deletions lib/screentracker/conversation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func statusTest(t *testing.T, params statusTestParams) {
if params.cfg.GetTime == nil {
params.cfg.GetTime = func() time.Time { return time.Now() }
}
c := st.NewConversation(ctx, params.cfg)
c := st.NewConversation(ctx, params.cfg, "")
assert.Equal(t, st.ConversationStatusInitializing, c.Status())

for i, step := range params.steps {
Expand Down Expand Up @@ -147,7 +147,7 @@ func TestMessages(t *testing.T) {
for _, opt := range opts {
opt(&cfg)
}
return st.NewConversation(context.Background(), cfg)
return st.NewConversation(context.Background(), cfg, "")
}

t.Run("messages are copied", func(t *testing.T) {
Expand Down