Skip to content

Commit 087acf8

Browse files
committed
feat: accept initial prompt
1 parent 517d0de commit 087acf8

File tree

2 files changed

+41
-21
lines changed

2 files changed

+41
-21
lines changed

cmd/server/server.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er
112112
ChatBasePath: viper.GetString(FlagChatBasePath),
113113
AllowedHosts: viper.GetStringSlice(FlagAllowedHosts),
114114
AllowedOrigins: viper.GetStringSlice(FlagAllowedOrigins),
115+
InitialPrompt: viper.GetString(FlagInitialPrompt),
115116
})
116117
if err != nil {
117118
return xerrors.Errorf("failed to create server: %w", err)
@@ -174,6 +175,7 @@ const (
174175
FlagAllowedHosts = "allowed-hosts"
175176
FlagAllowedOrigins = "allowed-origins"
176177
FlagExit = "exit"
178+
FlagInitialPrompt = "initial-prompt"
177179
)
178180

179181
func CreateServerCmd() *cobra.Command {
@@ -211,6 +213,7 @@ func CreateServerCmd() *cobra.Command {
211213
{FlagAllowedHosts, "a", []string{"localhost", "127.0.0.1", "[::1]"}, "HTTP allowed hosts (hostnames only, no ports). Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_HOSTS env var", "stringSlice"},
212214
// localhost:3284 is the default origin when you open the chat interface in your browser. localhost:3000 and 3001 are used during development.
213215
{FlagAllowedOrigins, "o", []string{"http://localhost:3284", "http://localhost:3000", "http://localhost:3001"}, "HTTP allowed origins. Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_ORIGINS env var", "stringSlice"},
216+
{FlagInitialPrompt, "I", "", "Initial prompt for the agent (recommended only if the agent doesn't support initial prompt in interaction mode)", "string"},
214217
}
215218

216219
for _, spec := range flagSpecs {

lib/httpapi/server.go

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,19 @@ import (
2929

3030
// Server represents the HTTP server
3131
type Server struct {
32-
router chi.Router
33-
api huma.API
34-
port int
35-
srv *http.Server
36-
mu sync.RWMutex
37-
logger *slog.Logger
38-
conversation *st.Conversation
39-
agentio *termexec.Process
40-
agentType mf.AgentType
41-
emitter *EventEmitter
42-
chatBasePath string
32+
router chi.Router
33+
api huma.API
34+
port int
35+
srv *http.Server
36+
mu sync.RWMutex
37+
logger *slog.Logger
38+
conversation *st.Conversation
39+
agentio *termexec.Process
40+
agentType mf.AgentType
41+
emitter *EventEmitter
42+
chatBasePath string
43+
initialPrompt string
44+
initialPromptSent bool
4345
}
4446

4547
func (s *Server) NormalizeSchema(schema any) any {
@@ -95,6 +97,7 @@ type ServerConfig struct {
9597
ChatBasePath string
9698
AllowedHosts []string
9799
AllowedOrigins []string
100+
InitialPrompt string
98101
}
99102

100103
// Validate allowed hosts don't contain whitespace, commas, schemes, or ports.
@@ -233,15 +236,17 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) {
233236
})
234237
emitter := NewEventEmitter(1024)
235238
s := &Server{
236-
router: router,
237-
api: api,
238-
port: config.Port,
239-
conversation: conversation,
240-
logger: logger,
241-
agentio: config.Process,
242-
agentType: config.AgentType,
243-
emitter: emitter,
244-
chatBasePath: strings.TrimSuffix(config.ChatBasePath, "/"),
239+
router: router,
240+
api: api,
241+
port: config.Port,
242+
conversation: conversation,
243+
logger: logger,
244+
agentio: config.Process,
245+
agentType: config.AgentType,
246+
emitter: emitter,
247+
chatBasePath: strings.TrimSuffix(config.ChatBasePath, "/"),
248+
initialPrompt: config.InitialPrompt,
249+
initialPromptSent: len(config.InitialPrompt) == 0,
245250
}
246251

247252
// Register API routes
@@ -306,7 +311,19 @@ func (s *Server) StartSnapshotLoop(ctx context.Context) {
306311
s.conversation.StartSnapshotLoop(ctx)
307312
go func() {
308313
for {
309-
s.emitter.UpdateStatusAndEmitChanges(s.conversation.Status())
314+
currentStatus := s.conversation.Status()
315+
316+
// Send initial prompt when agent becomes stable for the first time
317+
if !s.initialPromptSent && convertStatus(currentStatus) == AgentStatusStable {
318+
if err := s.conversation.SendMessage(FormatMessage(s.agentType, s.initialPrompt)...); err != nil {
319+
s.logger.Error("Failed to send initial prompt", "error", err)
320+
} else {
321+
s.initialPromptSent = true
322+
currentStatus = st.ConversationStatusChanging
323+
s.logger.Info("Initial prompt sent successfully")
324+
}
325+
}
326+
s.emitter.UpdateStatusAndEmitChanges(currentStatus)
310327
s.emitter.UpdateMessagesAndEmitChanges(s.conversation.Messages())
311328
s.emitter.UpdateScreenAndEmitChanges(s.conversation.Screen())
312329
time.Sleep(snapshotInterval)

0 commit comments

Comments
 (0)