Skip to content

Commit 9aad1e8

Browse files
MCP Prompts functionality
1 parent 400310f commit 9aad1e8

File tree

8 files changed

+408
-91
lines changed

8 files changed

+408
-91
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1010
### Added
1111

1212
- Add refresh title feature to refresh chat title from LLM
13+
- Implement MCP Prompts functionality
1314

1415
### Fixed
1516

cmd/server/main.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ func main() {
9393
mux.HandleFunc("/", m.HandleHome)
9494
mux.HandleFunc("/chats", m.HandleChats)
9595
mux.HandleFunc("/refresh-title", m.HandleRefreshTitle)
96+
mux.HandleFunc("/use-prompt", m.HandleUsePrompt)
9697
mux.HandleFunc("/sse/messages", m.HandleSSE)
9798
mux.HandleFunc("/sse/chats", m.HandleSSE)
9899

internal/handlers/chat.go

Lines changed: 198 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -50,35 +50,33 @@ func callToolError(err error) json.RawMessage {
5050
}
5151

5252
// HandleChats processes chat interactions through HTTP POST requests,
53-
// managing both new chat creation and message handling. It accepts user messages through form data,
54-
// creates appropriate chat contexts, and initiates asynchronous processing for AI responses and chat title generation.
53+
// managing both new chat creation and message handling. It supports two input methods:
54+
// 1. Regular messages via the "message" form field
55+
// 2. Predefined prompts via "prompt_name" and "prompt_args" form fields
5556
//
56-
// The handler expects a "message" form field and an optional "chat_id" field.
57-
// If no chat_id is provided, it creates a new chat session. The handler streams AI responses through
58-
// Server-Sent Events (SSE) and updates the UI accordingly through template rendering.
57+
// The handler expects an optional "chat_id" field. If no chat_id is provided,
58+
// it creates a new chat session. For new chats, it asynchronously generates a title
59+
// based on the first message or prompt.
60+
//
61+
// The function handles different rendering strategies based on whether it's a new chat
62+
// (complete chatbox template) or an existing chat (individual message templates). For
63+
// all chats, it adds messages to the database and initiates asynchronous AI response
64+
// generation that will be streamed via Server-Sent Events (SSE).
5965
//
6066
// The function returns appropriate HTTP error responses for invalid methods, missing required fields,
61-
// or internal processing errors. For successful requests, it renders either a complete chatbox template
62-
// for new chats or individual message templates for existing chats.
67+
// or internal processing errors. For successful requests, it renders the appropriate templates
68+
// with messages marked with correct streaming states.
6369
func (m Main) HandleChats(w http.ResponseWriter, r *http.Request) {
6470
if r.Method != http.MethodPost {
6571
m.logger.Error("Method not allowed", slog.String("method", r.Method))
6672
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
6773
return
6874
}
6975

70-
msg := r.FormValue("message")
71-
if msg == "" {
72-
m.logger.Error("Message is required")
73-
http.Error(w, "Message is required", http.StatusBadRequest)
74-
return
75-
}
76-
7776
var err error
78-
7977
chatID := r.FormValue("chat_id")
80-
// We track if this is a new chat to determine the appropriate template rendering strategy
8178
isNewChat := false
79+
8280
if chatID == "" {
8381
chatID, err = m.newChat()
8482
if err != nil {
@@ -95,25 +93,47 @@ func (m Main) HandleChats(w http.ResponseWriter, r *http.Request) {
9593
}
9694
}
9795

98-
// We create two messages: user's input and a placeholder for AI response
99-
um := models.Message{
100-
ID: uuid.New().String(),
101-
Role: models.RoleUser,
102-
Contents: []models.Content{
103-
{
104-
Type: models.ContentTypeText,
105-
Text: msg,
106-
},
107-
},
108-
Timestamp: time.Now(),
96+
var userMessages []models.Message
97+
var addedMessageIDs []string
98+
var firstMessageForTitle string
99+
100+
// Process input based on type (prompt or regular message)
101+
promptName := r.FormValue("prompt_name")
102+
if promptName != "" {
103+
// Handle prompt-based input
104+
promptArgs := r.FormValue("prompt_args")
105+
userMessages, firstMessageForTitle, err = m.processPromptInput(r.Context(), promptName, promptArgs)
106+
if err != nil {
107+
m.logger.Error("Failed to process prompt",
108+
slog.String("promptName", promptName),
109+
slog.String(errLoggerKey, err.Error()))
110+
http.Error(w, err.Error(), http.StatusInternalServerError)
111+
return
112+
}
113+
} else {
114+
// Handle regular message input
115+
msg := r.FormValue("message")
116+
if msg == "" {
117+
m.logger.Error("Message is required")
118+
http.Error(w, "Message is required", http.StatusBadRequest)
119+
return
120+
}
121+
122+
firstMessageForTitle = msg
123+
userMessages = []models.Message{m.processUserMessage(msg)}
109124
}
110-
userMsgID, err := m.store.AddMessage(r.Context(), chatID, um)
111-
if err != nil {
112-
m.logger.Error("Failed to add user message",
113-
slog.String("message", fmt.Sprintf("%+v", um)),
114-
slog.String(errLoggerKey, err.Error()))
115-
http.Error(w, err.Error(), http.StatusInternalServerError)
116-
return
125+
126+
// Add all user messages to the chat
127+
for _, msg := range userMessages {
128+
msgID, err := m.store.AddMessage(r.Context(), chatID, msg)
129+
if err != nil {
130+
m.logger.Error("Failed to add message",
131+
slog.String("message", fmt.Sprintf("%+v", msg)),
132+
slog.String(errLoggerKey, err.Error()))
133+
http.Error(w, err.Error(), http.StatusInternalServerError)
134+
return
135+
}
136+
addedMessageIDs = append(addedMessageIDs, msgID)
117137
}
118138

119139
// Initialize empty AI message to be streamed later
@@ -144,80 +164,172 @@ func (m Main) HandleChats(w http.ResponseWriter, r *http.Request) {
144164
go m.chat(chatID, messages)
145165

146166
if isNewChat {
147-
go m.generateChatTitle(chatID, msg)
167+
go m.generateChatTitle(chatID, firstMessageForTitle)
168+
m.renderNewChatResponse(w, chatID, messages, aiMsgID)
169+
return
170+
}
148171

149-
// For new chats, we prepare all messages with appropriate streaming states
150-
msgs := make([]message, len(messages))
151-
for i := range messages {
152-
// Mark only the AI message as "loading", others as "ended"
153-
streamingState := "ended"
154-
if messages[i].ID == aiMsgID {
155-
streamingState = "loading"
156-
}
157-
content, err := models.RenderContents(messages[i].Contents)
158-
if err != nil {
159-
m.logger.Error("Failed to render contents",
160-
slog.String("message", fmt.Sprintf("%+v", messages[i])),
161-
slog.String(errLoggerKey, err.Error()))
162-
http.Error(w, err.Error(), http.StatusInternalServerError)
163-
return
164-
}
165-
msgs[i] = message{
166-
ID: messages[i].ID,
167-
Role: string(messages[i].Role),
168-
Content: content,
169-
Timestamp: messages[i].Timestamp,
170-
StreamingState: streamingState,
171-
}
172+
// For existing chats, render each message separately
173+
m.renderExistingChatResponse(w, messages, addedMessageIDs, am, aiMsgID)
174+
}
175+
176+
// processPromptInput handles prompt-based inputs, extracting arguments and retrieving
177+
// prompt messages from the MCP client.
178+
func (m Main) processPromptInput(ctx context.Context, promptName, promptArgs string) ([]models.Message, string, error) {
179+
var args map[string]string
180+
if err := json.Unmarshal([]byte(promptArgs), &args); err != nil {
181+
return nil, "", fmt.Errorf("invalid prompt arguments: %w", err)
182+
}
183+
184+
// Get the prompt data directly from the server
185+
clientIdx, ok := m.promptsMap[promptName]
186+
if !ok {
187+
return nil, "", fmt.Errorf("prompt not found: %s", promptName)
188+
}
189+
190+
promptResult, err := m.mcpClients[clientIdx].GetPrompt(ctx, mcp.GetPromptParams{
191+
Name: promptName,
192+
Arguments: args,
193+
})
194+
if err != nil {
195+
return nil, "", fmt.Errorf("failed to get prompt: %w", err)
196+
}
197+
198+
// Convert prompt messages to our internal model format
199+
messages := make([]models.Message, 0, len(promptResult.Messages))
200+
firstMessageText := ""
201+
202+
for i, promptMsg := range promptResult.Messages {
203+
// For now, ignore non-text content
204+
if promptMsg.Content.Type != mcp.ContentTypeText {
205+
continue
172206
}
207+
content := promptMsg.Content.Text
208+
209+
// Save the first message for title generation
210+
if i == 0 {
211+
firstMessageText = content
212+
}
213+
214+
messages = append(messages, models.Message{
215+
ID: uuid.New().String(),
216+
Role: models.Role(promptMsg.Role),
217+
Contents: []models.Content{
218+
{
219+
Type: models.ContentTypeText,
220+
Text: content,
221+
},
222+
},
223+
Timestamp: time.Now(),
224+
})
225+
}
226+
227+
return messages, firstMessageText, nil
228+
}
229+
230+
// processUserMessage handles standard user message inputs.
231+
func (m Main) processUserMessage(message string) models.Message {
232+
return models.Message{
233+
ID: uuid.New().String(),
234+
Role: models.RoleUser,
235+
Contents: []models.Content{
236+
{
237+
Type: models.ContentTypeText,
238+
Text: message,
239+
},
240+
},
241+
Timestamp: time.Now(),
242+
}
243+
}
173244

174-
data := homePageData{
175-
CurrentChatID: chatID,
176-
Messages: msgs,
245+
// renderNewChatResponse renders the complete chatbox for new chats.
246+
func (m Main) renderNewChatResponse(w http.ResponseWriter, chatID string, messages []models.Message, aiMsgID string) {
247+
msgs := make([]message, len(messages))
248+
for i := range messages {
249+
// Mark only the AI message as "loading", others as "ended"
250+
streamingState := "ended"
251+
if messages[i].ID == aiMsgID {
252+
streamingState = "loading"
177253
}
178-
err = m.templates.ExecuteTemplate(w, "chatbox", data)
254+
content, err := models.RenderContents(messages[i].Contents)
179255
if err != nil {
256+
m.logger.Error("Failed to render contents",
257+
slog.String("message", fmt.Sprintf("%+v", messages[i])),
258+
slog.String(errLoggerKey, err.Error()))
180259
http.Error(w, err.Error(), http.StatusInternalServerError)
260+
return
261+
}
262+
msgs[i] = message{
263+
ID: messages[i].ID,
264+
Role: string(messages[i].Role),
265+
Content: content,
266+
Timestamp: messages[i].Timestamp,
267+
StreamingState: streamingState,
181268
}
182-
return
183269
}
184270

185-
userContent, err := models.RenderContents(um.Contents)
186-
if err != nil {
187-
m.logger.Error("Failed to render contents",
188-
slog.String("message", fmt.Sprintf("%+v", um)),
189-
slog.String(errLoggerKey, err.Error()))
190-
http.Error(w, err.Error(), http.StatusInternalServerError)
191-
return
271+
data := homePageData{
272+
CurrentChatID: chatID,
273+
Messages: msgs,
192274
}
193-
err = m.templates.ExecuteTemplate(w, "user_message", message{
194-
ID: userMsgID,
195-
Role: string(um.Role),
196-
Content: userContent,
197-
Timestamp: um.Timestamp,
198-
StreamingState: "ended",
199-
})
200-
if err != nil {
275+
if err := m.templates.ExecuteTemplate(w, "chatbox", data); err != nil {
201276
http.Error(w, err.Error(), http.StatusInternalServerError)
202-
return
277+
}
278+
}
279+
280+
// renderExistingChatResponse renders each message individually for existing chats.
281+
func (m Main) renderExistingChatResponse(w http.ResponseWriter, messages []models.Message, addedMessageIDs []string,
282+
aiMessage models.Message, aiMsgID string,
283+
) {
284+
for _, msgID := range addedMessageIDs {
285+
for i := range messages {
286+
if messages[i].ID == msgID {
287+
content, err := models.RenderContents(messages[i].Contents)
288+
if err != nil {
289+
m.logger.Error("Failed to render contents",
290+
slog.String("message", fmt.Sprintf("%+v", messages[i])),
291+
slog.String(errLoggerKey, err.Error()))
292+
http.Error(w, err.Error(), http.StatusInternalServerError)
293+
return
294+
}
295+
296+
templateName := "user_message"
297+
if messages[i].Role == models.RoleAssistant {
298+
templateName = "ai_message"
299+
}
300+
301+
if err := m.templates.ExecuteTemplate(w, templateName, message{
302+
ID: msgID,
303+
Role: string(messages[i].Role),
304+
Content: content,
305+
Timestamp: messages[i].Timestamp,
306+
StreamingState: "ended",
307+
}); err != nil {
308+
http.Error(w, err.Error(), http.StatusInternalServerError)
309+
return
310+
}
311+
break
312+
}
313+
}
203314
}
204315

205-
aiContent, err := models.RenderContents(am.Contents)
316+
// Render AI response message (always the last one added)
317+
aiContent, err := models.RenderContents(aiMessage.Contents)
206318
if err != nil {
207319
m.logger.Error("Failed to render contents",
208-
slog.String("message", fmt.Sprintf("%+v", am)),
320+
slog.String("message", fmt.Sprintf("%+v", aiMessage)),
209321
slog.String(errLoggerKey, err.Error()))
210322
http.Error(w, err.Error(), http.StatusInternalServerError)
211323
return
212324
}
213-
err = m.templates.ExecuteTemplate(w, "ai_message", message{
325+
326+
if err := m.templates.ExecuteTemplate(w, "ai_message", message{
214327
ID: aiMsgID,
215-
Role: string(am.Role),
328+
Role: string(aiMessage.Role),
216329
Content: aiContent,
217-
Timestamp: am.Timestamp,
330+
Timestamp: aiMessage.Timestamp,
218331
StreamingState: "loading",
219-
})
220-
if err != nil {
332+
}); err != nil {
221333
http.Error(w, err.Error(), http.StatusInternalServerError)
222334
}
223335
}

0 commit comments

Comments
 (0)