Skip to content
This repository was archived by the owner on Oct 6, 2025. It is now read-only.
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
381 changes: 300 additions & 81 deletions commands/run.go

Large diffs are not rendered by default.

176 changes: 132 additions & 44 deletions commands/run_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@ package commands

import (
"bufio"
"errors"
"io"
"strings"
"testing"

"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestReadMultilineInput(t *testing.T) {
Expand All @@ -15,40 +18,34 @@ func TestReadMultilineInput(t *testing.T) {
expected string
wantErr bool
}{
{
name: "single line input",
input: "hello world",
expected: "hello world",
wantErr: false,
},
{
name: "single line with triple quotes",
input: `"""hello world"""`,
expected: `"""hello world"""`,
expected: `hello world`,
wantErr: false,
},
{
name: "multiline input with double quotes",
input: `"""tell
me
a
joke"""`,
expected: `"""tell
me
a
joke"""`,
me
a
joke"""`,
expected: `tell
me
a
joke`,
wantErr: false,
},
{
name: "multiline input with single quotes",
input: `'''tell
me
a
joke'''`,
expected: `'''tell
me
a
joke'''`,
me
a
joke'''`,
expected: `tell
me
a
joke`,
wantErr: false,
},
{
Expand All @@ -61,27 +58,124 @@ joke'''`,
name: "multiline with empty lines",
input: `"""first line

third line"""`,
expected: `"""first line
third line"""`,
expected: `first line

third line"""`,
third line`,
wantErr: false,
},
{
name: "multiline with spaces and closing quotes on new line",
input: `"""first line
second line
third line
"""`,
expected: `first line
second line
third line`, // this will intentionally trim the last newline
wantErr: false,
},
{
name: "multiline with closing quotes and trailing spaces",
input: `"""first line
second line
third line """`,
expected: `first line
second line
third line `,
wantErr: false,
},
{
name: "single quotes with spaces",
input: `'''foo bar'''`,
expected: `foo bar`,
wantErr: false,
},
{
name: "triple quotes only",
input: `""""""`,
expected: "",
wantErr: false,
},
{
name: "single quotes only",
input: `''''''`,
expected: "",
wantErr: false,
},
{
name: "closing quotes in middle of line",
input: `"""foo"""bar"""`,
expected: `foo"""bar`,
wantErr: false,
},
{
name: "no closing quotes",
input: `"""foo
bar
baz`,
expected: "",
wantErr: true,
},
{
name: "invalid prefix",
input: `"foo bar"`,
expected: "",
wantErr: true,
},
{
name: "prefix but no content",
input: `"""`,
expected: "",
wantErr: true,
},
{
name: "prefix and newline only",
input: `"""
"""`,
expected: "",
wantErr: false,
},
{
name: "multiline with only whitespace",
input: `"""

"""`,
expected: `

`,
wantErr: false,
},
}

// Strings can be read either as the first line followed by the rest of the text,
// or as a single block of text. Make sure we test both scenarios.

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create a mock command for testing
cmd := &cobra.Command{}
r := bufio.NewReader(strings.NewReader(tt.input))
result, err := readMultilineString(t.Context(), r, "")

// Create a scanner from the test input
scanner := bufio.NewScanner(strings.NewReader(tt.input))
if (err != nil) != tt.wantErr {
t.Errorf("readMultilineInput() error = %v, wantErr %v", err, tt.wantErr)
return
}

// Capture output to avoid printing during tests
var output strings.Builder
cmd.SetOut(&output)
if result != tt.expected {
t.Errorf("readMultilineInput() = %q, want %q", result, tt.expected)
}
})

result, err := readMultilineInput(cmd, scanner)
t.Run(tt.name+"_chunked", func(t *testing.T) {
r := bufio.NewReader(strings.NewReader(tt.input))
firstLine, err := r.ReadString('\n') // Simulate reading the first line
if errors.Is(err, io.EOF) {
// Some test cases are single line, EOF is ok here
firstLine = tt.input
} else {
require.NoError(t, err)
}
result, err := readMultilineString(t.Context(), r, firstLine)

if (err != nil) != tt.wantErr {
t.Errorf("readMultilineInput() error = %v, wantErr %v", err, tt.wantErr)
Expand All @@ -98,18 +192,12 @@ third line"""`,
func TestReadMultilineInputUnclosed(t *testing.T) {
// Test unclosed multiline input (should return error)
input := `"""unclosed multiline`
cmd := &cobra.Command{}
var output strings.Builder
cmd.SetOut(&output)

scanner := bufio.NewScanner(strings.NewReader(input))

_, err := readMultilineInput(cmd, scanner)
_, err := readMultilineString(t.Context(), strings.NewReader(input), "")
if err == nil {
t.Error("readMultilineInput() should return error for unclosed multiline input")
t.Fatal("readMultilineInput() should return an error for unclosed multiline input")
}

if !strings.Contains(err.Error(), "unclosed multiline input") {
t.Errorf("readMultilineInput() error should mention unclosed multiline input, got: %v", err)
}
assert.Contains(t, err.Error(), "unclosed multiline input", "error should mention unclosed multiline input")
// Error should also be io.EOF
assert.True(t, errors.Is(err, io.EOF), "error should be io.EOF")
}
18 changes: 11 additions & 7 deletions desktop/desktop.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,9 @@ const (
chatPrinterReasoning
)

func (c *Client) Chat(backend, model, prompt, apiKey string) error {
// Chat sends a chat message to the model, prints the response to the standard output
// and then returns it as a slice of the received chunks.
func (c *Client) Chat(backend, model, prompt, apiKey string) ([]string, error) {
model = normalizeHuggingFaceModelName(model)
if !strings.Contains(strings.Trim(model, "/"), "/") {
// Do an extra API call to check if the model parameter isn't a model ID.
Expand All @@ -393,7 +395,7 @@ func (c *Client) Chat(backend, model, prompt, apiKey string) error {

jsonData, err := json.Marshal(reqBody)
if err != nil {
return fmt.Errorf("error marshaling request: %w", err)
return nil, fmt.Errorf("error marshaling request: %w", err)
}

var completionsPath string
Expand All @@ -411,18 +413,19 @@ func (c *Client) Chat(backend, model, prompt, apiKey string) error {
apiKey,
)
if err != nil {
return c.handleQueryError(err, completionsPath)
return nil, c.handleQueryError(err, completionsPath)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("error response: status=%d body=%s", resp.StatusCode, body)
return nil, fmt.Errorf("error response: status=%d body=%s", resp.StatusCode, body)
}

printerState := chatPrinterNone
reasoningFmt := color.New(color.FgWhite).Add(color.Italic)
scanner := bufio.NewScanner(resp.Body)
var chunks []string
for scanner.Scan() {
line := scanner.Text()
if line == "" {
Expand All @@ -441,7 +444,7 @@ func (c *Client) Chat(backend, model, prompt, apiKey string) error {

var streamResp OpenAIChatResponse
if err := json.Unmarshal([]byte(data), &streamResp); err != nil {
return fmt.Errorf("error parsing stream response: %w", err)
return nil, fmt.Errorf("error parsing stream response: %w", err)
}

if len(streamResp.Choices) > 0 {
Expand All @@ -463,15 +466,16 @@ func (c *Client) Chat(backend, model, prompt, apiKey string) error {
}
printerState = chatPrinterContent
fmt.Print(chunk)
chunks = append(chunks, chunk)
}
}
}

if err := scanner.Err(); err != nil {
return fmt.Errorf("error reading response stream: %w", err)
return nil, fmt.Errorf("error reading response stream: %w", err)
}

return nil
return chunks, nil
}

func (c *Client) Remove(models []string, force bool) (string, error) {
Expand Down
3 changes: 2 additions & 1 deletion desktop/desktop_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ func TestChatHuggingFaceModel(t *testing.T) {
Body: io.NopCloser(bytes.NewBufferString("data: {\"choices\":[{\"delta\":{\"content\":\"Hello there!\"}}]}\n")),
}, nil)

err := client.Chat("", modelName, prompt, "")
resp, err := client.Chat("", modelName, prompt, "")
assert.NoError(t, err)
assert.Equal(t, []string{"Hello there!"}, resp)
}

func TestInspectHuggingFaceModel(t *testing.T) {
Expand Down
21 changes: 21 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ go 1.24
toolchain go1.24.4

require (
github.com/charmbracelet/bubbles v0.21.0
github.com/charmbracelet/bubbletea v1.3.6
github.com/containerd/errdefs v1.0.0
github.com/docker/cli v28.3.0+incompatible
github.com/docker/cli-docs-tool v0.10.0
Expand All @@ -24,14 +26,23 @@ require (
github.com/stretchr/testify v1.10.0
go.opentelemetry.io/otel v1.37.0
go.uber.org/mock v0.5.0
golang.design/x/clipboard v0.7.1
golang.org/x/sync v0.15.0
golang.org/x/term v0.32.0
)

require (
github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c // indirect
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/StackExchange/wmi v1.2.1 // indirect
github.com/atotto/clipboard v0.1.4 // indirect
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect
github.com/charmbracelet/lipgloss v1.1.0 // indirect
github.com/charmbracelet/x/ansi v0.9.3 // indirect
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect
github.com/charmbracelet/x/term v0.2.1 // indirect
github.com/containerd/containerd/v2 v2.1.3 // indirect
github.com/containerd/errdefs/pkg v0.3.0 // indirect
github.com/containerd/log v0.1.0 // indirect
Expand All @@ -46,6 +57,7 @@ require (
github.com/docker/docker-credential-helpers v0.9.3 // indirect
github.com/elastic/go-sysinfo v1.15.3 // indirect
github.com/elastic/go-windows v1.0.2 // indirect
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/fsnotify/fsnotify v1.9.0 // indirect
github.com/fvbommel/sortorder v1.1.0 // indirect
Expand All @@ -63,7 +75,9 @@ require (
github.com/jaypipes/pcidb v1.0.1 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/compress v1.18.0 // indirect
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-localereader v0.0.1 // indirect
github.com/mattn/go-runewidth v0.0.16 // indirect
github.com/mattn/go-shellwords v1.0.12 // indirect
github.com/mitchellh/go-homedir v1.1.0 // indirect
Expand All @@ -75,6 +89,9 @@ require (
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/morikuni/aec v1.0.0 // indirect
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
github.com/muesli/cancelreader v0.2.2 // indirect
github.com/muesli/termenv v0.16.0 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.1 // indirect
Expand All @@ -89,6 +106,7 @@ require (
github.com/smallnest/ringbuffer v0.0.0-20241116012123-461381446e3d // indirect
github.com/theupdateframework/notary v0.7.1-0.20210315103452-bf96a202a09a // indirect
github.com/vbatts/tar-split v0.12.1 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.62.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.34.0 // indirect
Expand All @@ -101,6 +119,9 @@ require (
go.opentelemetry.io/proto/otlp v1.5.0 // indirect
golang.org/x/crypto v0.39.0 // indirect
golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8 // indirect
golang.org/x/exp/shiny v0.0.0-20250606033433-dcc06ee1d476 // indirect
golang.org/x/image v0.28.0 // indirect
golang.org/x/mobile v0.0.0-20250606033058-a2a15c67f36f // indirect
golang.org/x/mod v0.25.0 // indirect
golang.org/x/net v0.41.0 // indirect
golang.org/x/sys v0.33.0 // indirect
Expand Down
Loading
Loading