diff --git a/cmd/server/server.go b/cmd/server/server.go index 56b07ea..3afe050 100644 --- a/cmd/server/server.go +++ b/cmd/server/server.go @@ -112,6 +112,7 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er ChatBasePath: viper.GetString(FlagChatBasePath), AllowedHosts: viper.GetStringSlice(FlagAllowedHosts), AllowedOrigins: viper.GetStringSlice(FlagAllowedOrigins), + InitialPrompt: viper.GetString(FlagInitialPrompt), }) if err != nil { return xerrors.Errorf("failed to create server: %w", err) @@ -174,6 +175,7 @@ const ( FlagAllowedHosts = "allowed-hosts" FlagAllowedOrigins = "allowed-origins" FlagExit = "exit" + FlagInitialPrompt = "initial-prompt" ) func CreateServerCmd() *cobra.Command { @@ -211,6 +213,7 @@ func CreateServerCmd() *cobra.Command { {FlagAllowedHosts, "a", []string{"localhost", "127.0.0.1", "[::1]"}, "HTTP allowed hosts (hostnames only, no ports). Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_HOSTS env var", "stringSlice"}, // localhost:3284 is the default origin when you open the chat interface in your browser. localhost:3000 and 3001 are used during development. {FlagAllowedOrigins, "o", []string{"http://localhost:3284", "http://localhost:3000", "http://localhost:3001"}, "HTTP allowed origins. Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_ORIGINS env var", "stringSlice"}, + {FlagInitialPrompt, "I", "", "Initial prompt for the agent (recommended only if the agent doesn't support initial prompt in interaction mode)", "string"}, } for _, spec := range flagSpecs { diff --git a/lib/httpapi/server.go b/lib/httpapi/server.go index 08d92c4..11994f0 100644 --- a/lib/httpapi/server.go +++ b/lib/httpapi/server.go @@ -95,6 +95,7 @@ type ServerConfig struct { ChatBasePath string AllowedHosts []string AllowedOrigins []string + InitialPrompt string } // Validate allowed hosts don't contain whitespace, commas, schemes, or ports. @@ -230,7 +231,7 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) { SnapshotInterval: snapshotInterval, ScreenStabilityLength: 2 * time.Second, FormatMessage: formatMessage, - }) + }, config.InitialPrompt) emitter := NewEventEmitter(1024) s := &Server{ router: router, @@ -306,7 +307,19 @@ func (s *Server) StartSnapshotLoop(ctx context.Context) { s.conversation.StartSnapshotLoop(ctx) go func() { for { - s.emitter.UpdateStatusAndEmitChanges(s.conversation.Status()) + currentStatus := s.conversation.Status() + + // Send initial prompt when agent becomes stable for the first time + if !s.conversation.InitialPromptSent && convertStatus(currentStatus) == AgentStatusStable { + if err := s.conversation.SendMessage(FormatMessage(s.agentType, s.conversation.InitialPrompt)...); err != nil { + s.logger.Error("Failed to send initial prompt", "error", err) + } else { + s.conversation.InitialPromptSent = true + currentStatus = st.ConversationStatusChanging + s.logger.Info("Initial prompt sent successfully") + } + } + s.emitter.UpdateStatusAndEmitChanges(currentStatus) s.emitter.UpdateMessagesAndEmitChanges(s.conversation.Messages()) s.emitter.UpdateScreenAndEmitChanges(s.conversation.Screen()) time.Sleep(snapshotInterval) diff --git a/lib/screentracker/conversation.go b/lib/screentracker/conversation.go index 7777e04..4617e8e 100644 --- a/lib/screentracker/conversation.go +++ b/lib/screentracker/conversation.go @@ -74,6 +74,10 @@ type Conversation struct { messages []ConversationMessage screenBeforeLastUserMessage string lock sync.Mutex + // InitialPrompt is the initial prompt passed to the agent + InitialPrompt string + // InitialPromptSent keeps track if the InitialPrompt has been successfully sent to the agents + InitialPromptSent bool } type ConversationStatus string @@ -94,7 +98,7 @@ func getStableSnapshotsThreshold(cfg ConversationConfig) int { return threshold + 1 } -func NewConversation(ctx context.Context, cfg ConversationConfig) *Conversation { +func NewConversation(ctx context.Context, cfg ConversationConfig, initialPrompt string) *Conversation { threshold := getStableSnapshotsThreshold(cfg) c := &Conversation{ cfg: cfg, @@ -107,6 +111,8 @@ func NewConversation(ctx context.Context, cfg ConversationConfig) *Conversation Time: cfg.GetTime(), }, }, + InitialPrompt: initialPrompt, + InitialPromptSent: len(initialPrompt) == 0, } return c } diff --git a/lib/screentracker/conversation_test.go b/lib/screentracker/conversation_test.go index 53c77fd..92fe5ac 100644 --- a/lib/screentracker/conversation_test.go +++ b/lib/screentracker/conversation_test.go @@ -42,7 +42,7 @@ func statusTest(t *testing.T, params statusTestParams) { if params.cfg.GetTime == nil { params.cfg.GetTime = func() time.Time { return time.Now() } } - c := st.NewConversation(ctx, params.cfg) + c := st.NewConversation(ctx, params.cfg, "") assert.Equal(t, st.ConversationStatusInitializing, c.Status()) for i, step := range params.steps { @@ -147,7 +147,7 @@ func TestMessages(t *testing.T) { for _, opt := range opts { opt(&cfg) } - return st.NewConversation(context.Background(), cfg) + return st.NewConversation(context.Background(), cfg, "") } t.Run("messages are copied", func(t *testing.T) {