Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 2 additions & 27 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ import (
"os"
"os/signal"
"path/filepath"
"strings"
"syscall"

"github.com/docker/model-runner/pkg/distribution/transport/resumable"
"github.com/docker/model-runner/pkg/gpuinfo"
"github.com/docker/model-runner/pkg/inference"
"github.com/docker/model-runner/pkg/inference/backends/llamacpp"
"github.com/docker/model-runner/pkg/inference/backends/vllm"
"github.com/docker/model-runner/pkg/inference/common"
"github.com/docker/model-runner/pkg/inference/config"
"github.com/docker/model-runner/pkg/inference/memory"
"github.com/docker/model-runner/pkg/inference/models"
Expand Down Expand Up @@ -255,7 +255,7 @@ func createLlamaCppConfigFromEnv() config.BackendConfig {
}

// Split the string by spaces, respecting quoted arguments
args := splitArgs(argsStr)
args := common.SplitArgs(argsStr)

// Check for disallowed arguments
disallowedArgs := []string{"--model", "--host", "--embeddings", "--mmproj"}
Expand All @@ -273,29 +273,4 @@ func createLlamaCppConfigFromEnv() config.BackendConfig {
}
}

// splitArgs splits a string into arguments, respecting quoted arguments
func splitArgs(s string) []string {
var args []string
var currentArg strings.Builder
inQuotes := false

for _, r := range s {
switch {
case r == '"' || r == '\'':
inQuotes = !inQuotes
case r == ' ' && !inQuotes:
if currentArg.Len() > 0 {
args = append(args, currentArg.String())
currentArg.Reset()
}
default:
currentArg.WriteRune(r)
}
}

if currentArg.Len() > 0 {
args = append(args, currentArg.String())
}

return args
}
33 changes: 7 additions & 26 deletions pkg/inference/backends/llamacpp/llamacpp.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"errors"
"fmt"
"io"
"io/fs"
"net/http"
"os"
"os/exec"
Expand All @@ -23,9 +22,9 @@ import (

"github.com/docker/model-runner/pkg/diskusage"
"github.com/docker/model-runner/pkg/inference"
"github.com/docker/model-runner/pkg/inference/common"
"github.com/docker/model-runner/pkg/inference/config"
"github.com/docker/model-runner/pkg/inference/models"
"github.com/docker/model-runner/pkg/internal/utils"
"github.com/docker/model-runner/pkg/logging"
"github.com/docker/model-runner/pkg/sandbox"
"github.com/docker/model-runner/pkg/tailbuffer"
Expand Down Expand Up @@ -149,8 +148,8 @@ func (l *llamaCpp) Run(ctx context.Context, socket, model string, _ string, mode
}
}

if err := os.RemoveAll(socket); err != nil && !errors.Is(err, fs.ErrNotExist) {
l.log.Warnf("failed to remove socket file %s: %w\n", socket, err)
if err := common.HandleSocketCleanup(socket); err != nil {
l.log.Warnf("failed to remove socket file %s: %v\n", socket, err)
l.log.Warnln("llama.cpp may not be able to start")
}

Expand All @@ -177,12 +176,7 @@ func (l *llamaCpp) Run(ctx context.Context, socket, model string, _ string, mode
}
}

// Sanitize args for safe logging
sanitizedArgs := make([]string, len(args))
for i, arg := range args {
sanitizedArgs[i] = utils.SanitizeForLog(arg)
}
l.log.Infof("llamaCppArgs: %v", sanitizedArgs)
common.SanitizedArgsLog(l.log, "llamaCppArgs", args)
tailBuf := tailbuffer.NewTailBuffer(1024)
serverLogStream := l.serverLog.Writer()
out := io.MultiWriter(serverLogStream, tailBuf)
Expand Down Expand Up @@ -210,24 +204,11 @@ func (l *llamaCpp) Run(ctx context.Context, socket, model string, _ string, mode

llamaCppErrors := make(chan error, 1)
go func() {
llamaCppErr := llamaCppSandbox.Command().Wait()
serverLogStream.Close()

errOutput := new(strings.Builder)
if _, err := io.Copy(errOutput, tailBuf); err != nil {
l.log.Warnf("failed to read server output tail: %w", err)
}

if len(errOutput.String()) != 0 {
llamaCppErr = fmt.Errorf("llama.cpp exit status: %w\nwith output: %s", llamaCppErr, errOutput.String())
} else {
llamaCppErr = fmt.Errorf("llama.cpp exit status: %w", llamaCppErr)
}

llamaCppErr := common.ProcessExitHandler(l.log, llamaCppSandbox, tailBuf, serverLogStream, socket)
llamaCppErrors <- llamaCppErr
close(llamaCppErrors)
if err := os.Remove(socket); err != nil && !errors.Is(err, fs.ErrNotExist) {
l.log.Warnf("failed to remove socket file %s on exit: %w\n", socket, err)
if err := common.HandleSocketCleanup(socket); err != nil {
l.log.Warnf("failed to remove socket file %s on exit: %v\n", socket, err)
}
}()
Comment on lines 206 to 213
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The call to common.HandleSocketCleanup(socket) on lines 210-212 appears to be redundant, as the socket variable is already passed to common.ProcessExitHandler on line 207. This implies that ProcessExitHandler is responsible for the socket cleanup.

To avoid duplicate operations and improve clarity, the explicit cleanup call should be removed. If ProcessExitHandler does not handle cleanup, its signature should be changed to not accept a socket parameter.

 	go func() {
 		llamaCppErr := common.ProcessExitHandler(l.log, llamaCppSandbox, tailBuf, serverLogStream, socket)
 		llamaCppErrors <- llamaCppErr
 		close(llamaCppErrors)
 	}()

defer func() {
Expand Down
28 changes: 5 additions & 23 deletions pkg/inference/backends/vllm/vllm.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ import (
"github.com/docker/model-runner/pkg/diskusage"
"github.com/docker/model-runner/pkg/distribution/types"
"github.com/docker/model-runner/pkg/inference"
"github.com/docker/model-runner/pkg/inference/common"
"github.com/docker/model-runner/pkg/inference/models"
"github.com/docker/model-runner/pkg/inference/platform"
"github.com/docker/model-runner/pkg/internal/utils"
"github.com/docker/model-runner/pkg/logging"
"github.com/docker/model-runner/pkg/sandbox"
"github.com/docker/model-runner/pkg/tailbuffer"
Expand Down Expand Up @@ -118,7 +118,7 @@ func (v *vLLM) Run(ctx context.Context, socket, model string, modelRef string, m
}
}

if err := os.RemoveAll(socket); err != nil && !errors.Is(err, fs.ErrNotExist) {
if err := common.HandleSocketCleanup(socket); err != nil {
v.log.Warnf("failed to remove socket file %s: %v\n", socket, err)
v.log.Warnln("vLLM may not be able to start")
}
Expand Down Expand Up @@ -151,12 +151,7 @@ func (v *vLLM) Run(ctx context.Context, socket, model string, modelRef string, m

args = append(args, "--served-model-name", model, modelRef)

// Sanitize args for safe logging
sanitizedArgs := make([]string, len(args))
for i, arg := range args {
sanitizedArgs[i] = utils.SanitizeForLog(arg)
}
v.log.Infof("vLLM args: %v", sanitizedArgs)
common.SanitizedArgsLog(v.log, "vLLM args", args)
tailBuf := tailbuffer.NewTailBuffer(1024)
serverLogStream := v.serverLog.Writer()
out := io.MultiWriter(serverLogStream, tailBuf)
Expand Down Expand Up @@ -184,23 +179,10 @@ func (v *vLLM) Run(ctx context.Context, socket, model string, modelRef string, m

vllmErrors := make(chan error, 1)
go func() {
vllmErr := vllmSandbox.Command().Wait()
serverLogStream.Close()

errOutput := new(strings.Builder)
if _, err := io.Copy(errOutput, tailBuf); err != nil {
v.log.Warnf("failed to read server output tail: %v", err)
}

if len(errOutput.String()) != 0 {
vllmErr = fmt.Errorf("vLLM exit status: %w\nwith output: %s", vllmErr, errOutput.String())
} else {
vllmErr = fmt.Errorf("vLLM exit status: %w", vllmErr)
}

vllmErr := common.ProcessExitHandler(v.log, vllmSandbox, tailBuf, serverLogStream, socket)
vllmErrors <- vllmErr
close(vllmErrors)
if err := os.Remove(socket); err != nil && !errors.Is(err, fs.ErrNotExist) {
if err := common.HandleSocketCleanup(socket); err != nil {
v.log.Warnf("failed to remove socket file %s on exit: %v\n", socket, err)
}
}()
Comment on lines 181 to 188
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The call to common.HandleSocketCleanup(socket) on lines 185-187 appears to be redundant, as the socket variable is already passed to common.ProcessExitHandler on line 182. This implies that ProcessExitHandler is responsible for the socket cleanup.

To avoid duplicate operations and improve clarity, the explicit cleanup call should be removed. If ProcessExitHandler does not handle cleanup, its signature should be changed to not accept a socket parameter.

 	go func() {
 		vllmErr := common.ProcessExitHandler(v.log, vllmSandbox, tailBuf, serverLogStream, socket)
 		vllmErrors <- vllmErr
 		close(vllmErrors)
 	}()

Expand Down
96 changes: 96 additions & 0 deletions pkg/inference/common/backend_utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package common

import (
"errors"
"io"
"os"
"strings"

"github.com/docker/model-runner/pkg/internal/utils"
"github.com/docker/model-runner/pkg/logging"
"github.com/docker/model-runner/pkg/sandbox"
)

// SanitizedArgsLog logs command arguments with sanitization for safe logging
func SanitizedArgsLog(log logging.Logger, label string, args []string) {
sanitizedArgs := make([]string, len(args))
for i, arg := range args {
sanitizedArgs[i] = utils.SanitizeForLog(arg)
}
log.Infof("%s: %v", label, sanitizedArgs)
}

// HandleSocketCleanup removes the socket file at the given path, ignoring if it doesn't exist
func HandleSocketCleanup(socket string) error {
if err := os.RemoveAll(socket); err != nil && !errors.Is(err, os.ErrNotExist) {
return err
}
return nil
}

// ProcessExitHandler handles the exit of a backend process and captures its output
func ProcessExitHandler(
log logging.Logger,
sandboxInstance sandbox.Sandbox,
tailBuf io.ReadWriter,
serverLogStream io.Closer,
socket string,
) error {
serverLogStream.Close()

errOutput := new(strings.Builder)
if _, err := io.Copy(errOutput, tailBuf); err != nil {
log.Warnf("failed to read server output tail: %v", err)
}

cmdErr := sandboxInstance.Command().Wait()
outputStr := errOutput.String()
if len(outputStr) != 0 {
return &BackendExitError{
Err: cmdErr,
Output: outputStr,
}
} else {
return cmdErr
}
}

// BackendExitError represents an error when a backend process exits
type BackendExitError struct {
Err error
Output string
}

func (e *BackendExitError) Error() string {
if e.Output != "" {
return e.Err.Error() + "\nwith output: " + e.Output
}
return e.Err.Error()
}

// SplitArgs splits a string into arguments, respecting quoted arguments
func SplitArgs(s string) []string {
var args []string
var currentArg strings.Builder
inQuotes := false

for _, r := range s {
switch {
case r == '"' || r == '\'':
inQuotes = !inQuotes
case r == ' ' && !inQuotes:
if currentArg.Len() > 0 {
args = append(args, currentArg.String())
currentArg.Reset()
}
default:
currentArg.WriteRune(r)
}
}

if currentArg.Len() > 0 {
args = append(args, currentArg.String())
}

return args
}
Loading