@@ -2,6 +2,7 @@ package llamacpp
22
33import (
44 "runtime"
5+ "slices"
56 "strconv"
67 "testing"
78
@@ -72,6 +73,13 @@ func TestGetArgs(t *testing.T) {
7273 modelPath := "/path/to/model"
7374 socket := "unix:///tmp/socket"
7475
76+ // Build base expected args based on architecture
77+ baseArgs := []string {"--jinja" , "-ngl" , "999" , "--metrics" }
78+ if runtime .GOARCH == "arm64" {
79+ nThreads := max (2 , runtime .NumCPU ()/ 2 )
80+ baseArgs = append (baseArgs , "--threads" , strconv .Itoa (nThreads ))
81+ }
82+
7583 tests := []struct {
7684 name string
7785 bundle types.ModelBundle
@@ -85,30 +93,24 @@ func TestGetArgs(t *testing.T) {
8593 bundle : & fakeBundle {
8694 ggufPath : modelPath ,
8795 },
88- expected : []string {
89- "--jinja" ,
90- "-ngl" , "999" ,
91- "--metrics" ,
96+ expected : append (slices .Clone (baseArgs ),
9297 "--model" , modelPath ,
9398 "--host" , socket ,
9499 "--ctx-size" , "4096" ,
95- } ,
100+ ) ,
96101 },
97102 {
98103 name : "embedding mode" ,
99104 mode : inference .BackendModeEmbedding ,
100105 bundle : & fakeBundle {
101106 ggufPath : modelPath ,
102107 },
103- expected : []string {
104- "--jinja" ,
105- "-ngl" , "999" ,
106- "--metrics" ,
108+ expected : append (slices .Clone (baseArgs ),
107109 "--model" , modelPath ,
108110 "--host" , socket ,
109111 "--embeddings" ,
110112 "--ctx-size" , "4096" ,
111- } ,
113+ ) ,
112114 },
113115 {
114116 name : "context size from backend config" ,
@@ -119,15 +121,12 @@ func TestGetArgs(t *testing.T) {
119121 config : & inference.BackendConfiguration {
120122 ContextSize : 1234 ,
121123 },
122- expected : []string {
123- "--jinja" ,
124- "-ngl" , "999" ,
125- "--metrics" ,
124+ expected : append (slices .Clone (baseArgs ),
126125 "--model" , modelPath ,
127126 "--host" , socket ,
128127 "--embeddings" ,
129128 "--ctx-size" , "1234" , // should add this flag
130- } ,
129+ ) ,
131130 },
132131 {
133132 name : "context size from model config" ,
@@ -141,15 +140,12 @@ func TestGetArgs(t *testing.T) {
141140 config : & inference.BackendConfiguration {
142141 ContextSize : 1234 ,
143142 },
144- expected : []string {
145- "--jinja" ,
146- "-ngl" , "999" ,
147- "--metrics" ,
143+ expected : append (slices .Clone (baseArgs ),
148144 "--model" , modelPath ,
149145 "--host" , socket ,
150146 "--embeddings" ,
151147 "--ctx-size" , "2096" , // model config takes precedence
152- } ,
148+ ) ,
153149 },
154150 {
155151 name : "chat template from model artifact" ,
@@ -158,15 +154,12 @@ func TestGetArgs(t *testing.T) {
158154 ggufPath : modelPath ,
159155 templatePath : "/path/to/bundle/template.jinja" ,
160156 },
161- expected : []string {
162- "--jinja" ,
163- "-ngl" , "999" ,
164- "--metrics" ,
157+ expected : append (slices .Clone (baseArgs ),
165158 "--model" , modelPath ,
166159 "--host" , socket ,
167160 "--chat-template-file" , "/path/to/bundle/template.jinja" ,
168161 "--ctx-size" , "4096" ,
169- } ,
162+ ) ,
170163 },
171164 {
172165 name : "raw flags from backend config" ,
@@ -177,16 +170,13 @@ func TestGetArgs(t *testing.T) {
177170 config : & inference.BackendConfiguration {
178171 RuntimeFlags : []string {"--some" , "flag" },
179172 },
180- expected : []string {
181- "--jinja" ,
182- "-ngl" , "999" ,
183- "--metrics" ,
173+ expected : append (slices .Clone (baseArgs ),
184174 "--model" , modelPath ,
185175 "--host" , socket ,
186176 "--embeddings" ,
187177 "--ctx-size" , "4096" ,
188178 "--some" , "flag" , // model config takes precedence
189- } ,
179+ ) ,
190180 },
191181 }
192182
0 commit comments