Skip to content

Commit 527c410

Browse files
authored
Conditionally include --jinja (#201)
* refactor: update LlamaCppConfig to conditionally include --jinja argument * refactor: simplify expected arguments in LlamaCppConfig tests
1 parent 45f9314 commit 527c410

File tree

2 files changed

+32
-10
lines changed

2 files changed

+32
-10
lines changed

pkg/inference/backends/llamacpp/llamacpp_config.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import (
66
"strconv"
77

88
"github.com/docker/model-runner/pkg/distribution/types"
9-
109
"github.com/docker/model-runner/pkg/inference"
1110
)
1211

@@ -18,7 +17,7 @@ type Config struct {
1817

1918
// NewDefaultLlamaCppConfig creates a new LlamaCppConfig with default values.
2019
func NewDefaultLlamaCppConfig() *Config {
21-
args := append([]string{"--jinja", "-ngl", "999", "--metrics"})
20+
args := append([]string{"-ngl", "999", "--metrics"})
2221

2322
// Special case for ARM64
2423
if runtime.GOARCH == "arm64" {
@@ -69,9 +68,11 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference
6968
args = append(args, config.RuntimeFlags...)
7069
}
7170

72-
// Add arguments for Multimodal projector
71+
// Add arguments for Multimodal projector or jinja (they are mutually exclusive)
7372
if path := bundle.MMPROJPath(); path != "" {
7473
args = append(args, "--mmproj", path)
74+
} else {
75+
args = append(args, "--jinja")
7576
}
7677

7778
return args, nil

pkg/inference/backends/llamacpp/llamacpp_config_test.go

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ import (
1414
func TestNewDefaultLlamaCppConfig(t *testing.T) {
1515
config := NewDefaultLlamaCppConfig()
1616

17-
// Test default arguments
18-
if !containsArg(config.Args, "--jinja") {
19-
t.Error("Expected --jinja argument to be present")
17+
// Test that --jinja is NOT in default args (it will be added conditionally in GetArgs)
18+
if containsArg(config.Args, "--jinja") {
19+
t.Error("Did not expect --jinja argument in default config (it should be added conditionally)")
2020
}
2121

2222
// Test -ngl argument and its value
@@ -74,7 +74,7 @@ func TestGetArgs(t *testing.T) {
7474
socket := "unix:///tmp/socket"
7575

7676
// Build base expected args based on architecture
77-
baseArgs := []string{"--jinja", "-ngl", "999", "--metrics"}
77+
baseArgs := []string{"-ngl", "999", "--metrics"}
7878
if runtime.GOARCH == "arm64" {
7979
nThreads := max(2, runtime.NumCPU()/2)
8080
baseArgs = append(baseArgs, "--threads", strconv.Itoa(nThreads))
@@ -97,6 +97,7 @@ func TestGetArgs(t *testing.T) {
9797
"--model", modelPath,
9898
"--host", socket,
9999
"--ctx-size", "4096",
100+
"--jinja",
100101
),
101102
},
102103
{
@@ -110,6 +111,7 @@ func TestGetArgs(t *testing.T) {
110111
"--host", socket,
111112
"--embeddings",
112113
"--ctx-size", "4096",
114+
"--jinja",
113115
),
114116
},
115117
{
@@ -125,7 +127,8 @@ func TestGetArgs(t *testing.T) {
125127
"--model", modelPath,
126128
"--host", socket,
127129
"--embeddings",
128-
"--ctx-size", "1234", // should add this flag
130+
"--ctx-size", "1234",
131+
"--jinja",
129132
),
130133
},
131134
{
@@ -145,6 +148,7 @@ func TestGetArgs(t *testing.T) {
145148
"--host", socket,
146149
"--embeddings",
147150
"--ctx-size", "2096", // model config takes precedence
151+
"--jinja",
148152
),
149153
},
150154
{
@@ -159,6 +163,7 @@ func TestGetArgs(t *testing.T) {
159163
"--host", socket,
160164
"--chat-template-file", "/path/to/bundle/template.jinja",
161165
"--ctx-size", "4096",
166+
"--jinja",
162167
),
163168
},
164169
{
@@ -175,7 +180,22 @@ func TestGetArgs(t *testing.T) {
175180
"--host", socket,
176181
"--embeddings",
177182
"--ctx-size", "4096",
178-
"--some", "flag", // model config takes precedence
183+
"--some", "flag",
184+
"--jinja",
185+
),
186+
},
187+
{
188+
name: "multimodal projector removes jinja",
189+
mode: inference.BackendModeCompletion,
190+
bundle: &fakeBundle{
191+
ggufPath: modelPath,
192+
mmprojPath: "/path/to/model.mmproj",
193+
},
194+
expected: append(slices.Clone(baseArgs),
195+
"--model", modelPath,
196+
"--host", socket,
197+
"--ctx-size", "4096",
198+
"--mmproj", "/path/to/model.mmproj",
179199
),
180200
},
181201
}
@@ -261,6 +281,7 @@ type fakeBundle struct {
261281
ggufPath string
262282
config types.Config
263283
templatePath string
284+
mmprojPath string
264285
}
265286

266287
func (f *fakeBundle) ChatTemplatePath() string {
@@ -276,7 +297,7 @@ func (f *fakeBundle) GGUFPath() string {
276297
}
277298

278299
func (f *fakeBundle) MMPROJPath() string {
279-
return ""
300+
return f.mmprojPath
280301
}
281302

282303
func (f *fakeBundle) SafetensorsPath() string {

0 commit comments

Comments
 (0)