diff --git a/cmd/cli/README.md b/cmd/cli/README.md index a3f8d628a..0c1ad3a8e 100644 --- a/cmd/cli/README.md +++ b/cmd/cli/README.md @@ -52,7 +52,6 @@ Run `./model --help` to see all commands and options. Or enter chat mode: ```bash ./model run llama.cpp -Interactive chat mode started. Type '/bye' to exit. > """ Tell me a joke. """ diff --git a/cmd/cli/commands/run.go b/cmd/cli/commands/run.go index a3e672f7e..38030fabd 100644 --- a/cmd/cli/commands/run.go +++ b/cmd/cli/commands/run.go @@ -11,6 +11,7 @@ import ( "github.com/charmbracelet/glamour" "github.com/docker/model-runner/cmd/cli/commands/completion" "github.com/docker/model-runner/cmd/cli/desktop" + "github.com/docker/model-runner/cmd/cli/readline" "github.com/fatih/color" "github.com/spf13/cobra" "golang.org/x/term" @@ -81,6 +82,167 @@ func readMultilineInput(cmd *cobra.Command, scanner *bufio.Scanner) (string, err return multilineInput.String(), nil } +// generateInteractiveWithReadline provides an enhanced interactive mode with readline support +func generateInteractiveWithReadline(cmd *cobra.Command, desktopClient *desktop.Client, backend, model, apiKey string) error { + usage := func() { + fmt.Fprintln(os.Stderr, "Available Commands:") + fmt.Fprintln(os.Stderr, " /bye Exit") + fmt.Fprintln(os.Stderr, " /?, /help Help for a command") + fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts") + fmt.Fprintln(os.Stderr, "") + fmt.Fprintln(os.Stderr, `Use """ to begin a multi-line message.`) + fmt.Fprintln(os.Stderr, "") + } + + usageShortcuts := func() { + fmt.Fprintln(os.Stderr, "Available keyboard shortcuts:") + fmt.Fprintln(os.Stderr, " Ctrl + a Move to the beginning of the line (Home)") + fmt.Fprintln(os.Stderr, " Ctrl + e Move to the end of the line (End)") + fmt.Fprintln(os.Stderr, " Alt + b Move back (left) one word") + fmt.Fprintln(os.Stderr, " Alt + f Move forward (right) one word") + fmt.Fprintln(os.Stderr, " Ctrl + k Delete the sentence after the cursor") + fmt.Fprintln(os.Stderr, " Ctrl + u Delete the sentence before the cursor") + fmt.Fprintln(os.Stderr, " Ctrl + w Delete the word before the cursor") + fmt.Fprintln(os.Stderr, "") + fmt.Fprintln(os.Stderr, " Ctrl + l Clear the screen") + fmt.Fprintln(os.Stderr, " Ctrl + c Stop the model from responding") + fmt.Fprintln(os.Stderr, " Ctrl + d Exit (/bye)") + fmt.Fprintln(os.Stderr, "") + } + + scanner, err := readline.New(readline.Prompt{ + Prompt: "> ", + AltPrompt: "... ", + Placeholder: "Send a message (/? for help)", + AltPlaceholder: `Use """ to end multi-line input`, + }) + if err != nil { + // Fall back to basic input mode if readline initialization fails + return generateInteractiveBasic(cmd, desktopClient, backend, model, apiKey) + } + + // Disable history if the environment variable is set + if os.Getenv("DOCKER_MODEL_NOHISTORY") != "" { + scanner.HistoryDisable() + } + + fmt.Print(readline.StartBracketedPaste) + defer fmt.Printf(readline.EndBracketedPaste) + + var sb strings.Builder + var multiline bool + + for { + line, err := scanner.Readline() + switch { + case errors.Is(err, io.EOF): + fmt.Println() + return nil + case errors.Is(err, readline.ErrInterrupt): + if line == "" { + fmt.Println("\nUse Ctrl + d or /bye to exit.") + } + + scanner.Prompt.UseAlt = false + sb.Reset() + + continue + case err != nil: + return err + } + + switch { + case multiline: + // check if there's a multiline terminating string + before, ok := strings.CutSuffix(line, `"""`) + sb.WriteString(before) + if !ok { + fmt.Fprintln(&sb) + continue + } + + multiline = false + scanner.Prompt.UseAlt = false + case strings.HasPrefix(line, `"""`): + line := strings.TrimPrefix(line, `"""`) + line, ok := strings.CutSuffix(line, `"""`) + sb.WriteString(line) + if !ok { + // no multiline terminating string; need more input + fmt.Fprintln(&sb) + multiline = true + scanner.Prompt.UseAlt = true + } + case scanner.Pasting: + fmt.Fprintln(&sb, line) + continue + case strings.HasPrefix(line, "/help"), strings.HasPrefix(line, "/?"): + args := strings.Fields(line) + if len(args) > 1 { + switch args[1] { + case "shortcut", "shortcuts": + usageShortcuts() + default: + usage() + } + } else { + usage() + } + continue + case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"): + return nil + case strings.HasPrefix(line, "/"): + fmt.Printf("Unknown command '%s'. Type /? for help\n", strings.Fields(line)[0]) + continue + default: + sb.WriteString(line) + } + + if sb.Len() > 0 && !multiline { + userInput := sb.String() + + if err := chatWithMarkdown(cmd, desktopClient, backend, model, userInput, apiKey); err != nil { + cmd.PrintErr(handleClientError(err, "Failed to generate a response")) + sb.Reset() + continue + } + + cmd.Println() + sb.Reset() + } + } +} + +// generateInteractiveBasic provides a basic interactive mode (fallback) +func generateInteractiveBasic(cmd *cobra.Command, desktopClient *desktop.Client, backend, model, apiKey string) error { + scanner := bufio.NewScanner(os.Stdin) + for { + userInput, err := readMultilineInput(cmd, scanner) + if err != nil { + if err.Error() == "EOF" { + break + } + return fmt.Errorf("Error reading input: %v", err) + } + + if strings.ToLower(strings.TrimSpace(userInput)) == "/bye" { + break + } + + if strings.TrimSpace(userInput) == "" { + continue + } + + if err := chatWithMarkdown(cmd, desktopClient, backend, model, userInput, apiKey); err != nil { + cmd.PrintErr(handleClientError(err, "Failed to generate a response")) + continue + } + + cmd.Println() + } + return nil +} + var ( markdownRenderer *glamour.TermRenderer lastWidth int @@ -389,36 +551,13 @@ func newRunCmd() *cobra.Command { return nil } - scanner := bufio.NewScanner(os.Stdin) - cmd.Println("Interactive chat mode started. Type '/bye' to exit.") - - for { - userInput, err := readMultilineInput(cmd, scanner) - if err != nil { - if err.Error() == "EOF" { - cmd.Println("\nChat session ended.") - break - } - return fmt.Errorf("Error reading input: %v", err) - } - - if strings.ToLower(strings.TrimSpace(userInput)) == "/bye" { - cmd.Println("Chat session ended.") - break - } - - if strings.TrimSpace(userInput) == "" { - continue - } - - if err := chatWithMarkdown(cmd, desktopClient, backend, model, userInput, apiKey); err != nil { - cmd.PrintErr(handleClientError(err, "Failed to generate a response")) - continue - } - - cmd.Println() + // Use enhanced readline-based interactive mode when terminal is available + if term.IsTerminal(int(os.Stdin.Fd())) { + return generateInteractiveWithReadline(cmd, desktopClient, backend, model, apiKey) } - return nil + + // Fall back to basic mode if not a terminal + return generateInteractiveBasic(cmd, desktopClient, backend, model, apiKey) }, ValidArgsFunction: completion.ModelNames(getDesktopClient, 1), } diff --git a/cmd/cli/docs/reference/docker_model_run.yaml b/cmd/cli/docs/reference/docker_model_run.yaml index 44b1340fd..a24424b33 100644 --- a/cmd/cli/docs/reference/docker_model_run.yaml +++ b/cmd/cli/docs/reference/docker_model_run.yaml @@ -72,11 +72,9 @@ examples: |- Output: ```console - Interactive chat mode started. Type '/bye' to exit. > Hi Hi there! It's SmolLM, AI assistant. How can I help you today? > /bye - Chat session ended. ``` deprecated: false hidden: false diff --git a/cmd/cli/docs/reference/model_run.md b/cmd/cli/docs/reference/model_run.md index 6b0c3cc6a..444ffe0a7 100644 --- a/cmd/cli/docs/reference/model_run.md +++ b/cmd/cli/docs/reference/model_run.md @@ -45,9 +45,7 @@ docker model run ai/smollm2 Output: ```console -Interactive chat mode started. Type '/bye' to exit. > Hi Hi there! It's SmolLM, AI assistant. How can I help you today? > /bye -Chat session ended. ``` diff --git a/cmd/cli/go.mod b/cmd/cli/go.mod index 63b603150..966839de8 100644 --- a/cmd/cli/go.mod +++ b/cmd/cli/go.mod @@ -11,9 +11,11 @@ require ( github.com/docker/go-connections v0.5.0 github.com/docker/go-units v0.5.0 github.com/docker/model-runner v0.0.0 + github.com/emirpasic/gods/v2 v2.0.0-alpha github.com/fatih/color v1.18.0 github.com/google/go-containerregistry v0.20.6 github.com/mattn/go-isatty v0.0.20 + github.com/mattn/go-runewidth v0.0.16 github.com/nxadm/tail v1.4.8 github.com/olekukonko/tablewriter v0.0.5 github.com/pkg/errors v0.9.1 @@ -23,6 +25,7 @@ require ( go.opentelemetry.io/otel v1.37.0 go.uber.org/mock v0.5.0 golang.org/x/sync v0.15.0 + golang.org/x/sys v0.35.0 golang.org/x/term v0.32.0 ) @@ -76,7 +79,6 @@ require ( github.com/kolesnikovae/go-winjob v1.0.0 // indirect github.com/lucasb-eyer/go-colorful v1.2.0 // indirect github.com/mattn/go-colorable v0.1.13 // indirect - github.com/mattn/go-runewidth v0.0.16 // indirect github.com/mattn/go-shellwords v1.0.12 // indirect github.com/microcosm-cc/bluemonday v1.0.27 // indirect github.com/mitchellh/go-homedir v1.1.0 // indirect @@ -120,7 +122,6 @@ require ( golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8 // indirect golang.org/x/mod v0.25.0 // indirect golang.org/x/net v0.41.0 // indirect - golang.org/x/sys v0.35.0 // indirect golang.org/x/text v0.26.0 // indirect golang.org/x/time v0.9.0 // indirect golang.org/x/tools v0.34.0 // indirect diff --git a/cmd/cli/go.sum b/cmd/cli/go.sum index 71a7638e0..065b11341 100644 --- a/cmd/cli/go.sum +++ b/cmd/cli/go.sum @@ -115,6 +115,8 @@ github.com/elastic/go-sysinfo v1.15.3 h1:W+RnmhKFkqPTCRoFq2VCTmsT4p/fwpo+3gKNQsn github.com/elastic/go-sysinfo v1.15.3/go.mod h1:K/cNrqYTDrSoMh2oDkYEMS2+a72GRxMvNP+GC+vRIlo= github.com/elastic/go-windows v1.0.2 h1:yoLLsAsV5cfg9FLhZ9EXZ2n2sQFKeDYrHenkcivY4vI= github.com/elastic/go-windows v1.0.2/go.mod h1:bGcDpBzXgYSqM0Gx3DM4+UxFj300SZLixie9u9ixLM8= +github.com/emirpasic/gods/v2 v2.0.0-alpha h1:dwFlh8pBg1VMOXWGipNMRt8v96dKAIvBehtCt6OtunU= +github.com/emirpasic/gods/v2 v2.0.0-alpha/go.mod h1:W0y4M2dtBB9U5z3YlghmpuUhiaZT2h6yoeE+C1sCp6A= github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5/go.mod h1:a2zkGnVExMxdzMo3M0Hi/3sEU+cWnZpSni0O6/Yb/P0= github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= diff --git a/cmd/cli/readline/buffer.go b/cmd/cli/readline/buffer.go new file mode 100644 index 000000000..52dc70526 --- /dev/null +++ b/cmd/cli/readline/buffer.go @@ -0,0 +1,527 @@ +package readline + +import ( + "fmt" + "os" + + "github.com/emirpasic/gods/v2/lists/arraylist" + "github.com/mattn/go-runewidth" + "golang.org/x/term" +) + +type Buffer struct { + DisplayPos int + Pos int + Buf *arraylist.List[rune] + // LineHasSpace is an arraylist of bools to keep track of whether a line has a space at the end + LineHasSpace *arraylist.List[bool] + Prompt *Prompt + LineWidth int + Width int + Height int +} + +func NewBuffer(prompt *Prompt) (*Buffer, error) { + fd := int(os.Stdout.Fd()) + width, height := 80, 24 + if termWidth, termHeight, err := term.GetSize(fd); err == nil { + width, height = termWidth, termHeight + } + + lwidth := width - len(prompt.prompt()) + + b := &Buffer{ + DisplayPos: 0, + Pos: 0, + Buf: arraylist.New[rune](), + LineHasSpace: arraylist.New[bool](), + Prompt: prompt, + Width: width, + Height: height, + LineWidth: lwidth, + } + + return b, nil +} + +func (b *Buffer) GetLineSpacing(line int) bool { + hasSpace, _ := b.LineHasSpace.Get(line) + return hasSpace +} + +func (b *Buffer) MoveLeft() { + if b.Pos > 0 { + // asserts that we retrieve a rune + if r, ok := b.Buf.Get(b.Pos - 1); ok { + rLength := runewidth.RuneWidth(r) + + if b.DisplayPos%b.LineWidth == 0 { + fmt.Print(CursorUp + CursorBOL + CursorRightN(b.Width)) + if rLength == 2 { + fmt.Print(CursorLeft) + } + + line := b.DisplayPos/b.LineWidth - 1 + hasSpace := b.GetLineSpacing(line) + if hasSpace { + b.DisplayPos -= 1 + fmt.Print(CursorLeft) + } + } else { + fmt.Print(CursorLeftN(rLength)) + } + + b.Pos -= 1 + b.DisplayPos -= rLength + } + } +} + +func (b *Buffer) MoveLeftWord() { + if b.Pos > 0 { + var foundNonspace bool + for { + v, _ := b.Buf.Get(b.Pos - 1) + if v == ' ' { + if foundNonspace { + break + } + } else { + foundNonspace = true + } + b.MoveLeft() + + if b.Pos == 0 { + break + } + } + } +} + +func (b *Buffer) MoveRight() { + if b.Pos < b.Buf.Size() { + if r, ok := b.Buf.Get(b.Pos); ok { + rLength := runewidth.RuneWidth(r) + b.Pos += 1 + hasSpace := b.GetLineSpacing(b.DisplayPos / b.LineWidth) + b.DisplayPos += rLength + + if b.DisplayPos%b.LineWidth == 0 { + fmt.Print(CursorDown + CursorBOL + CursorRightN(len(b.Prompt.prompt()))) + } else if (b.DisplayPos-rLength)%b.LineWidth == b.LineWidth-1 && hasSpace { + fmt.Print(CursorDown + CursorBOL + CursorRightN(len(b.Prompt.prompt())+rLength)) + b.DisplayPos += 1 + } else if b.LineHasSpace.Size() > 0 && b.DisplayPos%b.LineWidth == b.LineWidth-1 && hasSpace { + fmt.Print(CursorDown + CursorBOL + CursorRightN(len(b.Prompt.prompt()))) + b.DisplayPos += 1 + } else { + fmt.Print(CursorRightN(rLength)) + } + } + } +} + +func (b *Buffer) MoveRightWord() { + if b.Pos < b.Buf.Size() { + for { + b.MoveRight() + v, _ := b.Buf.Get(b.Pos) + if v == ' ' { + break + } + + if b.Pos == b.Buf.Size() { + break + } + } + } +} + +func (b *Buffer) MoveToStart() { + if b.Pos > 0 { + currLine := b.DisplayPos / b.LineWidth + if currLine > 0 { + for range currLine { + fmt.Print(CursorUp) + } + } + fmt.Print(CursorBOL + CursorRightN(len(b.Prompt.prompt()))) + b.Pos = 0 + b.DisplayPos = 0 + } +} + +func (b *Buffer) MoveToEnd() { + if b.Pos < b.Buf.Size() { + currLine := b.DisplayPos / b.LineWidth + totalLines := b.DisplaySize() / b.LineWidth + if currLine < totalLines { + for range totalLines - currLine { + fmt.Print(CursorDown) + } + remainder := b.DisplaySize() % b.LineWidth + fmt.Print(CursorBOL + CursorRightN(len(b.Prompt.prompt())+remainder)) + } else { + fmt.Print(CursorRightN(b.DisplaySize() - b.DisplayPos)) + } + + b.Pos = b.Buf.Size() + b.DisplayPos = b.DisplaySize() + } +} + +func (b *Buffer) DisplaySize() int { + sum := 0 + for i := range b.Buf.Size() { + if r, ok := b.Buf.Get(i); ok { + sum += runewidth.RuneWidth(r) + } + } + + return sum +} + +func (b *Buffer) Add(r rune) { + if b.Pos == b.Buf.Size() { + b.AddChar(r, false) + } else { + b.AddChar(r, true) + } +} + +func (b *Buffer) AddChar(r rune, insert bool) { + rLength := runewidth.RuneWidth(r) + b.DisplayPos += rLength + + if b.Pos > 0 { + if b.DisplayPos%b.LineWidth == 0 { + fmt.Printf("%c", r) + fmt.Printf("\n%s", b.Prompt.AltPrompt) + + if insert { + b.LineHasSpace.Set(b.DisplayPos/b.LineWidth-1, false) + } else { + b.LineHasSpace.Add(false) + } + + // this case occurs when a double-width rune crosses the line boundary + } else if b.DisplayPos%b.LineWidth < (b.DisplayPos-rLength)%b.LineWidth { + if insert { + fmt.Print(ClearToEOL) + } + fmt.Printf("\n%s", b.Prompt.AltPrompt) + b.DisplayPos += 1 + fmt.Printf("%c", r) + + if insert { + b.LineHasSpace.Set(b.DisplayPos/b.LineWidth-1, true) + } else { + b.LineHasSpace.Add(true) + } + } else { + fmt.Printf("%c", r) + } + } else { + fmt.Printf("%c", r) + } + + if insert { + b.Buf.Insert(b.Pos, r) + } else { + b.Buf.Add(r) + } + + b.Pos += 1 + + if insert { + b.drawRemaining() + } +} + +func (b *Buffer) countRemainingLineWidth(place int) int { + var sum int + counter := -1 + var prevLen int + + for place <= b.LineWidth { + counter += 1 + sum += prevLen + if r, ok := b.Buf.Get(b.Pos + counter); ok { + place += runewidth.RuneWidth(r) + prevLen = len(string(r)) + } else { + break + } + } + + return sum +} + +func (b *Buffer) drawRemaining() { + var place int + remainingText := b.StringN(b.Pos) + if b.Pos > 0 { + place = b.DisplayPos % b.LineWidth + } + fmt.Print(CursorHide) + + // render the rest of the current line + currLineLength := b.countRemainingLineWidth(place) + + currLine := remainingText[:min(currLineLength, len(remainingText))] + currLineSpace := runewidth.StringWidth(currLine) + remLength := runewidth.StringWidth(remainingText) + + if len(currLine) > 0 { + fmt.Print(ClearToEOL + currLine + CursorLeftN(currLineSpace)) + } else { + fmt.Print(ClearToEOL) + } + + if currLineSpace != b.LineWidth-place && currLineSpace != remLength { + b.LineHasSpace.Set(b.DisplayPos/b.LineWidth, true) + } else if currLineSpace != b.LineWidth-place { + b.LineHasSpace.Remove(b.DisplayPos / b.LineWidth) + } else { + b.LineHasSpace.Set(b.DisplayPos/b.LineWidth, false) + } + + if (b.DisplayPos+currLineSpace)%b.LineWidth == 0 && currLine == remainingText { + fmt.Print(CursorRightN(currLineSpace)) + fmt.Printf("\n%s", b.Prompt.AltPrompt) + fmt.Print(CursorUp + CursorBOL + CursorRightN(b.Width-currLineSpace)) + } + + // render the other lines + if remLength > currLineSpace { + remaining := (remainingText[len(currLine):]) + var totalLines int + var displayLength int + var lineLength int = currLineSpace + + for _, c := range remaining { + if displayLength == 0 || (displayLength+runewidth.RuneWidth(c))%b.LineWidth < displayLength%b.LineWidth { + fmt.Printf("\n%s", b.Prompt.AltPrompt) + totalLines += 1 + + if displayLength != 0 { + if lineLength == b.LineWidth { + b.LineHasSpace.Set(b.DisplayPos/b.LineWidth+totalLines-1, false) + } else { + b.LineHasSpace.Set(b.DisplayPos/b.LineWidth+totalLines-1, true) + } + } + + lineLength = 0 + } + + displayLength += runewidth.RuneWidth(c) + lineLength += runewidth.RuneWidth(c) + fmt.Printf("%c", c) + } + fmt.Print(ClearToEOL + CursorUpN(totalLines) + CursorBOL + CursorRightN(b.Width-currLineSpace)) + + hasSpace := b.GetLineSpacing(b.DisplayPos / b.LineWidth) + + if hasSpace && b.DisplayPos%b.LineWidth != b.LineWidth-1 { + fmt.Print(CursorLeft) + } + } + + fmt.Print(CursorShow) +} + +func (b *Buffer) Remove() { + if b.Buf.Size() > 0 && b.Pos > 0 { + if r, ok := b.Buf.Get(b.Pos - 1); ok { + rLength := runewidth.RuneWidth(r) + hasSpace := b.GetLineSpacing(b.DisplayPos/b.LineWidth - 1) + + if b.DisplayPos%b.LineWidth == 0 { + // if the user backspaces over the word boundary, do this magic to clear the line + // and move to the end of the previous line + fmt.Print(CursorBOL + ClearToEOL + CursorUp + CursorBOL + CursorRightN(b.Width)) + + if b.DisplaySize()%b.LineWidth < (b.DisplaySize()-rLength)%b.LineWidth { + b.LineHasSpace.Remove(b.DisplayPos/b.LineWidth - 1) + } + + if hasSpace { + b.DisplayPos -= 1 + fmt.Print(CursorLeft) + } + + if rLength == 2 { + fmt.Print(CursorLeft + " " + CursorLeftN(2)) + } else { + fmt.Print(" " + CursorLeft) + } + } else if (b.DisplayPos-rLength)%b.LineWidth == 0 && hasSpace { + fmt.Print(CursorBOL + ClearToEOL + CursorUp + CursorBOL + CursorRightN(b.Width)) + + if b.Pos == b.Buf.Size() { + b.LineHasSpace.Remove(b.DisplayPos/b.LineWidth - 1) + } + b.DisplayPos -= 1 + } else { + fmt.Print(CursorLeftN(rLength)) + for range rLength { + fmt.Print(" ") + } + fmt.Print(CursorLeftN(rLength)) + } + + var eraseExtraLine bool + if (b.DisplaySize()-1)%b.LineWidth == 0 || (rLength == 2 && ((b.DisplaySize()-2)%b.LineWidth == 0)) || b.DisplaySize()%b.LineWidth == 0 { + eraseExtraLine = true + } + + b.Pos -= 1 + b.DisplayPos -= rLength + b.Buf.Remove(b.Pos) + + if b.Pos < b.Buf.Size() { + b.drawRemaining() + // this erases a line which is left over when backspacing in the middle of a line and there + // are trailing characters which go over the line width boundary + if eraseExtraLine { + remainingLines := (b.DisplaySize() - b.DisplayPos) / b.LineWidth + fmt.Print(CursorDownN(remainingLines+1) + CursorBOL + ClearToEOL) + place := b.DisplayPos % b.LineWidth + fmt.Print(CursorUpN(remainingLines+1) + CursorRightN(place+len(b.Prompt.prompt()))) + } + } + } + } +} + +func (b *Buffer) Delete() { + if b.Buf.Size() > 0 && b.Pos < b.Buf.Size() { + b.Buf.Remove(b.Pos) + b.drawRemaining() + if b.DisplaySize()%b.LineWidth == 0 { + if b.DisplayPos != b.DisplaySize() { + remainingLines := (b.DisplaySize() - b.DisplayPos) / b.LineWidth + fmt.Print(CursorDownN(remainingLines) + CursorBOL + ClearToEOL) + place := b.DisplayPos % b.LineWidth + fmt.Print(CursorUpN(remainingLines) + CursorRightN(place+len(b.Prompt.prompt()))) + } + } + } +} + +func (b *Buffer) DeleteBefore() { + if b.Pos > 0 { + for cnt := b.Pos - 1; cnt >= 0; cnt-- { + b.Remove() + } + } +} + +func (b *Buffer) DeleteRemaining() { + if b.DisplaySize() > 0 && b.Pos < b.DisplaySize() { + charsToDel := b.Buf.Size() - b.Pos + for range charsToDel { + b.Delete() + } + } +} + +func (b *Buffer) DeleteWord() { + if b.Buf.Size() > 0 && b.Pos > 0 { + var foundNonspace bool + for { + v, _ := b.Buf.Get(b.Pos - 1) + if v == ' ' { + if !foundNonspace { + b.Remove() + } else { + break + } + } else { + foundNonspace = true + b.Remove() + } + + if b.Pos == 0 { + break + } + } + } +} + +func (b *Buffer) ClearScreen() { + fmt.Print(ClearScreen + CursorReset + b.Prompt.prompt()) + if b.IsEmpty() { + ph := b.Prompt.placeholder() + fmt.Print(ColorGrey + ph + CursorLeftN(len(ph)) + ColorDefault) + } else { + currPos := b.DisplayPos + currIndex := b.Pos + b.Pos = 0 + b.DisplayPos = 0 + b.drawRemaining() + fmt.Print(CursorReset + CursorRightN(len(b.Prompt.prompt()))) + if currPos > 0 { + targetLine := currPos / b.LineWidth + if targetLine > 0 { + for range targetLine { + fmt.Print(CursorDown) + } + } + remainder := currPos % b.LineWidth + if remainder > 0 { + fmt.Print(CursorRightN(remainder)) + } + if currPos%b.LineWidth == 0 { + fmt.Print(CursorBOL + b.Prompt.AltPrompt) + } + } + b.Pos = currIndex + b.DisplayPos = currPos + } +} + +func (b *Buffer) IsEmpty() bool { + return b.Buf.Empty() +} + +func (b *Buffer) Replace(r []rune) { + b.DisplayPos = 0 + b.Pos = 0 + lineNums := b.DisplaySize() / b.LineWidth + + b.Buf.Clear() + + fmt.Print(CursorBOL + ClearToEOL) + + for range lineNums { + fmt.Print(CursorUp + CursorBOL + ClearToEOL) + } + + fmt.Print(CursorBOL + b.Prompt.prompt()) + + for _, c := range r { + b.Add(c) + } +} + +func (b *Buffer) String() string { + return b.StringN(0) +} + +func (b *Buffer) StringN(n int) string { + return b.StringNM(n, 0) +} + +func (b *Buffer) StringNM(n, m int) string { + var s string + if m == 0 { + m = b.Buf.Size() + } + for cnt := n; cnt < m; cnt++ { + c, _ := b.Buf.Get(cnt) + s += string(c) + } + return s +} diff --git a/cmd/cli/readline/errors.go b/cmd/cli/readline/errors.go new file mode 100644 index 000000000..bb3fbd473 --- /dev/null +++ b/cmd/cli/readline/errors.go @@ -0,0 +1,15 @@ +package readline + +import ( + "errors" +) + +var ErrInterrupt = errors.New("Interrupt") + +type InterruptError struct { + Line []rune +} + +func (*InterruptError) Error() string { + return "Interrupted" +} diff --git a/cmd/cli/readline/history.go b/cmd/cli/readline/history.go new file mode 100644 index 000000000..f4784f5e7 --- /dev/null +++ b/cmd/cli/readline/history.go @@ -0,0 +1,151 @@ +package readline + +import ( + "bufio" + "errors" + "fmt" + "io" + "os" + "path/filepath" + "strings" + + "github.com/emirpasic/gods/v2/lists/arraylist" +) + +type History struct { + Buf *arraylist.List[string] + Autosave bool + Pos int + Limit int + Filename string + Enabled bool +} + +func NewHistory() (*History, error) { + h := &History{ + Buf: arraylist.New[string](), + Limit: 100, // resizeme + Autosave: true, + Enabled: true, + } + + err := h.Init() + if err != nil { + return nil, err + } + + return h, nil +} + +func (h *History) Init() error { + home, err := os.UserHomeDir() + if err != nil { + return err + } + + path := filepath.Join(home, ".docker", "model", "history") + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return err + } + + h.Filename = path + + f, err := os.OpenFile(path, os.O_CREATE|os.O_RDONLY, 0o600) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return nil + } + return err + } + defer f.Close() + + r := bufio.NewReader(f) + for { + line, err := r.ReadString('\n') + if err != nil { + if errors.Is(err, io.EOF) { + break + } + return err + } + + line = strings.TrimSpace(line) + if len(line) == 0 { + continue + } + + h.Add(line) + } + + return nil +} + +func (h *History) Add(s string) { + h.Buf.Add(s) + h.Compact() + h.Pos = h.Size() + if h.Autosave { + _ = h.Save() + } +} + +func (h *History) Compact() { + s := h.Buf.Size() + if s > h.Limit { + for range s - h.Limit { + h.Buf.Remove(0) + } + } +} + +func (h *History) Clear() { + h.Buf.Clear() +} + +func (h *History) Prev() (line string) { + if h.Pos > 0 { + h.Pos -= 1 + } + line, _ = h.Buf.Get(h.Pos) + return line +} + +func (h *History) Next() (line string) { + if h.Pos < h.Buf.Size() { + h.Pos += 1 + line, _ = h.Buf.Get(h.Pos) + } + return line +} + +func (h *History) Size() int { + return h.Buf.Size() +} + +func (h *History) Save() error { + if !h.Enabled { + return nil + } + + tmpFile := h.Filename + ".tmp" + + f, err := os.OpenFile(tmpFile, os.O_CREATE|os.O_WRONLY|os.O_TRUNC|os.O_APPEND, 0o600) + if err != nil { + return err + } + defer f.Close() + + buf := bufio.NewWriter(f) + for cnt := range h.Size() { + line, _ := h.Buf.Get(cnt) + fmt.Fprintln(buf, line) + } + buf.Flush() + f.Close() + + if err = os.Rename(tmpFile, h.Filename); err != nil { + return err + } + + return nil +} diff --git a/cmd/cli/readline/readline.go b/cmd/cli/readline/readline.go new file mode 100644 index 000000000..9252f3253 --- /dev/null +++ b/cmd/cli/readline/readline.go @@ -0,0 +1,299 @@ +package readline + +import ( + "bufio" + "fmt" + "io" + "os" +) + +type Prompt struct { + Prompt string + AltPrompt string + Placeholder string + AltPlaceholder string + UseAlt bool +} + +func (p *Prompt) prompt() string { + if p.UseAlt { + return p.AltPrompt + } + return p.Prompt +} + +func (p *Prompt) placeholder() string { + if p.UseAlt { + return p.AltPlaceholder + } + return p.Placeholder +} + +type Terminal struct { + outchan chan rune + rawmode bool + termios any +} + +type Instance struct { + Prompt *Prompt + Terminal *Terminal + History *History + Pasting bool +} + +func New(prompt Prompt) (*Instance, error) { + term, err := NewTerminal() + if err != nil { + return nil, err + } + + history, err := NewHistory() + if err != nil { + return nil, err + } + + return &Instance{ + Prompt: &prompt, + Terminal: term, + History: history, + }, nil +} + +func (i *Instance) Readline() (string, error) { + if !i.Terminal.rawmode { + fd := os.Stdin.Fd() + termios, err := SetRawMode(fd) + if err != nil { + return "", err + } + i.Terminal.rawmode = true + i.Terminal.termios = termios + } + + prompt := i.Prompt.prompt() + if i.Pasting { + // force alt prompt when pasting + prompt = i.Prompt.AltPrompt + } + fmt.Print(prompt) + + defer func() { + fd := os.Stdin.Fd() + //nolint:errcheck + UnsetRawMode(fd, i.Terminal.termios) + i.Terminal.rawmode = false + }() + + buf, _ := NewBuffer(i.Prompt) + + var esc bool + var escex bool + var metaDel bool + + var currentLineBuf []rune + + for { + // don't show placeholder when pasting unless we're in multiline mode + showPlaceholder := !i.Pasting || i.Prompt.UseAlt + if buf.IsEmpty() && showPlaceholder { + ph := i.Prompt.placeholder() + fmt.Print(ColorGrey + ph + CursorLeftN(len(ph)) + ColorDefault) + } + + r, err := i.Terminal.Read() + + if buf.IsEmpty() { + fmt.Print(ClearToEOL) + } + + if err != nil { + return "", io.EOF + } + + if escex { + escex = false + + switch r { + case KeyUp: + i.historyPrev(buf, ¤tLineBuf) + case KeyDown: + i.historyNext(buf, ¤tLineBuf) + case KeyLeft: + buf.MoveLeft() + case KeyRight: + buf.MoveRight() + case CharBracketedPaste: + var code string + for range 3 { + r, err = i.Terminal.Read() + if err != nil { + return "", io.EOF + } + + code += string(r) + } + if code == CharBracketedPasteStart { + i.Pasting = true + } else if code == CharBracketedPasteEnd { + i.Pasting = false + } + case KeyDel: + if buf.DisplaySize() > 0 { + buf.Delete() + } + metaDel = true + case MetaStart: + buf.MoveToStart() + case MetaEnd: + buf.MoveToEnd() + default: + // skip any keys we don't know about + continue + } + continue + } else if esc { + esc = false + + switch r { + case 'b': + buf.MoveLeftWord() + case 'f': + buf.MoveRightWord() + case CharBackspace: + buf.DeleteWord() + case CharEscapeEx: + escex = true + } + continue + } + + switch r { + case CharNull: + continue + case CharEsc: + esc = true + case CharInterrupt: + return "", ErrInterrupt + case CharPrev: + i.historyPrev(buf, ¤tLineBuf) + case CharNext: + i.historyNext(buf, ¤tLineBuf) + case CharLineStart: + buf.MoveToStart() + case CharLineEnd: + buf.MoveToEnd() + case CharBackward: + buf.MoveLeft() + case CharForward: + buf.MoveRight() + case CharBackspace, CharCtrlH: + buf.Remove() + case CharTab: + // todo: convert back to real tabs + for range 8 { + buf.Add(' ') + } + case CharDelete: + if buf.DisplaySize() > 0 { + buf.Delete() + } else { + return "", io.EOF + } + case CharKill: + buf.DeleteRemaining() + case CharCtrlU: + buf.DeleteBefore() + case CharCtrlL: + buf.ClearScreen() + case CharCtrlW: + buf.DeleteWord() + case CharCtrlZ: + fd := os.Stdin.Fd() + return handleCharCtrlZ(fd, i.Terminal.termios) + case CharEnter, CharCtrlJ: + output := buf.String() + if output != "" { + i.History.Add(output) + } + buf.MoveToEnd() + fmt.Println() + + return output, nil + default: + if metaDel { + metaDel = false + continue + } + if r >= CharSpace || r == CharEnter || r == CharCtrlJ { + buf.Add(r) + } + } + } +} + +func (i *Instance) HistoryEnable() { + i.History.Enabled = true +} + +func (i *Instance) HistoryDisable() { + i.History.Enabled = false +} + +func (i *Instance) historyPrev(buf *Buffer, currentLineBuf *[]rune) { + if i.History.Pos > 0 { + if i.History.Pos == i.History.Size() { + *currentLineBuf = []rune(buf.String()) + } + buf.Replace([]rune(i.History.Prev())) + } +} + +func (i *Instance) historyNext(buf *Buffer, currentLineBuf *[]rune) { + if i.History.Pos < i.History.Size() { + buf.Replace([]rune(i.History.Next())) + if i.History.Pos == i.History.Size() { + buf.Replace(*currentLineBuf) + } + } +} + +func NewTerminal() (*Terminal, error) { + fd := os.Stdin.Fd() + termios, err := SetRawMode(fd) + if err != nil { + return nil, err + } + + t := &Terminal{ + outchan: make(chan rune), + rawmode: true, + termios: termios, + } + + go t.ioloop() + + return t, nil +} + +func (t *Terminal) ioloop() { + buf := bufio.NewReader(os.Stdin) + + for { + r, _, err := buf.ReadRune() + if err != nil { + close(t.outchan) + break + } + t.outchan <- r + } +} + +func (t *Terminal) Read() (rune, error) { + r, ok := <-t.outchan + if !ok { + return 0, io.EOF + } + + return r, nil +} diff --git a/cmd/cli/readline/readline_unix.go b/cmd/cli/readline/readline_unix.go new file mode 100644 index 000000000..d48b91769 --- /dev/null +++ b/cmd/cli/readline/readline_unix.go @@ -0,0 +1,19 @@ +//go:build !windows + +package readline + +import ( + "syscall" +) + +func handleCharCtrlZ(fd uintptr, termios any) (string, error) { + t := termios.(*Termios) + if err := UnsetRawMode(fd, t); err != nil { + return "", err + } + + _ = syscall.Kill(0, syscall.SIGSTOP) + + // on resume... + return "", nil +} diff --git a/cmd/cli/readline/readline_windows.go b/cmd/cli/readline/readline_windows.go new file mode 100644 index 000000000..a131d0ef7 --- /dev/null +++ b/cmd/cli/readline/readline_windows.go @@ -0,0 +1,6 @@ +package readline + +func handleCharCtrlZ(fd uintptr, state any) (string, error) { + // not supported + return "", nil +} diff --git a/cmd/cli/readline/term.go b/cmd/cli/readline/term.go new file mode 100644 index 000000000..5584cd257 --- /dev/null +++ b/cmd/cli/readline/term.go @@ -0,0 +1,37 @@ +//go:build aix || darwin || dragonfly || freebsd || (linux && !appengine) || netbsd || openbsd || os400 || solaris + +package readline + +import ( + "syscall" +) + +type Termios syscall.Termios + +func SetRawMode(fd uintptr) (*Termios, error) { + termios, err := getTermios(fd) + if err != nil { + return nil, err + } + + newTermios := *termios + newTermios.Iflag &^= syscall.IGNBRK | syscall.BRKINT | syscall.PARMRK | syscall.ISTRIP | syscall.INLCR | syscall.IGNCR | syscall.ICRNL | syscall.IXON + newTermios.Lflag &^= syscall.ECHO | syscall.ECHONL | syscall.ICANON | syscall.ISIG | syscall.IEXTEN + newTermios.Cflag &^= syscall.CSIZE | syscall.PARENB + newTermios.Cflag |= syscall.CS8 + newTermios.Cc[syscall.VMIN] = 1 + newTermios.Cc[syscall.VTIME] = 0 + + return termios, setTermios(fd, &newTermios) +} + +func UnsetRawMode(fd uintptr, termios any) error { + t := termios.(*Termios) + return setTermios(fd, t) +} + +// IsTerminal returns true if the given file descriptor is a terminal. +func IsTerminal(fd uintptr) bool { + _, err := getTermios(fd) + return err == nil +} diff --git a/cmd/cli/readline/term_bsd.go b/cmd/cli/readline/term_bsd.go new file mode 100644 index 000000000..80bee6b3e --- /dev/null +++ b/cmd/cli/readline/term_bsd.go @@ -0,0 +1,25 @@ +//go:build darwin || freebsd || netbsd || openbsd + +package readline + +import ( + "syscall" + "unsafe" +) + +func getTermios(fd uintptr) (*Termios, error) { + termios := new(Termios) + _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, fd, syscall.TIOCGETA, uintptr(unsafe.Pointer(termios)), 0, 0, 0) + if err != 0 { + return nil, err + } + return termios, nil +} + +func setTermios(fd uintptr, termios *Termios) error { + _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, fd, syscall.TIOCSETA, uintptr(unsafe.Pointer(termios)), 0, 0, 0) + if err != 0 { + return err + } + return nil +} diff --git a/cmd/cli/readline/term_linux.go b/cmd/cli/readline/term_linux.go new file mode 100644 index 000000000..e9e36da99 --- /dev/null +++ b/cmd/cli/readline/term_linux.go @@ -0,0 +1,30 @@ +//go:build linux || solaris + +package readline + +import ( + "syscall" + "unsafe" +) + +const ( + tcgets = 0x5401 + tcsets = 0x5402 +) + +func getTermios(fd uintptr) (*Termios, error) { + termios := new(Termios) + _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, fd, tcgets, uintptr(unsafe.Pointer(termios)), 0, 0, 0) + if err != 0 { + return nil, err + } + return termios, nil +} + +func setTermios(fd uintptr, termios *Termios) error { + _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, fd, tcsets, uintptr(unsafe.Pointer(termios)), 0, 0, 0) + if err != 0 { + return err + } + return nil +} diff --git a/cmd/cli/readline/term_windows.go b/cmd/cli/readline/term_windows.go new file mode 100644 index 000000000..3b35149b8 --- /dev/null +++ b/cmd/cli/readline/term_windows.go @@ -0,0 +1,38 @@ +package readline + +import ( + "golang.org/x/sys/windows" +) + +type State struct { + mode uint32 +} + +// IsTerminal checks if the given file descriptor is associated with a terminal +func IsTerminal(fd uintptr) bool { + var st uint32 + err := windows.GetConsoleMode(windows.Handle(fd), &st) + return err == nil +} + +func SetRawMode(fd uintptr) (*State, error) { + var st uint32 + if err := windows.GetConsoleMode(windows.Handle(fd), &st); err != nil { + return nil, err + } + + // this enables raw mode by turning off various flags in the console mode: https://pkg.go.dev/golang.org/x/sys/windows#pkg-constants + raw := st &^ (windows.ENABLE_ECHO_INPUT | windows.ENABLE_PROCESSED_INPUT | windows.ENABLE_LINE_INPUT | windows.ENABLE_PROCESSED_OUTPUT) + + // turn on ENABLE_VIRTUAL_TERMINAL_INPUT to enable escape sequences + raw |= windows.ENABLE_VIRTUAL_TERMINAL_INPUT + if err := windows.SetConsoleMode(windows.Handle(fd), raw); err != nil { + return nil, err + } + return &State{st}, nil +} + +func UnsetRawMode(fd uintptr, state any) error { + s := state.(*State) + return windows.SetConsoleMode(windows.Handle(fd), s.mode) +} diff --git a/cmd/cli/readline/types.go b/cmd/cli/readline/types.go new file mode 100644 index 000000000..f4efa8d92 --- /dev/null +++ b/cmd/cli/readline/types.go @@ -0,0 +1,97 @@ +package readline + +import "strconv" + +const ( + CharNull = 0 + CharLineStart = 1 + CharBackward = 2 + CharInterrupt = 3 + CharDelete = 4 + CharLineEnd = 5 + CharForward = 6 + CharBell = 7 + CharCtrlH = 8 + CharTab = 9 + CharCtrlJ = 10 + CharKill = 11 + CharCtrlL = 12 + CharEnter = 13 + CharNext = 14 + CharPrev = 16 + CharBckSearch = 18 + CharFwdSearch = 19 + CharTranspose = 20 + CharCtrlU = 21 + CharCtrlW = 23 + CharCtrlY = 25 + CharCtrlZ = 26 + CharEsc = 27 + CharSpace = 32 + CharEscapeEx = 91 + CharBackspace = 127 +) + +const ( + KeyDel = 51 + KeyUp = 65 + KeyDown = 66 + KeyRight = 67 + KeyLeft = 68 + MetaEnd = 70 + MetaStart = 72 +) + +const ( + Esc = "\x1b" + + CursorSave = Esc + "[s" + CursorRestore = Esc + "[u" + + CursorEOL = Esc + "[E" + CursorBOL = Esc + "[1G" + CursorHide = Esc + "[?25l" + CursorShow = Esc + "[?25h" + + ClearToEOL = Esc + "[K" + ClearLine = Esc + "[2K" + ClearScreen = Esc + "[2J" + CursorReset = Esc + "[0;0f" + + ColorGrey = Esc + "[38;5;245m" + ColorDefault = Esc + "[0m" + + ColorBold = Esc + "[1m" + + StartBracketedPaste = Esc + "[?2004h" + EndBracketedPaste = Esc + "[?2004l" +) + +func CursorUpN(n int) string { + return Esc + "[" + strconv.Itoa(n) + "A" +} + +func CursorDownN(n int) string { + return Esc + "[" + strconv.Itoa(n) + "B" +} + +func CursorRightN(n int) string { + return Esc + "[" + strconv.Itoa(n) + "C" +} + +func CursorLeftN(n int) string { + return Esc + "[" + strconv.Itoa(n) + "D" +} + +var ( + CursorUp = CursorUpN(1) + CursorDown = CursorDownN(1) + CursorRight = CursorRightN(1) + CursorLeft = CursorLeftN(1) +) + +const ( + CharBracketedPaste = 50 + CharBracketedPasteStart = "00~" + CharBracketedPasteEnd = "01~" +)