@@ -10,10 +10,10 @@ import (
1010 "os/exec"
1111 "path/filepath"
1212 "runtime"
13+ "strconv"
1314
1415 "github.com/docker/model-runner/pkg/diskusage"
1516 "github.com/docker/model-runner/pkg/inference"
16- "github.com/docker/model-runner/pkg/inference/config"
1717 "github.com/docker/model-runner/pkg/inference/models"
1818 "github.com/docker/model-runner/pkg/logging"
1919)
@@ -39,8 +39,6 @@ type llamaCpp struct {
3939 updatedServerStoragePath string
4040 // status is the state in which the llama.cpp backend is in.
4141 status string
42- // config is the configuration for the llama.cpp backend.
43- config config.BackendConfig
4442}
4543
4644// New creates a new llama.cpp-based backend.
@@ -50,20 +48,13 @@ func New(
5048 serverLog logging.Logger ,
5149 vendoredServerStoragePath string ,
5250 updatedServerStoragePath string ,
53- conf config.BackendConfig ,
5451) (inference.Backend , error ) {
55- // If no config is provided, use the default configuration
56- if conf == nil {
57- conf = NewDefaultLlamaCppConfig ()
58- }
59-
6052 return & llamaCpp {
6153 log : log ,
6254 modelManager : modelManager ,
6355 serverLog : serverLog ,
6456 vendoredServerStoragePath : vendoredServerStoragePath ,
6557 updatedServerStoragePath : updatedServerStoragePath ,
66- config : conf ,
6758 }, nil
6859}
6960
@@ -124,6 +115,11 @@ func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference
124115 return fmt .Errorf ("failed to get model path: %w" , err )
125116 }
126117
118+ modelDesc , err := l .modelManager .GetModel (model )
119+ if err != nil {
120+ return fmt .Errorf ("failed to get model: %w" , err )
121+ }
122+
127123 if err := os .RemoveAll (socket ); err != nil && ! errors .Is (err , fs .ErrNotExist ) {
128124 l .log .Warnf ("failed to remove socket file %s: %w\n " , socket , err )
129125 l .log .Warnln ("llama.cpp may not be able to start" )
@@ -133,13 +129,32 @@ func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference
133129 if l .updatedLlamaCpp {
134130 binPath = l .updatedServerStoragePath
135131 }
136-
137- args := l .config .GetArgs (modelPath , socket , mode )
138- l .log .Infof ("llamaCppArgs: %v" , args )
132+ llamaCppArgs := []string {"--model" , modelPath , "--jinja" , "--host" , socket }
133+ if mode == inference .BackendModeEmbedding {
134+ llamaCppArgs = append (llamaCppArgs , "--embeddings" )
135+ }
136+ if runtime .GOOS == "windows" && runtime .GOARCH == "arm64" {
137+ // Using a thread count equal to core count results in bad performance, and there seems to be little to no gain
138+ // in going beyond core_count/2.
139+ // TODO(p1-0tr): dig into why the defaults don't work well on windows/arm64
140+ nThreads := min (2 , runtime .NumCPU ()/ 2 )
141+ llamaCppArgs = append (llamaCppArgs , "--threads" , strconv .Itoa (nThreads ))
142+
143+ modelConfig , err := modelDesc .Config ()
144+ if err != nil {
145+ return fmt .Errorf ("failed to get model config: %w" , err )
146+ }
147+ // The Adreno OpenCL implementation currently only supports Q4_0
148+ if modelConfig .Quantization == "Q4_0" {
149+ llamaCppArgs = append (llamaCppArgs , "-ngl" , "100" )
150+ }
151+ } else {
152+ llamaCppArgs = append (llamaCppArgs , "-ngl" , "100" )
153+ }
139154 llamaCppProcess := exec .CommandContext (
140155 ctx ,
141156 filepath .Join (binPath , "com.docker.llama-server" ),
142- args ... ,
157+ llamaCppArgs ... ,
143158 )
144159 llamaCppProcess .Cancel = func () error {
145160 if runtime .GOOS == "windows" {
0 commit comments