diff --git a/main.go b/main.go index e5c30ea98..5bf70e63e 100644 --- a/main.go +++ b/main.go @@ -8,7 +8,6 @@ import ( "os" "os/signal" "path/filepath" - "strings" "syscall" "github.com/docker/model-runner/pkg/distribution/transport/resumable" @@ -16,6 +15,7 @@ import ( "github.com/docker/model-runner/pkg/inference" "github.com/docker/model-runner/pkg/inference/backends/llamacpp" "github.com/docker/model-runner/pkg/inference/backends/vllm" + "github.com/docker/model-runner/pkg/inference/common" "github.com/docker/model-runner/pkg/inference/config" "github.com/docker/model-runner/pkg/inference/memory" "github.com/docker/model-runner/pkg/inference/models" @@ -255,7 +255,7 @@ func createLlamaCppConfigFromEnv() config.BackendConfig { } // Split the string by spaces, respecting quoted arguments - args := splitArgs(argsStr) + args := common.SplitArgs(argsStr) // Check for disallowed arguments disallowedArgs := []string{"--model", "--host", "--embeddings", "--mmproj"} @@ -273,29 +273,4 @@ func createLlamaCppConfigFromEnv() config.BackendConfig { } } -// splitArgs splits a string into arguments, respecting quoted arguments -func splitArgs(s string) []string { - var args []string - var currentArg strings.Builder - inQuotes := false - - for _, r := range s { - switch { - case r == '"' || r == '\'': - inQuotes = !inQuotes - case r == ' ' && !inQuotes: - if currentArg.Len() > 0 { - args = append(args, currentArg.String()) - currentArg.Reset() - } - default: - currentArg.WriteRune(r) - } - } - if currentArg.Len() > 0 { - args = append(args, currentArg.String()) - } - - return args -} diff --git a/pkg/inference/backends/llamacpp/llamacpp.go b/pkg/inference/backends/llamacpp/llamacpp.go index 129e1734b..f3d453dde 100644 --- a/pkg/inference/backends/llamacpp/llamacpp.go +++ b/pkg/inference/backends/llamacpp/llamacpp.go @@ -7,7 +7,6 @@ import ( "errors" "fmt" "io" - "io/fs" "net/http" "os" "os/exec" @@ -23,9 +22,9 @@ import ( "github.com/docker/model-runner/pkg/diskusage" "github.com/docker/model-runner/pkg/inference" + "github.com/docker/model-runner/pkg/inference/common" "github.com/docker/model-runner/pkg/inference/config" "github.com/docker/model-runner/pkg/inference/models" - "github.com/docker/model-runner/pkg/internal/utils" "github.com/docker/model-runner/pkg/logging" "github.com/docker/model-runner/pkg/sandbox" "github.com/docker/model-runner/pkg/tailbuffer" @@ -149,8 +148,8 @@ func (l *llamaCpp) Run(ctx context.Context, socket, model string, _ string, mode } } - if err := os.RemoveAll(socket); err != nil && !errors.Is(err, fs.ErrNotExist) { - l.log.Warnf("failed to remove socket file %s: %w\n", socket, err) + if err := common.HandleSocketCleanup(socket); err != nil { + l.log.Warnf("failed to remove socket file %s: %v\n", socket, err) l.log.Warnln("llama.cpp may not be able to start") } @@ -177,12 +176,7 @@ func (l *llamaCpp) Run(ctx context.Context, socket, model string, _ string, mode } } - // Sanitize args for safe logging - sanitizedArgs := make([]string, len(args)) - for i, arg := range args { - sanitizedArgs[i] = utils.SanitizeForLog(arg) - } - l.log.Infof("llamaCppArgs: %v", sanitizedArgs) + common.SanitizedArgsLog(l.log, "llamaCppArgs", args) tailBuf := tailbuffer.NewTailBuffer(1024) serverLogStream := l.serverLog.Writer() out := io.MultiWriter(serverLogStream, tailBuf) @@ -210,24 +204,11 @@ func (l *llamaCpp) Run(ctx context.Context, socket, model string, _ string, mode llamaCppErrors := make(chan error, 1) go func() { - llamaCppErr := llamaCppSandbox.Command().Wait() - serverLogStream.Close() - - errOutput := new(strings.Builder) - if _, err := io.Copy(errOutput, tailBuf); err != nil { - l.log.Warnf("failed to read server output tail: %w", err) - } - - if len(errOutput.String()) != 0 { - llamaCppErr = fmt.Errorf("llama.cpp exit status: %w\nwith output: %s", llamaCppErr, errOutput.String()) - } else { - llamaCppErr = fmt.Errorf("llama.cpp exit status: %w", llamaCppErr) - } - + llamaCppErr := common.ProcessExitHandler(l.log, llamaCppSandbox, tailBuf, serverLogStream, socket) llamaCppErrors <- llamaCppErr close(llamaCppErrors) - if err := os.Remove(socket); err != nil && !errors.Is(err, fs.ErrNotExist) { - l.log.Warnf("failed to remove socket file %s on exit: %w\n", socket, err) + if err := common.HandleSocketCleanup(socket); err != nil { + l.log.Warnf("failed to remove socket file %s on exit: %v\n", socket, err) } }() defer func() { diff --git a/pkg/inference/backends/vllm/vllm.go b/pkg/inference/backends/vllm/vllm.go index c6233fa4c..f8d96a7e8 100644 --- a/pkg/inference/backends/vllm/vllm.go +++ b/pkg/inference/backends/vllm/vllm.go @@ -17,9 +17,9 @@ import ( "github.com/docker/model-runner/pkg/diskusage" "github.com/docker/model-runner/pkg/distribution/types" "github.com/docker/model-runner/pkg/inference" + "github.com/docker/model-runner/pkg/inference/common" "github.com/docker/model-runner/pkg/inference/models" "github.com/docker/model-runner/pkg/inference/platform" - "github.com/docker/model-runner/pkg/internal/utils" "github.com/docker/model-runner/pkg/logging" "github.com/docker/model-runner/pkg/sandbox" "github.com/docker/model-runner/pkg/tailbuffer" @@ -118,7 +118,7 @@ func (v *vLLM) Run(ctx context.Context, socket, model string, modelRef string, m } } - if err := os.RemoveAll(socket); err != nil && !errors.Is(err, fs.ErrNotExist) { + if err := common.HandleSocketCleanup(socket); err != nil { v.log.Warnf("failed to remove socket file %s: %v\n", socket, err) v.log.Warnln("vLLM may not be able to start") } @@ -151,12 +151,7 @@ func (v *vLLM) Run(ctx context.Context, socket, model string, modelRef string, m args = append(args, "--served-model-name", model, modelRef) - // Sanitize args for safe logging - sanitizedArgs := make([]string, len(args)) - for i, arg := range args { - sanitizedArgs[i] = utils.SanitizeForLog(arg) - } - v.log.Infof("vLLM args: %v", sanitizedArgs) + common.SanitizedArgsLog(v.log, "vLLM args", args) tailBuf := tailbuffer.NewTailBuffer(1024) serverLogStream := v.serverLog.Writer() out := io.MultiWriter(serverLogStream, tailBuf) @@ -184,23 +179,10 @@ func (v *vLLM) Run(ctx context.Context, socket, model string, modelRef string, m vllmErrors := make(chan error, 1) go func() { - vllmErr := vllmSandbox.Command().Wait() - serverLogStream.Close() - - errOutput := new(strings.Builder) - if _, err := io.Copy(errOutput, tailBuf); err != nil { - v.log.Warnf("failed to read server output tail: %v", err) - } - - if len(errOutput.String()) != 0 { - vllmErr = fmt.Errorf("vLLM exit status: %w\nwith output: %s", vllmErr, errOutput.String()) - } else { - vllmErr = fmt.Errorf("vLLM exit status: %w", vllmErr) - } - + vllmErr := common.ProcessExitHandler(v.log, vllmSandbox, tailBuf, serverLogStream, socket) vllmErrors <- vllmErr close(vllmErrors) - if err := os.Remove(socket); err != nil && !errors.Is(err, fs.ErrNotExist) { + if err := common.HandleSocketCleanup(socket); err != nil { v.log.Warnf("failed to remove socket file %s on exit: %v\n", socket, err) } }() diff --git a/pkg/inference/common/backend_utils.go b/pkg/inference/common/backend_utils.go new file mode 100644 index 000000000..1f9ccf654 --- /dev/null +++ b/pkg/inference/common/backend_utils.go @@ -0,0 +1,96 @@ +package common + +import ( + "errors" + "io" + "os" + "strings" + + "github.com/docker/model-runner/pkg/internal/utils" + "github.com/docker/model-runner/pkg/logging" + "github.com/docker/model-runner/pkg/sandbox" +) + +// SanitizedArgsLog logs command arguments with sanitization for safe logging +func SanitizedArgsLog(log logging.Logger, label string, args []string) { + sanitizedArgs := make([]string, len(args)) + for i, arg := range args { + sanitizedArgs[i] = utils.SanitizeForLog(arg) + } + log.Infof("%s: %v", label, sanitizedArgs) +} + +// HandleSocketCleanup removes the socket file at the given path, ignoring if it doesn't exist +func HandleSocketCleanup(socket string) error { + if err := os.RemoveAll(socket); err != nil && !errors.Is(err, os.ErrNotExist) { + return err + } + return nil +} + +// ProcessExitHandler handles the exit of a backend process and captures its output +func ProcessExitHandler( + log logging.Logger, + sandboxInstance sandbox.Sandbox, + tailBuf io.ReadWriter, + serverLogStream io.Closer, + socket string, +) error { + serverLogStream.Close() + + errOutput := new(strings.Builder) + if _, err := io.Copy(errOutput, tailBuf); err != nil { + log.Warnf("failed to read server output tail: %v", err) + } + + cmdErr := sandboxInstance.Command().Wait() + outputStr := errOutput.String() + if len(outputStr) != 0 { + return &BackendExitError{ + Err: cmdErr, + Output: outputStr, + } + } else { + return cmdErr + } +} + +// BackendExitError represents an error when a backend process exits +type BackendExitError struct { + Err error + Output string +} + +func (e *BackendExitError) Error() string { + if e.Output != "" { + return e.Err.Error() + "\nwith output: " + e.Output + } + return e.Err.Error() +} + +// SplitArgs splits a string into arguments, respecting quoted arguments +func SplitArgs(s string) []string { + var args []string + var currentArg strings.Builder + inQuotes := false + + for _, r := range s { + switch { + case r == '"' || r == '\'': + inQuotes = !inQuotes + case r == ' ' && !inQuotes: + if currentArg.Len() > 0 { + args = append(args, currentArg.String()) + currentArg.Reset() + } + default: + currentArg.WriteRune(r) + } + } + + if currentArg.Len() > 0 { + args = append(args, currentArg.String()) + } + + return args +} \ No newline at end of file