Skip to content

Commit c5162c8

Browse files
authored
feat: accept initial prompt (#112)
1 parent 517d0de commit c5162c8

File tree

4 files changed

+27
-5
lines changed

4 files changed

+27
-5
lines changed

cmd/server/server.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er
112112
ChatBasePath: viper.GetString(FlagChatBasePath),
113113
AllowedHosts: viper.GetStringSlice(FlagAllowedHosts),
114114
AllowedOrigins: viper.GetStringSlice(FlagAllowedOrigins),
115+
InitialPrompt: viper.GetString(FlagInitialPrompt),
115116
})
116117
if err != nil {
117118
return xerrors.Errorf("failed to create server: %w", err)
@@ -174,6 +175,7 @@ const (
174175
FlagAllowedHosts = "allowed-hosts"
175176
FlagAllowedOrigins = "allowed-origins"
176177
FlagExit = "exit"
178+
FlagInitialPrompt = "initial-prompt"
177179
)
178180

179181
func CreateServerCmd() *cobra.Command {
@@ -211,6 +213,7 @@ func CreateServerCmd() *cobra.Command {
211213
{FlagAllowedHosts, "a", []string{"localhost", "127.0.0.1", "[::1]"}, "HTTP allowed hosts (hostnames only, no ports). Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_HOSTS env var", "stringSlice"},
212214
// localhost:3284 is the default origin when you open the chat interface in your browser. localhost:3000 and 3001 are used during development.
213215
{FlagAllowedOrigins, "o", []string{"http://localhost:3284", "http://localhost:3000", "http://localhost:3001"}, "HTTP allowed origins. Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_ORIGINS env var", "stringSlice"},
216+
{FlagInitialPrompt, "I", "", "Initial prompt for the agent (recommended only if the agent doesn't support initial prompt in interaction mode)", "string"},
214217
}
215218

216219
for _, spec := range flagSpecs {

lib/httpapi/server.go

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ type ServerConfig struct {
9595
ChatBasePath string
9696
AllowedHosts []string
9797
AllowedOrigins []string
98+
InitialPrompt string
9899
}
99100

100101
// Validate allowed hosts don't contain whitespace, commas, schemes, or ports.
@@ -230,7 +231,7 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) {
230231
SnapshotInterval: snapshotInterval,
231232
ScreenStabilityLength: 2 * time.Second,
232233
FormatMessage: formatMessage,
233-
})
234+
}, config.InitialPrompt)
234235
emitter := NewEventEmitter(1024)
235236
s := &Server{
236237
router: router,
@@ -306,7 +307,19 @@ func (s *Server) StartSnapshotLoop(ctx context.Context) {
306307
s.conversation.StartSnapshotLoop(ctx)
307308
go func() {
308309
for {
309-
s.emitter.UpdateStatusAndEmitChanges(s.conversation.Status())
310+
currentStatus := s.conversation.Status()
311+
312+
// Send initial prompt when agent becomes stable for the first time
313+
if !s.conversation.InitialPromptSent && convertStatus(currentStatus) == AgentStatusStable {
314+
if err := s.conversation.SendMessage(FormatMessage(s.agentType, s.conversation.InitialPrompt)...); err != nil {
315+
s.logger.Error("Failed to send initial prompt", "error", err)
316+
} else {
317+
s.conversation.InitialPromptSent = true
318+
currentStatus = st.ConversationStatusChanging
319+
s.logger.Info("Initial prompt sent successfully")
320+
}
321+
}
322+
s.emitter.UpdateStatusAndEmitChanges(currentStatus)
310323
s.emitter.UpdateMessagesAndEmitChanges(s.conversation.Messages())
311324
s.emitter.UpdateScreenAndEmitChanges(s.conversation.Screen())
312325
time.Sleep(snapshotInterval)

lib/screentracker/conversation.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,10 @@ type Conversation struct {
7474
messages []ConversationMessage
7575
screenBeforeLastUserMessage string
7676
lock sync.Mutex
77+
// InitialPrompt is the initial prompt passed to the agent
78+
InitialPrompt string
79+
// InitialPromptSent keeps track if the InitialPrompt has been successfully sent to the agents
80+
InitialPromptSent bool
7781
}
7882

7983
type ConversationStatus string
@@ -94,7 +98,7 @@ func getStableSnapshotsThreshold(cfg ConversationConfig) int {
9498
return threshold + 1
9599
}
96100

97-
func NewConversation(ctx context.Context, cfg ConversationConfig) *Conversation {
101+
func NewConversation(ctx context.Context, cfg ConversationConfig, initialPrompt string) *Conversation {
98102
threshold := getStableSnapshotsThreshold(cfg)
99103
c := &Conversation{
100104
cfg: cfg,
@@ -107,6 +111,8 @@ func NewConversation(ctx context.Context, cfg ConversationConfig) *Conversation
107111
Time: cfg.GetTime(),
108112
},
109113
},
114+
InitialPrompt: initialPrompt,
115+
InitialPromptSent: len(initialPrompt) == 0,
110116
}
111117
return c
112118
}

lib/screentracker/conversation_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ func statusTest(t *testing.T, params statusTestParams) {
4242
if params.cfg.GetTime == nil {
4343
params.cfg.GetTime = func() time.Time { return time.Now() }
4444
}
45-
c := st.NewConversation(ctx, params.cfg)
45+
c := st.NewConversation(ctx, params.cfg, "")
4646
assert.Equal(t, st.ConversationStatusInitializing, c.Status())
4747

4848
for i, step := range params.steps {
@@ -147,7 +147,7 @@ func TestMessages(t *testing.T) {
147147
for _, opt := range opts {
148148
opt(&cfg)
149149
}
150-
return st.NewConversation(context.Background(), cfg)
150+
return st.NewConversation(context.Background(), cfg, "")
151151
}
152152

153153
t.Run("messages are copied", func(t *testing.T) {

0 commit comments

Comments
 (0)