Skip to content

Commit 97a5f3e

Browse files
authored
Merge pull request #239 from docker/implement-exit-behavior
Add context cancellation support for Ctrl+C during model response
2 parents a9e4461 + 8b98ec5 commit 97a5f3e

File tree

2 files changed

+89
-8
lines changed

2 files changed

+89
-8
lines changed

cmd/cli/commands/run.go

Lines changed: 70 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@ package commands
22

33
import (
44
"bufio"
5+
"context"
56
"errors"
67
"fmt"
78
"io"
89
"os"
10+
"os/signal"
911
"strings"
12+
"syscall"
1013

1114
"github.com/charmbracelet/glamour"
1215
"github.com/docker/model-runner/cmd/cli/commands/completion"
@@ -201,8 +204,36 @@ func generateInteractiveWithReadline(cmd *cobra.Command, desktopClient *desktop.
201204
if sb.Len() > 0 && !multiline {
202205
userInput := sb.String()
203206

204-
if err := chatWithMarkdown(cmd, desktopClient, backend, model, userInput, apiKey); err != nil {
205-
cmd.PrintErr(handleClientError(err, "Failed to generate a response"))
207+
// Create a cancellable context for the chat request
208+
// This allows us to cancel the request if the user presses Ctrl+C during response generation
209+
chatCtx, cancelChat := context.WithCancel(cmd.Context())
210+
211+
// Set up signal handler to cancel the context on Ctrl+C
212+
sigChan := make(chan os.Signal, 1)
213+
signal.Notify(sigChan, syscall.SIGINT)
214+
go func() {
215+
select {
216+
case <-sigChan:
217+
cancelChat()
218+
case <-chatCtx.Done():
219+
// Context cancelled, exit goroutine
220+
}
221+
}()
222+
223+
err := chatWithMarkdownContext(chatCtx, cmd, desktopClient, backend, model, userInput, apiKey)
224+
225+
// Clean up signal handler
226+
signal.Stop(sigChan)
227+
// Do not close sigChan to avoid race condition
228+
cancelChat()
229+
230+
if err != nil {
231+
// Check if the error is due to context cancellation (Ctrl+C during response)
232+
if errors.Is(err, context.Canceled) {
233+
cmd.Println()
234+
} else {
235+
cmd.PrintErr(handleClientError(err, "Failed to generate a response"))
236+
}
206237
sb.Reset()
207238
continue
208239
}
@@ -233,8 +264,36 @@ func generateInteractiveBasic(cmd *cobra.Command, desktopClient *desktop.Client,
233264
continue
234265
}
235266

236-
if err := chatWithMarkdown(cmd, desktopClient, backend, model, userInput, apiKey); err != nil {
237-
cmd.PrintErr(handleClientError(err, "Failed to generate a response"))
267+
// Create a cancellable context for the chat request
268+
// This allows us to cancel the request if the user presses Ctrl+C during response generation
269+
chatCtx, cancelChat := context.WithCancel(cmd.Context())
270+
271+
// Set up signal handler to cancel the context on Ctrl+C
272+
sigChan := make(chan os.Signal, 1)
273+
signal.Notify(sigChan, syscall.SIGINT)
274+
go func() {
275+
select {
276+
case <-sigChan:
277+
cancelChat()
278+
case <-chatCtx.Done():
279+
// Context cancelled, exit goroutine
280+
// Context cancelled, exit goroutine
281+
}
282+
}()
283+
284+
err = chatWithMarkdownContext(chatCtx, cmd, desktopClient, backend, model, userInput, apiKey)
285+
286+
cancelChat()
287+
signal.Stop(sigChan)
288+
cancelChat()
289+
290+
if err != nil {
291+
// Check if the error is due to context cancellation (Ctrl+C during response)
292+
if errors.Is(err, context.Canceled) {
293+
fmt.Println("\nUse Ctrl + d or /bye to exit.")
294+
} else {
295+
cmd.PrintErr(handleClientError(err, "Failed to generate a response"))
296+
}
238297
continue
239298
}
240299

@@ -425,21 +484,26 @@ func renderMarkdown(content string) (string, error) {
425484

426485
// chatWithMarkdown performs chat and streams the response with selective markdown rendering.
427486
func chatWithMarkdown(cmd *cobra.Command, client *desktop.Client, backend, model, prompt, apiKey string) error {
487+
return chatWithMarkdownContext(cmd.Context(), cmd, client, backend, model, prompt, apiKey)
488+
}
489+
490+
// chatWithMarkdownContext performs chat with context support and streams the response with selective markdown rendering.
491+
func chatWithMarkdownContext(ctx context.Context, cmd *cobra.Command, client *desktop.Client, backend, model, prompt, apiKey string) error {
428492
colorMode, _ := cmd.Flags().GetString("color")
429493
useMarkdown := shouldUseMarkdown(colorMode)
430494
debug, _ := cmd.Flags().GetBool("debug")
431495

432496
if !useMarkdown {
433497
// Simple case: just stream as plain text
434-
return client.Chat(backend, model, prompt, apiKey, func(content string) {
498+
return client.ChatWithContext(ctx, backend, model, prompt, apiKey, func(content string) {
435499
cmd.Print(content)
436500
}, false)
437501
}
438502

439503
// For markdown: use streaming buffer to render code blocks as they complete
440504
markdownBuffer := NewStreamingMarkdownBuffer()
441505

442-
err := client.Chat(backend, model, prompt, apiKey, func(content string) {
506+
err := client.ChatWithContext(ctx, backend, model, prompt, apiKey, func(content string) {
443507
// Use the streaming markdown buffer to intelligently render content
444508
rendered, err := markdownBuffer.AddContent(content, true)
445509
if err != nil {

cmd/cli/desktop/desktop.go

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,11 @@ func (c *Client) fullModelID(id string) (string, error) {
366366

367367
// Chat performs a chat request and streams the response content with selective markdown rendering.
368368
func (c *Client) Chat(backend, model, prompt, apiKey string, outputFunc func(string), shouldUseMarkdown bool) error {
369+
return c.ChatWithContext(context.Background(), backend, model, prompt, apiKey, outputFunc, shouldUseMarkdown)
370+
}
371+
372+
// ChatWithContext performs a chat request with context support for cancellation and streams the response content with selective markdown rendering.
373+
func (c *Client) ChatWithContext(ctx context.Context, backend, model, prompt, apiKey string, outputFunc func(string), shouldUseMarkdown bool) error {
369374
model = normalizeHuggingFaceModelName(model)
370375
if !strings.Contains(strings.Trim(model, "/"), "/") {
371376
// Do an extra API call to check if the model parameter isn't a model ID.
@@ -397,7 +402,8 @@ func (c *Client) Chat(backend, model, prompt, apiKey string, outputFunc func(str
397402
completionsPath = inference.InferencePrefix + "/v1/chat/completions"
398403
}
399404

400-
resp, err := c.doRequestWithAuth(
405+
resp, err := c.doRequestWithAuthContext(
406+
ctx,
401407
http.MethodPost,
402408
completionsPath,
403409
bytes.NewReader(jsonData),
@@ -432,6 +438,13 @@ func (c *Client) Chat(backend, model, prompt, apiKey string, outputFunc func(str
432438

433439
scanner := bufio.NewScanner(resp.Body)
434440
for scanner.Scan() {
441+
// Check if context was cancelled
442+
select {
443+
case <-ctx.Done():
444+
return ctx.Err()
445+
default:
446+
}
447+
435448
line := scanner.Text()
436449
if line == "" {
437450
continue
@@ -755,7 +768,11 @@ func (c *Client) doRequest(method, path string, body io.Reader) (*http.Response,
755768

756769
// doRequestWithAuth is a helper function that performs HTTP requests with optional authentication
757770
func (c *Client) doRequestWithAuth(method, path string, body io.Reader, backend, apiKey string) (*http.Response, error) {
758-
req, err := http.NewRequest(method, c.modelRunner.URL(path), body)
771+
return c.doRequestWithAuthContext(context.Background(), method, path, body, backend, apiKey)
772+
}
773+
774+
func (c *Client) doRequestWithAuthContext(ctx context.Context, method, path string, body io.Reader, backend, apiKey string) (*http.Response, error) {
775+
req, err := http.NewRequestWithContext(ctx, method, c.modelRunner.URL(path), body)
759776
if err != nil {
760777
return nil, fmt.Errorf("error creating request: %w", err)
761778
}

0 commit comments

Comments
 (0)