Skip to content

Commit 5dd0f5e

Browse files
committed
De-duplicate code
In various places Signed-off-by: Eric Curtin <[email protected]>
1 parent 7fdb650 commit 5dd0f5e

File tree

3 files changed

+14
-76
lines changed

3 files changed

+14
-76
lines changed

main.go

Lines changed: 2 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@ import (
88
"os"
99
"os/signal"
1010
"path/filepath"
11-
"strings"
1211
"syscall"
1312

1413
"github.com/docker/model-runner/pkg/distribution/transport/resumable"
1514
"github.com/docker/model-runner/pkg/gpuinfo"
1615
"github.com/docker/model-runner/pkg/inference"
1716
"github.com/docker/model-runner/pkg/inference/backends/llamacpp"
1817
"github.com/docker/model-runner/pkg/inference/backends/vllm"
18+
"github.com/docker/model-runner/pkg/inference/common"
1919
"github.com/docker/model-runner/pkg/inference/config"
2020
"github.com/docker/model-runner/pkg/inference/memory"
2121
"github.com/docker/model-runner/pkg/inference/models"
@@ -255,7 +255,7 @@ func createLlamaCppConfigFromEnv() config.BackendConfig {
255255
}
256256

257257
// Split the string by spaces, respecting quoted arguments
258-
args := splitArgs(argsStr)
258+
args := common.SplitArgs(argsStr)
259259

260260
// Check for disallowed arguments
261261
disallowedArgs := []string{"--model", "--host", "--embeddings", "--mmproj"}
@@ -273,29 +273,4 @@ func createLlamaCppConfigFromEnv() config.BackendConfig {
273273
}
274274
}
275275

276-
// splitArgs splits a string into arguments, respecting quoted arguments
277-
func splitArgs(s string) []string {
278-
var args []string
279-
var currentArg strings.Builder
280-
inQuotes := false
281-
282-
for _, r := range s {
283-
switch {
284-
case r == '"' || r == '\'':
285-
inQuotes = !inQuotes
286-
case r == ' ' && !inQuotes:
287-
if currentArg.Len() > 0 {
288-
args = append(args, currentArg.String())
289-
currentArg.Reset()
290-
}
291-
default:
292-
currentArg.WriteRune(r)
293-
}
294-
}
295276

296-
if currentArg.Len() > 0 {
297-
args = append(args, currentArg.String())
298-
}
299-
300-
return args
301-
}

pkg/inference/backends/llamacpp/llamacpp.go

Lines changed: 7 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import (
77
"errors"
88
"fmt"
99
"io"
10-
"io/fs"
1110
"net/http"
1211
"os"
1312
"os/exec"
@@ -23,9 +22,9 @@ import (
2322

2423
"github.com/docker/model-runner/pkg/diskusage"
2524
"github.com/docker/model-runner/pkg/inference"
25+
"github.com/docker/model-runner/pkg/inference/common"
2626
"github.com/docker/model-runner/pkg/inference/config"
2727
"github.com/docker/model-runner/pkg/inference/models"
28-
"github.com/docker/model-runner/pkg/internal/utils"
2928
"github.com/docker/model-runner/pkg/logging"
3029
"github.com/docker/model-runner/pkg/sandbox"
3130
"github.com/docker/model-runner/pkg/tailbuffer"
@@ -149,8 +148,8 @@ func (l *llamaCpp) Run(ctx context.Context, socket, model string, _ string, mode
149148
}
150149
}
151150

152-
if err := os.RemoveAll(socket); err != nil && !errors.Is(err, fs.ErrNotExist) {
153-
l.log.Warnf("failed to remove socket file %s: %w\n", socket, err)
151+
if err := common.HandleSocketCleanup(socket); err != nil {
152+
l.log.Warnf("failed to remove socket file %s: %v\n", socket, err)
154153
l.log.Warnln("llama.cpp may not be able to start")
155154
}
156155

@@ -177,12 +176,7 @@ func (l *llamaCpp) Run(ctx context.Context, socket, model string, _ string, mode
177176
}
178177
}
179178

180-
// Sanitize args for safe logging
181-
sanitizedArgs := make([]string, len(args))
182-
for i, arg := range args {
183-
sanitizedArgs[i] = utils.SanitizeForLog(arg)
184-
}
185-
l.log.Infof("llamaCppArgs: %v", sanitizedArgs)
179+
common.SanitizedArgsLog(l.log, "llamaCppArgs", args)
186180
tailBuf := tailbuffer.NewTailBuffer(1024)
187181
serverLogStream := l.serverLog.Writer()
188182
out := io.MultiWriter(serverLogStream, tailBuf)
@@ -210,24 +204,11 @@ func (l *llamaCpp) Run(ctx context.Context, socket, model string, _ string, mode
210204

211205
llamaCppErrors := make(chan error, 1)
212206
go func() {
213-
llamaCppErr := llamaCppSandbox.Command().Wait()
214-
serverLogStream.Close()
215-
216-
errOutput := new(strings.Builder)
217-
if _, err := io.Copy(errOutput, tailBuf); err != nil {
218-
l.log.Warnf("failed to read server output tail: %w", err)
219-
}
220-
221-
if len(errOutput.String()) != 0 {
222-
llamaCppErr = fmt.Errorf("llama.cpp exit status: %w\nwith output: %s", llamaCppErr, errOutput.String())
223-
} else {
224-
llamaCppErr = fmt.Errorf("llama.cpp exit status: %w", llamaCppErr)
225-
}
226-
207+
llamaCppErr := common.ProcessExitHandler(l.log, llamaCppSandbox, tailBuf, serverLogStream, socket)
227208
llamaCppErrors <- llamaCppErr
228209
close(llamaCppErrors)
229-
if err := os.Remove(socket); err != nil && !errors.Is(err, fs.ErrNotExist) {
230-
l.log.Warnf("failed to remove socket file %s on exit: %w\n", socket, err)
210+
if err := common.HandleSocketCleanup(socket); err != nil {
211+
l.log.Warnf("failed to remove socket file %s on exit: %v\n", socket, err)
231212
}
232213
}()
233214
defer func() {

pkg/inference/backends/vllm/vllm.go

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ import (
1717
"github.com/docker/model-runner/pkg/diskusage"
1818
"github.com/docker/model-runner/pkg/distribution/types"
1919
"github.com/docker/model-runner/pkg/inference"
20+
"github.com/docker/model-runner/pkg/inference/common"
2021
"github.com/docker/model-runner/pkg/inference/models"
2122
"github.com/docker/model-runner/pkg/inference/platform"
22-
"github.com/docker/model-runner/pkg/internal/utils"
2323
"github.com/docker/model-runner/pkg/logging"
2424
"github.com/docker/model-runner/pkg/sandbox"
2525
"github.com/docker/model-runner/pkg/tailbuffer"
@@ -118,7 +118,7 @@ func (v *vLLM) Run(ctx context.Context, socket, model string, modelRef string, m
118118
}
119119
}
120120

121-
if err := os.RemoveAll(socket); err != nil && !errors.Is(err, fs.ErrNotExist) {
121+
if err := common.HandleSocketCleanup(socket); err != nil {
122122
v.log.Warnf("failed to remove socket file %s: %v\n", socket, err)
123123
v.log.Warnln("vLLM may not be able to start")
124124
}
@@ -151,12 +151,7 @@ func (v *vLLM) Run(ctx context.Context, socket, model string, modelRef string, m
151151

152152
args = append(args, "--served-model-name", model, modelRef)
153153

154-
// Sanitize args for safe logging
155-
sanitizedArgs := make([]string, len(args))
156-
for i, arg := range args {
157-
sanitizedArgs[i] = utils.SanitizeForLog(arg)
158-
}
159-
v.log.Infof("vLLM args: %v", sanitizedArgs)
154+
common.SanitizedArgsLog(v.log, "vLLM args", args)
160155
tailBuf := tailbuffer.NewTailBuffer(1024)
161156
serverLogStream := v.serverLog.Writer()
162157
out := io.MultiWriter(serverLogStream, tailBuf)
@@ -184,23 +179,10 @@ func (v *vLLM) Run(ctx context.Context, socket, model string, modelRef string, m
184179

185180
vllmErrors := make(chan error, 1)
186181
go func() {
187-
vllmErr := vllmSandbox.Command().Wait()
188-
serverLogStream.Close()
189-
190-
errOutput := new(strings.Builder)
191-
if _, err := io.Copy(errOutput, tailBuf); err != nil {
192-
v.log.Warnf("failed to read server output tail: %v", err)
193-
}
194-
195-
if len(errOutput.String()) != 0 {
196-
vllmErr = fmt.Errorf("vLLM exit status: %w\nwith output: %s", vllmErr, errOutput.String())
197-
} else {
198-
vllmErr = fmt.Errorf("vLLM exit status: %w", vllmErr)
199-
}
200-
182+
vllmErr := common.ProcessExitHandler(v.log, vllmSandbox, tailBuf, serverLogStream, socket)
201183
vllmErrors <- vllmErr
202184
close(vllmErrors)
203-
if err := os.Remove(socket); err != nil && !errors.Is(err, fs.ErrNotExist) {
185+
if err := common.HandleSocketCleanup(socket); err != nil {
204186
v.log.Warnf("failed to remove socket file %s on exit: %v\n", socket, err)
205187
}
206188
}()

0 commit comments

Comments
 (0)