Skip to content

Commit 442f049

Browse files
authored
Merge pull request #51 from docker/revert-50-revert-41-configure-llamacpp-args
Revert "Revert "configure backend args""
2 parents 25eefbf + 54f2b98 commit 442f049

File tree

7 files changed

+439
-29
lines changed

7 files changed

+439
-29
lines changed

Makefile

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ BASE_IMAGE := ubuntu:24.04
77
DOCKER_IMAGE := docker/model-runner:latest
88
PORT := 8080
99
MODELS_PATH := $(shell pwd)/models
10+
LLAMA_ARGS ?=
1011

1112
# Main targets
1213
.PHONY: build run clean test docker-build docker-run help
@@ -20,6 +21,7 @@ build:
2021

2122
# Run the application locally
2223
run: build
24+
LLAMA_ARGS="$(LLAMA_ARGS)" \
2325
./$(APP_NAME)
2426

2527
# Clean build artifacts
@@ -55,6 +57,7 @@ docker-run: docker-build
5557
-e MODEL_RUNNER_PORT=$(PORT) \
5658
-e LLAMA_SERVER_PATH=/app/bin \
5759
-e MODELS_PATH=/models \
60+
-e LLAMA_ARGS="$(LLAMA_ARGS)" \
5861
$(DOCKER_IMAGE)
5962

6063
# Show help
@@ -67,3 +70,10 @@ help:
6770
@echo " docker-build - Build Docker image"
6871
@echo " docker-run - Run in Docker container with TCP port access and mounted model storage"
6972
@echo " help - Show this help message"
73+
@echo ""
74+
@echo "Backend configuration options:"
75+
@echo " LLAMA_ARGS - Arguments for llama.cpp (e.g., \"--verbose --jinja -ngl 100 --ctx-size 2048\")"
76+
@echo ""
77+
@echo "Example usage:"
78+
@echo " make run LLAMA_ARGS=\"--verbose --jinja -ngl 100 --ctx-size 2048\""
79+
@echo " make docker-run LLAMA_ARGS=\"--verbose --jinja -ngl 100 --threads 4 --ctx-size 2048\""

main.go

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@ import (
77
"os"
88
"os/signal"
99
"path/filepath"
10+
"strings"
1011
"syscall"
1112

1213
"github.com/docker/model-runner/pkg/inference"
1314
"github.com/docker/model-runner/pkg/inference/backends/llamacpp"
15+
"github.com/docker/model-runner/pkg/inference/config"
1416
"github.com/docker/model-runner/pkg/inference/models"
1517
"github.com/docker/model-runner/pkg/inference/scheduling"
1618
"github.com/docker/model-runner/pkg/routing"
@@ -50,6 +52,9 @@ func main() {
5052

5153
log.Infof("LLAMA_SERVER_PATH: %s", llamaServerPath)
5254

55+
// Create llama.cpp configuration from environment variables
56+
llamaCppConfig := createLlamaCppConfigFromEnv()
57+
5358
llamaCppBackend, err := llamacpp.New(
5459
log,
5560
modelManager,
@@ -61,6 +66,7 @@ func main() {
6166
_ = os.MkdirAll(d, 0o755)
6267
return d
6368
}(),
69+
llamaCppConfig,
6470
)
6571
if err != nil {
6672
log.Fatalf("unable to initialize %s backend: %v", llamacpp.Name, err)
@@ -134,3 +140,59 @@ func main() {
134140
}
135141
log.Infoln("Docker Model Runner stopped")
136142
}
143+
144+
// createLlamaCppConfigFromEnv creates a LlamaCppConfig from environment variables
145+
func createLlamaCppConfigFromEnv() config.BackendConfig {
146+
// Check if any configuration environment variables are set
147+
argsStr := os.Getenv("LLAMA_ARGS")
148+
149+
// If no environment variables are set, use default configuration
150+
if argsStr == "" {
151+
return nil // nil will cause the backend to use its default configuration
152+
}
153+
154+
// Split the string by spaces, respecting quoted arguments
155+
args := splitArgs(argsStr)
156+
157+
// Check for disallowed arguments
158+
disallowedArgs := []string{"--model", "--host", "--embeddings", "--mmproj"}
159+
for _, arg := range args {
160+
for _, disallowed := range disallowedArgs {
161+
if arg == disallowed {
162+
log.Fatalf("LLAMA_ARGS cannot override the %s argument as it is controlled by the model runner", disallowed)
163+
}
164+
}
165+
}
166+
167+
log.Infof("Using custom arguments: %v", args)
168+
return &llamacpp.Config{
169+
Args: args,
170+
}
171+
}
172+
173+
// splitArgs splits a string into arguments, respecting quoted arguments
174+
func splitArgs(s string) []string {
175+
var args []string
176+
var currentArg strings.Builder
177+
inQuotes := false
178+
179+
for _, r := range s {
180+
switch {
181+
case r == '"' || r == '\'':
182+
inQuotes = !inQuotes
183+
case r == ' ' && !inQuotes:
184+
if currentArg.Len() > 0 {
185+
args = append(args, currentArg.String())
186+
currentArg.Reset()
187+
}
188+
default:
189+
currentArg.WriteRune(r)
190+
}
191+
}
192+
193+
if currentArg.Len() > 0 {
194+
args = append(args, currentArg.String())
195+
}
196+
197+
return args
198+
}

main_test.go

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
package main
2+
3+
import (
4+
"os"
5+
"testing"
6+
7+
"github.com/docker/model-runner/pkg/inference/backends/llamacpp"
8+
"github.com/sirupsen/logrus"
9+
)
10+
11+
func TestCreateLlamaCppConfigFromEnv(t *testing.T) {
12+
tests := []struct {
13+
name string
14+
llamaArgs string
15+
wantErr bool
16+
}{
17+
{
18+
name: "empty args",
19+
llamaArgs: "",
20+
wantErr: false,
21+
},
22+
{
23+
name: "valid args",
24+
llamaArgs: "--threads 4 --ctx-size 2048",
25+
wantErr: false,
26+
},
27+
{
28+
name: "disallowed model arg",
29+
llamaArgs: "--model test.gguf",
30+
wantErr: true,
31+
},
32+
{
33+
name: "disallowed host arg",
34+
llamaArgs: "--host localhost:8080",
35+
wantErr: true,
36+
},
37+
{
38+
name: "disallowed embeddings arg",
39+
llamaArgs: "--embeddings",
40+
wantErr: true,
41+
},
42+
{
43+
name: "disallowed mmproj arg",
44+
llamaArgs: "--mmproj test.mmproj",
45+
wantErr: true,
46+
},
47+
{
48+
name: "multiple disallowed args",
49+
llamaArgs: "--model test.gguf --host localhost:8080",
50+
wantErr: true,
51+
},
52+
{
53+
name: "quoted args",
54+
llamaArgs: "--prompt \"Hello, world!\" --threads 4",
55+
wantErr: false,
56+
},
57+
}
58+
59+
for _, tt := range tests {
60+
t.Run(tt.name, func(t *testing.T) {
61+
// Set up environment
62+
if tt.llamaArgs != "" {
63+
os.Setenv("LLAMA_ARGS", tt.llamaArgs)
64+
defer os.Unsetenv("LLAMA_ARGS")
65+
}
66+
67+
// Create a test logger that captures fatal errors
68+
originalLog := log
69+
defer func() { log = originalLog }()
70+
71+
// Create a new logger that will exit with a special exit code
72+
testLog := logrus.New()
73+
var exitCode int
74+
testLog.ExitFunc = func(code int) {
75+
exitCode = code
76+
}
77+
log = testLog
78+
79+
config := createLlamaCppConfigFromEnv()
80+
81+
if tt.wantErr {
82+
if exitCode != 1 {
83+
t.Errorf("Expected exit code 1, got %d", exitCode)
84+
}
85+
} else {
86+
if exitCode != 0 {
87+
t.Errorf("Expected exit code 0, got %d", exitCode)
88+
}
89+
if tt.llamaArgs == "" {
90+
if config != nil {
91+
t.Error("Expected nil config for empty args")
92+
}
93+
} else {
94+
llamaConfig, ok := config.(*llamacpp.Config)
95+
if !ok {
96+
t.Errorf("Expected *llamacpp.Config, got %T", config)
97+
}
98+
if llamaConfig == nil {
99+
t.Error("Expected non-nil config")
100+
}
101+
if len(llamaConfig.Args) == 0 {
102+
t.Error("Expected non-empty args")
103+
}
104+
}
105+
}
106+
})
107+
}
108+
}

pkg/inference/backends/llamacpp/llamacpp.go

Lines changed: 14 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@ import (
1010
"os/exec"
1111
"path/filepath"
1212
"runtime"
13-
"strconv"
1413

1514
"github.com/docker/model-runner/pkg/diskusage"
1615
"github.com/docker/model-runner/pkg/inference"
16+
"github.com/docker/model-runner/pkg/inference/config"
1717
"github.com/docker/model-runner/pkg/inference/models"
1818
"github.com/docker/model-runner/pkg/logging"
1919
)
@@ -39,6 +39,8 @@ type llamaCpp struct {
3939
updatedServerStoragePath string
4040
// status is the state in which the llama.cpp backend is in.
4141
status string
42+
// config is the configuration for the llama.cpp backend.
43+
config config.BackendConfig
4244
}
4345

4446
// New creates a new llama.cpp-based backend.
@@ -48,13 +50,20 @@ func New(
4850
serverLog logging.Logger,
4951
vendoredServerStoragePath string,
5052
updatedServerStoragePath string,
53+
conf config.BackendConfig,
5154
) (inference.Backend, error) {
55+
// If no config is provided, use the default configuration
56+
if conf == nil {
57+
conf = NewDefaultLlamaCppConfig()
58+
}
59+
5260
return &llamaCpp{
5361
log: log,
5462
modelManager: modelManager,
5563
serverLog: serverLog,
5664
vendoredServerStoragePath: vendoredServerStoragePath,
5765
updatedServerStoragePath: updatedServerStoragePath,
66+
config: conf,
5867
}, nil
5968
}
6069

@@ -115,11 +124,6 @@ func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference
115124
return fmt.Errorf("failed to get model path: %w", err)
116125
}
117126

118-
modelDesc, err := l.modelManager.GetModel(model)
119-
if err != nil {
120-
return fmt.Errorf("failed to get model: %w", err)
121-
}
122-
123127
if err := os.RemoveAll(socket); err != nil && !errors.Is(err, fs.ErrNotExist) {
124128
l.log.Warnf("failed to remove socket file %s: %w\n", socket, err)
125129
l.log.Warnln("llama.cpp may not be able to start")
@@ -129,32 +133,13 @@ func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference
129133
if l.updatedLlamaCpp {
130134
binPath = l.updatedServerStoragePath
131135
}
132-
llamaCppArgs := []string{"--model", modelPath, "--jinja", "--host", socket}
133-
if mode == inference.BackendModeEmbedding {
134-
llamaCppArgs = append(llamaCppArgs, "--embeddings")
135-
}
136-
if runtime.GOOS == "windows" && runtime.GOARCH == "arm64" {
137-
// Using a thread count equal to core count results in bad performance, and there seems to be little to no gain
138-
// in going beyond core_count/2.
139-
// TODO(p1-0tr): dig into why the defaults don't work well on windows/arm64
140-
nThreads := min(2, runtime.NumCPU()/2)
141-
llamaCppArgs = append(llamaCppArgs, "--threads", strconv.Itoa(nThreads))
142-
143-
modelConfig, err := modelDesc.Config()
144-
if err != nil {
145-
return fmt.Errorf("failed to get model config: %w", err)
146-
}
147-
// The Adreno OpenCL implementation currently only supports Q4_0
148-
if modelConfig.Quantization == "Q4_0" {
149-
llamaCppArgs = append(llamaCppArgs, "-ngl", "100")
150-
}
151-
} else {
152-
llamaCppArgs = append(llamaCppArgs, "-ngl", "100")
153-
}
136+
137+
args := l.config.GetArgs(modelPath, socket, mode)
138+
l.log.Infof("llamaCppArgs: %v", args)
154139
llamaCppProcess := exec.CommandContext(
155140
ctx,
156141
filepath.Join(binPath, "com.docker.llama-server"),
157-
llamaCppArgs...,
142+
args...,
158143
)
159144
llamaCppProcess.Cancel = func() error {
160145
if runtime.GOOS == "windows" {
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
package llamacpp
2+
3+
import (
4+
"runtime"
5+
"strconv"
6+
7+
"github.com/docker/model-runner/pkg/inference"
8+
)
9+
10+
// Config is the configuration for the llama.cpp backend.
11+
type Config struct {
12+
// Args are the base arguments that are always included.
13+
Args []string
14+
}
15+
16+
// NewDefaultLlamaCppConfig creates a new LlamaCppConfig with default values.
17+
func NewDefaultLlamaCppConfig() *Config {
18+
args := append([]string{"--jinja", "-ngl", "100"})
19+
20+
// Special case for Windows ARM64
21+
if runtime.GOOS == "windows" && runtime.GOARCH == "arm64" {
22+
// Using a thread count equal to core count results in bad performance, and there seems to be little to no gain
23+
// in going beyond core_count/2.
24+
if !containsArg(args, "--threads") {
25+
nThreads := min(2, runtime.NumCPU()/2)
26+
args = append(args, "--threads", strconv.Itoa(nThreads))
27+
}
28+
}
29+
30+
return &Config{
31+
Args: args,
32+
}
33+
}
34+
35+
// GetArgs implements BackendConfig.GetArgs.
36+
func (c *Config) GetArgs(modelPath, socket string, mode inference.BackendMode) []string {
37+
// Start with the arguments from LlamaCppConfig
38+
args := append([]string{}, c.Args...)
39+
40+
// Add model and socket arguments
41+
args = append(args, "--model", modelPath, "--host", socket)
42+
43+
// Add mode-specific arguments
44+
if mode == inference.BackendModeEmbedding {
45+
args = append(args, "--embeddings")
46+
}
47+
48+
return args
49+
}
50+
51+
// containsArg checks if the given argument is already in the args slice.
52+
func containsArg(args []string, arg string) bool {
53+
for _, a := range args {
54+
if a == arg {
55+
return true
56+
}
57+
}
58+
return false
59+
}

0 commit comments

Comments
 (0)