Skip to content

Commit c733ed7

Browse files
authored
Merge pull request #184 from docker/fix-tests-on-arm
test: update tests to support ARM architecture
2 parents dafd755 + 3287ffe commit c733ed7

File tree

3 files changed

+23
-31
lines changed

3 files changed

+23
-31
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ model-runner
44
model-runner.sock
55
# Default MODELS_PATH in Makefile
66
models-store/
7+
# Default MODELS_PATH in mdltool
8+
model-store/
79
# Directory where we store the updated llama.cpp
810
updated-inference/
911
vendor/

pkg/inference/backends/llamacpp/llamacpp_config.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ func NewDefaultLlamaCppConfig() *Config {
2525
// Using a thread count equal to core count results in bad performance, and there seems to be little to no gain
2626
// in going beyond core_count/2.
2727
if !containsArg(args, "--threads") {
28-
nThreads := min(2, runtime.NumCPU()/2)
28+
nThreads := max(2, runtime.NumCPU()/2)
2929
args = append(args, "--threads", strconv.Itoa(nThreads))
3030
}
3131
}

pkg/inference/backends/llamacpp/llamacpp_config_test.go

Lines changed: 20 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package llamacpp
22

33
import (
44
"runtime"
5+
"slices"
56
"strconv"
67
"testing"
78

@@ -72,6 +73,13 @@ func TestGetArgs(t *testing.T) {
7273
modelPath := "/path/to/model"
7374
socket := "unix:///tmp/socket"
7475

76+
// Build base expected args based on architecture
77+
baseArgs := []string{"--jinja", "-ngl", "999", "--metrics"}
78+
if runtime.GOARCH == "arm64" {
79+
nThreads := max(2, runtime.NumCPU()/2)
80+
baseArgs = append(baseArgs, "--threads", strconv.Itoa(nThreads))
81+
}
82+
7583
tests := []struct {
7684
name string
7785
bundle types.ModelBundle
@@ -85,30 +93,24 @@ func TestGetArgs(t *testing.T) {
8593
bundle: &fakeBundle{
8694
ggufPath: modelPath,
8795
},
88-
expected: []string{
89-
"--jinja",
90-
"-ngl", "999",
91-
"--metrics",
96+
expected: append(slices.Clone(baseArgs),
9297
"--model", modelPath,
9398
"--host", socket,
9499
"--ctx-size", "4096",
95-
},
100+
),
96101
},
97102
{
98103
name: "embedding mode",
99104
mode: inference.BackendModeEmbedding,
100105
bundle: &fakeBundle{
101106
ggufPath: modelPath,
102107
},
103-
expected: []string{
104-
"--jinja",
105-
"-ngl", "999",
106-
"--metrics",
108+
expected: append(slices.Clone(baseArgs),
107109
"--model", modelPath,
108110
"--host", socket,
109111
"--embeddings",
110112
"--ctx-size", "4096",
111-
},
113+
),
112114
},
113115
{
114116
name: "context size from backend config",
@@ -119,15 +121,12 @@ func TestGetArgs(t *testing.T) {
119121
config: &inference.BackendConfiguration{
120122
ContextSize: 1234,
121123
},
122-
expected: []string{
123-
"--jinja",
124-
"-ngl", "999",
125-
"--metrics",
124+
expected: append(slices.Clone(baseArgs),
126125
"--model", modelPath,
127126
"--host", socket,
128127
"--embeddings",
129128
"--ctx-size", "1234", // should add this flag
130-
},
129+
),
131130
},
132131
{
133132
name: "context size from model config",
@@ -141,15 +140,12 @@ func TestGetArgs(t *testing.T) {
141140
config: &inference.BackendConfiguration{
142141
ContextSize: 1234,
143142
},
144-
expected: []string{
145-
"--jinja",
146-
"-ngl", "999",
147-
"--metrics",
143+
expected: append(slices.Clone(baseArgs),
148144
"--model", modelPath,
149145
"--host", socket,
150146
"--embeddings",
151147
"--ctx-size", "2096", // model config takes precedence
152-
},
148+
),
153149
},
154150
{
155151
name: "chat template from model artifact",
@@ -158,15 +154,12 @@ func TestGetArgs(t *testing.T) {
158154
ggufPath: modelPath,
159155
templatePath: "/path/to/bundle/template.jinja",
160156
},
161-
expected: []string{
162-
"--jinja",
163-
"-ngl", "999",
164-
"--metrics",
157+
expected: append(slices.Clone(baseArgs),
165158
"--model", modelPath,
166159
"--host", socket,
167160
"--chat-template-file", "/path/to/bundle/template.jinja",
168161
"--ctx-size", "4096",
169-
},
162+
),
170163
},
171164
{
172165
name: "raw flags from backend config",
@@ -177,16 +170,13 @@ func TestGetArgs(t *testing.T) {
177170
config: &inference.BackendConfiguration{
178171
RuntimeFlags: []string{"--some", "flag"},
179172
},
180-
expected: []string{
181-
"--jinja",
182-
"-ngl", "999",
183-
"--metrics",
173+
expected: append(slices.Clone(baseArgs),
184174
"--model", modelPath,
185175
"--host", socket,
186176
"--embeddings",
187177
"--ctx-size", "4096",
188178
"--some", "flag", // model config takes precedence
189-
},
179+
),
190180
},
191181
}
192182

0 commit comments

Comments
 (0)