diff --git a/.github/workflows/promote-to-latest.yml b/.github/workflows/promote-to-latest.yml index 246dc6f3e..8e223ac14 100644 --- a/.github/workflows/promote-to-latest.yml +++ b/.github/workflows/promote-to-latest.yml @@ -5,7 +5,7 @@ on: workflow_dispatch: inputs: version: - description: 'version' + description: "version" required: true type: string @@ -42,6 +42,11 @@ jobs: echo "Promoting vLLM CUDA images" crane tag "docker/model-runner:${{ inputs.version }}-vllm-cuda" "latest-vllm-cuda" + - name: Promote SGLang CUDA images + run: | + echo "Promoting SGLang CUDA images" + crane tag "docker/model-runner:${{ inputs.version }}-sglang-cuda" "latest-sglang-cuda" + - name: Promote ROCm images run: | echo "Promoting ROCm images" diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 2ab6d8bfa..e679baa8c 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -5,28 +5,33 @@ on: workflow_dispatch: inputs: pushLatest: - description: 'Tag images produced by this job as latest' + description: "Tag images produced by this job as latest" required: false type: boolean default: false releaseTag: - description: 'Release tag' + description: "Release tag" required: false type: string default: "test" llamaServerVersion: - description: 'llama-server version' + description: "llama-server version" required: false type: string default: "latest" vllmVersion: - description: 'vLLM version' + description: "vLLM version" required: false type: string default: "0.12.0" + sglangVersion: + description: "SGLang version" + required: false + type: string + default: "0.4.0" # This can be removed once we have llama.cpp built for MUSA and CANN. buildMusaCann: - description: 'Build MUSA and CANN images' + description: "Build MUSA and CANN images" required: false type: boolean default: false @@ -76,6 +81,12 @@ jobs: echo "docker/model-runner:latest-vllm-cuda" >> "$GITHUB_OUTPUT" fi echo 'EOF' >> "$GITHUB_OUTPUT" + echo "sglang-cuda<> "$GITHUB_OUTPUT" + echo "docker/model-runner:${{ inputs.releaseTag }}-sglang-cuda" >> "$GITHUB_OUTPUT" + if [ "${{ inputs.pushLatest }}" == "true" ]; then + echo "docker/model-runner:latest-sglang-cuda" >> "$GITHUB_OUTPUT" + fi + echo 'EOF' >> "$GITHUB_OUTPUT" echo "rocm<> "$GITHUB_OUTPUT" echo "docker/model-runner:${{ inputs.releaseTag }}-rocm" >> "$GITHUB_OUTPUT" if [ "${{ inputs.pushLatest }}" == "true" ]; then @@ -155,6 +166,22 @@ jobs: provenance: mode=max tags: ${{ steps.tags.outputs.vllm-cuda }} + - name: Build SGLang CUDA image + uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 + with: + file: Dockerfile + target: final-sglang + platforms: linux/amd64 + build-args: | + "LLAMA_SERVER_VERSION=${{ inputs.llamaServerVersion }}" + "LLAMA_SERVER_VARIANT=cuda" + "BASE_IMAGE=nvidia/cuda:12.9.0-runtime-ubuntu24.04" + "SGLANG_VERSION=${{ inputs.sglangVersion }}" + push: true + sbom: true + provenance: mode=max + tags: ${{ steps.tags.outputs.sglang-cuda }} + - name: Build ROCm image uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 with: diff --git a/Dockerfile b/Dockerfile index f26157c11..7d71f3dc1 100644 --- a/Dockerfile +++ b/Dockerfile @@ -33,7 +33,13 @@ COPY --link . . # Build the Go binary (static build) RUN --mount=type=cache,target=/go/pkg/mod \ --mount=type=cache,target=/root/.cache/go-build \ - CGO_ENABLED=1 GOOS=linux go build -ldflags="-s -w" -o model-runner ./main.go + CGO_ENABLED=1 GOOS=linux go build -ldflags="-s -w" -o model-runner . + +# Build the Go binary for SGLang (without vLLM) +FROM builder AS builder-sglang +RUN --mount=type=cache,target=/go/pkg/mod \ + --mount=type=cache,target=/root/.cache/go-build \ + CGO_ENABLED=1 GOOS=linux go build -tags=novllm -ldflags="-s -w" -o model-runner . # --- Get llama.cpp binary --- FROM docker/docker-model-backend-llamacpp:${LLAMA_SERVER_VERSION}-${LLAMA_SERVER_VARIANT} AS llama-server @@ -97,17 +103,50 @@ USER modelrunner # Install uv and vLLM as modelrunner user RUN curl -LsSf https://astral.sh/uv/install.sh | sh \ - && ~/.local/bin/uv venv --python /usr/bin/python3 /opt/vllm-env \ - && if [ "$TARGETARCH" = "amd64" ]; then \ - WHEEL_ARCH="manylinux_2_31_x86_64"; \ - WHEEL_URL="https://github.com/vllm-project/vllm/releases/download/v${VLLM_VERSION}/vllm-${VLLM_VERSION}%2B${VLLM_CUDA_VERSION}-${VLLM_PYTHON_TAG}-${WHEEL_ARCH}.whl"; \ - ~/.local/bin/uv pip install --python /opt/vllm-env/bin/python "$WHEEL_URL"; \ + && ~/.local/bin/uv venv --python /usr/bin/python3 /opt/vllm-env \ + && if [ "$TARGETARCH" = "amd64" ]; then \ + WHEEL_ARCH="manylinux_2_31_x86_64"; \ + WHEEL_URL="https://github.com/vllm-project/vllm/releases/download/v${VLLM_VERSION}/vllm-${VLLM_VERSION}%2B${VLLM_CUDA_VERSION}-${VLLM_PYTHON_TAG}-${WHEEL_ARCH}.whl"; \ + ~/.local/bin/uv pip install --python /opt/vllm-env/bin/python "$WHEEL_URL"; \ else \ - ~/.local/bin/uv pip install --python /opt/vllm-env/bin/python "vllm==${VLLM_VERSION}"; \ + ~/.local/bin/uv pip install --python /opt/vllm-env/bin/python "vllm==${VLLM_VERSION}"; \ fi RUN /opt/vllm-env/bin/python -c "import vllm; print(vllm.__version__)" > /opt/vllm-env/version +# --- SGLang variant --- +FROM llamacpp AS sglang + +ARG SGLANG_VERSION=0.5.6 + +USER root + +# Install CUDA toolkit 13 for nvcc (needed for flashinfer JIT compilation) +RUN apt update && apt install -y \ + python3 python3-venv python3-dev \ + curl ca-certificates build-essential \ + libnuma1 libnuma-dev numactl ninja-build \ + wget gnupg \ + && wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb \ + && dpkg -i cuda-keyring_1.1-1_all.deb \ + && apt update && apt install -y cuda-toolkit-13-0 \ + && rm cuda-keyring_1.1-1_all.deb \ + && rm -rf /var/lib/apt/lists/* + +RUN mkdir -p /opt/sglang-env && chown -R modelrunner:modelrunner /opt/sglang-env + +USER modelrunner + +# Set CUDA paths for nvcc (needed during flashinfer compilation) +ENV PATH=/usr/local/cuda-13.0/bin:$PATH +ENV LD_LIBRARY_PATH=/usr/local/cuda-13.0/lib64:$LD_LIBRARY_PATH + +# Install uv and SGLang as modelrunner user +RUN curl -LsSf https://astral.sh/uv/install.sh | sh \ + && ~/.local/bin/uv venv --python /usr/bin/python3 /opt/sglang-env \ + && ~/.local/bin/uv pip install --python /opt/sglang-env/bin/python "sglang==${SGLANG_VERSION}" + +RUN /opt/sglang-env/bin/python -c "import sglang; print(sglang.__version__)" > /opt/sglang-env/version FROM llamacpp AS final-llamacpp # Copy the built binary from builder COPY --from=builder /app/model-runner /app/model-runner @@ -115,3 +154,7 @@ COPY --from=builder /app/model-runner /app/model-runner FROM vllm AS final-vllm # Copy the built binary from builder COPY --from=builder /app/model-runner /app/model-runner + +FROM sglang AS final-sglang +# Copy the built binary from builder-sglang (without vLLM) +COPY --from=builder-sglang /app/model-runner /app/model-runner diff --git a/Makefile b/Makefile index 9ecb2e2d5..f9fead629 100644 --- a/Makefile +++ b/Makefile @@ -7,6 +7,8 @@ BASE_IMAGE := ubuntu:24.04 VLLM_BASE_IMAGE := nvidia/cuda:13.0.2-runtime-ubuntu24.04 DOCKER_IMAGE := docker/model-runner:latest DOCKER_IMAGE_VLLM := docker/model-runner:latest-vllm-cuda +DOCKER_IMAGE_SGLANG := docker/model-runner:latest-sglang +DOCKER_IMAGE_SGLANG_CUDA := docker/model-runner:latest-sglang-cuda DOCKER_TARGET ?= final-llamacpp PORT := 8080 MODELS_PATH := $(shell pwd)/models-store @@ -31,13 +33,13 @@ LICENSE ?= BUILD_DMR ?= 1 # Main targets -.PHONY: build run clean test integration-tests test-docker-ce-installation docker-build docker-build-multiplatform docker-run docker-build-vllm docker-run-vllm docker-run-impl help validate lint model-distribution-tool +.PHONY: build run clean test integration-tests test-docker-ce-installation docker-build docker-build-multiplatform docker-run docker-build-vllm docker-run-vllm docker-build-sglang docker-run-sglang docker-build-sglang-cuda docker-run-sglang-cuda docker-run-impl help validate lint model-distribution-tool # Default target .DEFAULT_GOAL := build # Build the Go application build: - CGO_ENABLED=1 go build -ldflags="-s -w" -o $(APP_NAME) ./main.go + CGO_ENABLED=1 go build -ldflags="-s -w" -o $(APP_NAME) . # Build model-distribution-tool model-distribution-tool: @@ -116,6 +118,30 @@ docker-build-vllm: docker-run-vllm: docker-build-vllm @$(MAKE) -s docker-run-impl DOCKER_IMAGE=$(DOCKER_IMAGE_VLLM) +# Build SGLang Docker image (CPU variant) +docker-build-sglang: + @$(MAKE) docker-build \ + DOCKER_TARGET=final-sglang \ + DOCKER_IMAGE=$(DOCKER_IMAGE_SGLANG) \ + LLAMA_SERVER_VARIANT=cpu \ + BASE_IMAGE=$(BASE_IMAGE) + +# Run SGLang Docker container (CPU variant) with TCP port access and mounted model storage +docker-run-sglang: docker-build-sglang + @$(MAKE) -s docker-run-impl DOCKER_IMAGE=$(DOCKER_IMAGE_SGLANG) + +# Build SGLang Docker image (CUDA variant) +docker-build-sglang-cuda: + @$(MAKE) docker-build \ + DOCKER_TARGET=final-sglang \ + DOCKER_IMAGE=$(DOCKER_IMAGE_SGLANG_CUDA) \ + LLAMA_SERVER_VARIANT=cuda \ + BASE_IMAGE=$(VLLM_BASE_IMAGE) + +# Run SGLang Docker container (CUDA variant) with TCP port access and mounted model storage +docker-run-sglang-cuda: docker-build-sglang-cuda + @$(MAKE) -s docker-run-impl DOCKER_IMAGE=$(DOCKER_IMAGE_SGLANG_CUDA) + # Common implementation for running Docker container docker-run-impl: @echo "" @@ -178,6 +204,10 @@ help: @echo " docker-run - Run in Docker container with TCP port access and mounted model storage" @echo " docker-build-vllm - Build vLLM Docker image" @echo " docker-run-vllm - Run vLLM Docker container" + @echo " docker-build-sglang - Build SGLang Docker image (CPU)" + @echo " docker-run-sglang - Run SGLang Docker container (CPU)" + @echo " docker-build-sglang-cuda - Build SGLang Docker image (CUDA)" + @echo " docker-run-sglang-cuda - Run SGLang Docker container (CUDA)" @echo " help - Show this help message" @echo "" @echo "Model distribution tool targets:" diff --git a/backends_vllm.go b/backends_vllm.go new file mode 100644 index 000000000..66c6de1c5 --- /dev/null +++ b/backends_vllm.go @@ -0,0 +1,23 @@ +//go:build !novllm + +package main + +import ( + "github.com/docker/model-runner/pkg/inference" + "github.com/docker/model-runner/pkg/inference/backends/vllm" + "github.com/docker/model-runner/pkg/inference/models" + "github.com/sirupsen/logrus" +) + +func initVLLMBackend(log *logrus.Logger, modelManager *models.Manager) (inference.Backend, error) { + return vllm.New( + log, + modelManager, + log.WithFields(logrus.Fields{"component": vllm.Name}), + nil, + ) +} + +func registerVLLMBackend(backends map[string]inference.Backend, backend inference.Backend) { + backends[vllm.Name] = backend +} diff --git a/backends_vllm_stub.go b/backends_vllm_stub.go new file mode 100644 index 000000000..dceb094a7 --- /dev/null +++ b/backends_vllm_stub.go @@ -0,0 +1,17 @@ +//go:build novllm + +package main + +import ( + "github.com/docker/model-runner/pkg/inference" + "github.com/docker/model-runner/pkg/inference/models" + "github.com/sirupsen/logrus" +) + +func initVLLMBackend(log *logrus.Logger, modelManager *models.Manager) (inference.Backend, error) { + return nil, nil +} + +func registerVLLMBackend(backends map[string]inference.Backend, backend inference.Backend) { + // No-op when vLLM is disabled +} diff --git a/main.go b/main.go index 90aacbb0b..7e188b719 100644 --- a/main.go +++ b/main.go @@ -14,6 +14,7 @@ import ( "github.com/docker/model-runner/pkg/inference" "github.com/docker/model-runner/pkg/inference/backends/llamacpp" "github.com/docker/model-runner/pkg/inference/backends/mlx" + "github.com/docker/model-runner/pkg/inference/backends/sglang" "github.com/docker/model-runner/pkg/inference/backends/vllm" "github.com/docker/model-runner/pkg/inference/config" "github.com/docker/model-runner/pkg/inference/models" @@ -106,12 +107,7 @@ func main() { log.Fatalf("unable to initialize %s backend: %v", llamacpp.Name, err) } - vllmBackend, err := vllm.New( - log, - modelManager, - log.WithFields(logrus.Fields{"component": vllm.Name}), - nil, - ) + vllmBackend, err := initVLLMBackend(log, modelManager) if err != nil { log.Fatalf("unable to initialize %s backend: %v", vllm.Name, err) } @@ -126,13 +122,26 @@ func main() { log.Fatalf("unable to initialize %s backend: %v", mlx.Name, err) } + sglangBackend, err := sglang.New( + log, + modelManager, + log.WithFields(logrus.Fields{"component": sglang.Name}), + nil, + ) + if err != nil { + log.Fatalf("unable to initialize %s backend: %v", sglang.Name, err) + } + + backends := map[string]inference.Backend{ + llamacpp.Name: llamaCppBackend, + mlx.Name: mlxBackend, + sglang.Name: sglangBackend, + } + registerVLLMBackend(backends, vllmBackend) + scheduler := scheduling.NewScheduler( log, - map[string]inference.Backend{ - llamacpp.Name: llamaCppBackend, - vllm.Name: vllmBackend, - mlx.Name: mlxBackend, - }, + backends, llamaCppBackend, modelManager, http.DefaultClient, diff --git a/pkg/inference/backend.go b/pkg/inference/backend.go index 36b7580a1..595ad7d12 100644 --- a/pkg/inference/backend.go +++ b/pkg/inference/backend.go @@ -132,6 +132,10 @@ type Backend interface { // external model management system and false if the backend uses the shared // model manager. UsesExternalModelManagement() bool + // UsesTCP returns true if the backend uses TCP for communication instead + // of Unix sockets. When true, the scheduler will create a TCP transport + // and pass a "host:port" address to Run instead of a Unix socket path. + UsesTCP() bool // Install ensures that the backend is installed. It should return a nil // error if installation succeeds or if the backend is already installed. // The provided HTTP client should be used for any HTTP operations. diff --git a/pkg/inference/backends/llamacpp/llamacpp.go b/pkg/inference/backends/llamacpp/llamacpp.go index 8c8ba6f76..802cb7084 100644 --- a/pkg/inference/backends/llamacpp/llamacpp.go +++ b/pkg/inference/backends/llamacpp/llamacpp.go @@ -89,6 +89,11 @@ func (l *llamaCpp) UsesExternalModelManagement() bool { return false } +// UsesTCP implements inference.Backend.UsesTCP. +func (l *llamaCpp) UsesTCP() bool { + return false +} + // Install implements inference.Backend.Install. func (l *llamaCpp) Install(ctx context.Context, httpClient *http.Client) error { l.updatedLlamaCpp = false diff --git a/pkg/inference/backends/mlx/mlx.go b/pkg/inference/backends/mlx/mlx.go index 420c7c5f3..27ae59419 100644 --- a/pkg/inference/backends/mlx/mlx.go +++ b/pkg/inference/backends/mlx/mlx.go @@ -65,6 +65,11 @@ func (m *mlx) UsesExternalModelManagement() bool { return false } +// UsesTCP implements inference.Backend.UsesTCP. +func (m *mlx) UsesTCP() bool { + return false +} + // Install implements inference.Backend.Install. func (m *mlx) Install(ctx context.Context, httpClient *http.Client) error { if !platform.SupportsMLX() { diff --git a/pkg/inference/backends/sglang/sglang.go b/pkg/inference/backends/sglang/sglang.go new file mode 100644 index 000000000..e1821e1d8 --- /dev/null +++ b/pkg/inference/backends/sglang/sglang.go @@ -0,0 +1,204 @@ +package sglang + +import ( + "context" + "errors" + "fmt" + "net/http" + "os" + "os/exec" + "path/filepath" + "strings" + + "github.com/docker/model-runner/pkg/diskusage" + "github.com/docker/model-runner/pkg/inference" + "github.com/docker/model-runner/pkg/inference/backends" + "github.com/docker/model-runner/pkg/inference/models" + "github.com/docker/model-runner/pkg/inference/platform" + "github.com/docker/model-runner/pkg/logging" +) + +const ( + // Name is the backend name. + Name = "sglang" + sglangDir = "/opt/sglang-env" +) + +var ( + ErrNotImplemented = errors.New("not implemented") + ErrSGLangNotFound = errors.New("sglang package not installed") + ErrPythonNotFound = errors.New("python3 not found in PATH") +) + +// sglang is the SGLang-based backend implementation. +type sglang struct { + // log is the associated logger. + log logging.Logger + // modelManager is the shared model manager. + modelManager *models.Manager + // serverLog is the logger to use for the SGLang server process. + serverLog logging.Logger + // config is the configuration for the SGLang backend. + config *Config + // status is the state in which the SGLang backend is in. + status string + // pythonPath is the path to the python3 binary. + pythonPath string +} + +// New creates a new SGLang-based backend. +func New(log logging.Logger, modelManager *models.Manager, serverLog logging.Logger, conf *Config) (inference.Backend, error) { + // If no config is provided, use the default configuration + if conf == nil { + conf = NewDefaultSGLangConfig() + } + + return &sglang{ + log: log, + modelManager: modelManager, + serverLog: serverLog, + config: conf, + status: "not installed", + }, nil +} + +// Name implements inference.Backend.Name. +func (s *sglang) Name() string { + return Name +} + +func (s *sglang) UsesExternalModelManagement() bool { + return false +} + +// UsesTCP implements inference.Backend.UsesTCP. +// SGLang only supports TCP, not Unix sockets. +func (s *sglang) UsesTCP() bool { + return true +} + +func (s *sglang) Install(_ context.Context, _ *http.Client) error { + if !platform.SupportsSGLang() { + return ErrNotImplemented + } + + venvPython := filepath.Join(sglangDir, "bin", "python3") + pythonPath := venvPython + + if _, err := os.Stat(venvPython); err != nil { + // Fall back to system Python + systemPython, err := exec.LookPath("python3") + if err != nil { + s.status = ErrPythonNotFound.Error() + return ErrPythonNotFound + } + pythonPath = systemPython + } + + s.pythonPath = pythonPath + + // Check if sglang is installed + if err := s.pythonCmd("-c", "import sglang").Run(); err != nil { + s.status = "sglang package not installed" + s.log.Warnf("sglang package not found. Install with: uv pip install sglang") + return ErrSGLangNotFound + } + + // Get version + output, err := s.pythonCmd("-c", "import sglang; print(sglang.__version__)").Output() + if err != nil { + s.log.Warnf("could not get sglang version: %v", err) + s.status = "running sglang version: unknown" + } else { + s.status = fmt.Sprintf("running sglang version: %s", strings.TrimSpace(string(output))) + } + + return nil +} + +func (s *sglang) Run(ctx context.Context, socket, model string, modelRef string, mode inference.BackendMode, backendConfig *inference.BackendConfiguration) error { + if !platform.SupportsSGLang() { + s.log.Warn("sglang backend is not yet supported") + return ErrNotImplemented + } + + bundle, err := s.modelManager.GetBundle(model) + if err != nil { + return fmt.Errorf("failed to get model: %w", err) + } + + args, err := s.config.GetArgs(bundle, socket, mode, backendConfig) + if err != nil { + return fmt.Errorf("failed to get SGLang arguments: %w", err) + } + + // Add served model name and weight version + if model != "" { + // SGLang 0.5.6+ doesn't allow colons in served-model-name (reserved for LoRA syntax) + // Replace colons with underscores to sanitize the model name + sanitizedModel := strings.ReplaceAll(model, ":", "_") + args = append(args, "--served-model-name", sanitizedModel) + } + if modelRef != "" { + args = append(args, "--weight-version", modelRef) + } + + if s.pythonPath == "" { + return fmt.Errorf("sglang: python runtime not configured; did you forget to call Install?") + } + + sandboxPath := "" + if _, err := os.Stat(sglangDir); err == nil { + sandboxPath = sglangDir + } + + return backends.RunBackend(ctx, backends.RunnerConfig{ + BackendName: "SGLang", + Socket: socket, + BinaryPath: s.pythonPath, + SandboxPath: sandboxPath, + SandboxConfig: "", + Args: args, + Logger: s.log, + ServerLogWriter: s.serverLog.Writer(), + }) +} + +func (s *sglang) Status() string { + return s.status +} + +func (s *sglang) GetDiskUsage() (int64, error) { + // Check if Docker installation exists + if _, err := os.Stat(sglangDir); err == nil { + size, err := diskusage.Size(sglangDir) + if err != nil { + return 0, fmt.Errorf("error while getting sglang dir size: %w", err) + } + return size, nil + } + // Python installation doesn't have a dedicated installation directory + // It's installed via pip in the system Python environment + return 0, nil +} + +func (s *sglang) GetRequiredMemoryForModel(_ context.Context, _ string, _ *inference.BackendConfiguration) (inference.RequiredMemory, error) { + if !platform.SupportsSGLang() { + return inference.RequiredMemory{}, ErrNotImplemented + } + + return inference.RequiredMemory{ + RAM: 1, + VRAM: 1, + }, nil +} + +// pythonCmd creates an exec.Cmd that runs python with the given arguments. +// It uses the configured pythonPath if available, otherwise falls back to "python3". +func (s *sglang) pythonCmd(args ...string) *exec.Cmd { + pythonBinary := "python3" + if s.pythonPath != "" { + pythonBinary = s.pythonPath + } + return exec.Command(pythonBinary, args...) +} diff --git a/pkg/inference/backends/sglang/sglang_config.go b/pkg/inference/backends/sglang/sglang_config.go new file mode 100644 index 000000000..8b500a906 --- /dev/null +++ b/pkg/inference/backends/sglang/sglang_config.go @@ -0,0 +1,79 @@ +package sglang + +import ( + "fmt" + "net" + "path/filepath" + "strconv" + + "github.com/docker/model-runner/pkg/distribution/types" + "github.com/docker/model-runner/pkg/inference" +) + +// Config is the configuration for the SGLang backend. +type Config struct { + // Args are the base arguments that are always included. + Args []string +} + +// NewDefaultSGLangConfig creates a new SGLangConfig with default values. +func NewDefaultSGLangConfig() *Config { + return &Config{} +} + +// GetArgs implements BackendConfig.GetArgs. +func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference.BackendMode, config *inference.BackendConfiguration) ([]string, error) { + // Start with the arguments from SGLangConfig + args := append([]string{}, c.Args...) + + // SGLang uses Python module: python -m sglang.launch_server + args = append(args, "-m", "sglang.launch_server") + + // Add model path + safetensorsPath := bundle.SafetensorsPath() + if safetensorsPath == "" { + return nil, fmt.Errorf("safetensors path required by SGLang backend") + } + modelPath := filepath.Dir(safetensorsPath) + args = append(args, "--model-path", modelPath) + + host, port, err := net.SplitHostPort(socket) + if err != nil { + return nil, fmt.Errorf("failed to parse host:port from %q: %w", socket, err) + } + args = append(args, "--host", host, "--port", port) + + // Add mode-specific arguments + switch mode { + case inference.BackendModeCompletion: + // Default mode for SGLang + case inference.BackendModeEmbedding: + args = append(args, "--is-embedding") + case inference.BackendModeReranking: + default: + return nil, fmt.Errorf("unsupported backend mode %q", mode) + } + + // Add context-length if specified in model config or backend config + if contextLen := GetContextLength(bundle.RuntimeConfig(), config); contextLen != nil { + args = append(args, "--context-length", strconv.Itoa(int(*contextLen))) + } + + return args, nil +} + +// GetContextLength returns the context length (context size) from model config or backend config. +// Model config takes precedence over backend config. +// Returns nil if neither is specified (SGLang will auto-derive from model). +func GetContextLength(modelCfg types.Config, backendCfg *inference.BackendConfiguration) *int32 { + // Model config takes precedence + if modelCfg.ContextSize != nil && *modelCfg.ContextSize > 0 { + return modelCfg.ContextSize + } + // Fallback to backend config + if backendCfg != nil && backendCfg.ContextSize != nil && *backendCfg.ContextSize > 0 { + return backendCfg.ContextSize + } + // Return nil to let SGLang auto-derive from model config + return nil +} diff --git a/pkg/inference/backends/sglang/sglang_config_test.go b/pkg/inference/backends/sglang/sglang_config_test.go new file mode 100644 index 000000000..28886527a --- /dev/null +++ b/pkg/inference/backends/sglang/sglang_config_test.go @@ -0,0 +1,254 @@ +package sglang + +import ( + "testing" + + "github.com/docker/model-runner/pkg/distribution/types" + "github.com/docker/model-runner/pkg/inference" +) + +type mockModelBundle struct { + safetensorsPath string + runtimeConfig types.Config +} + +func (m *mockModelBundle) GGUFPath() string { + return "" +} + +func (m *mockModelBundle) SafetensorsPath() string { + return m.safetensorsPath +} + +func (m *mockModelBundle) ChatTemplatePath() string { + return "" +} + +func (m *mockModelBundle) MMPROJPath() string { + return "" +} + +func (m *mockModelBundle) RuntimeConfig() types.Config { + return m.runtimeConfig +} + +func (m *mockModelBundle) RootDir() string { + return "/path/to/bundle" +} + +func TestGetArgs(t *testing.T) { + tests := []struct { + name string + config *inference.BackendConfiguration + bundle *mockModelBundle + mode inference.BackendMode + expected []string + expectError bool + }{ + { + name: "empty safetensors path should error", + bundle: &mockModelBundle{ + safetensorsPath: "", + }, + mode: inference.BackendModeCompletion, + config: nil, + expected: nil, + expectError: true, + }, + { + name: "basic args without context size", + bundle: &mockModelBundle{ + safetensorsPath: "/path/to/model/model.safetensors", + }, + mode: inference.BackendModeCompletion, + config: nil, + expected: []string{ + "-m", + "sglang.launch_server", + "--model-path", + "/path/to/model", + "--host", + "127.0.0.1", + "--port", + "30000", + }, + }, + { + name: "with backend context size", + bundle: &mockModelBundle{ + safetensorsPath: "/path/to/model/model.safetensors", + }, + mode: inference.BackendModeCompletion, + config: &inference.BackendConfiguration{ + ContextSize: int32ptr(8192), + }, + expected: []string{ + "-m", + "sglang.launch_server", + "--model-path", + "/path/to/model", + "--host", + "127.0.0.1", + "--port", + "30000", + "--context-length", + "8192", + }, + }, + { + name: "with model context size (takes precedence)", + bundle: &mockModelBundle{ + safetensorsPath: "/path/to/model/model.safetensors", + runtimeConfig: types.Config{ + ContextSize: int32ptr(16384), + }, + }, + mode: inference.BackendModeCompletion, + config: &inference.BackendConfiguration{ + ContextSize: int32ptr(8192), + }, + expected: []string{ + "-m", + "sglang.launch_server", + "--model-path", + "/path/to/model", + "--host", + "127.0.0.1", + "--port", + "30000", + "--context-length", + "16384", + }, + }, + { + name: "embedding mode", + bundle: &mockModelBundle{ + safetensorsPath: "/path/to/model/model.safetensors", + }, + mode: inference.BackendModeEmbedding, + config: nil, + expected: []string{ + "-m", + "sglang.launch_server", + "--model-path", + "/path/to/model", + "--host", + "127.0.0.1", + "--port", + "30000", + "--is-embedding", + }, + }, + { + name: "reranking mode", + bundle: &mockModelBundle{ + safetensorsPath: "/path/to/model/model.safetensors", + }, + mode: inference.BackendModeReranking, + config: nil, + expected: []string{ + "-m", + "sglang.launch_server", + "--model-path", + "/path/to/model", + "--host", + "127.0.0.1", + "--port", + "30000", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := NewDefaultSGLangConfig() + args, err := config.GetArgs(tt.bundle, "127.0.0.1:30000", tt.mode, tt.config) + + if tt.expectError { + if err == nil { + t.Fatalf("expected error but got none") + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(args) != len(tt.expected) { + t.Fatalf("expected %d args, got %d\nexpected: %v\ngot: %v", len(tt.expected), len(args), tt.expected, args) + } + + for i, arg := range args { + if arg != tt.expected[i] { + t.Errorf("arg[%d]: expected %q, got %q", i, tt.expected[i], arg) + } + } + }) + } +} + +func TestGetContextLength(t *testing.T) { + tests := []struct { + name string + modelCfg types.Config + backendCfg *inference.BackendConfiguration + expectedValue *int32 + }{ + { + name: "no config", + modelCfg: types.Config{}, + backendCfg: nil, + expectedValue: nil, + }, + { + name: "backend config only", + modelCfg: types.Config{}, + backendCfg: &inference.BackendConfiguration{ + ContextSize: int32ptr(4096), + }, + expectedValue: int32ptr(4096), + }, + { + name: "model config only", + modelCfg: types.Config{ + ContextSize: int32ptr(8192), + }, + backendCfg: nil, + expectedValue: int32ptr(8192), + }, + { + name: "model config takes precedence", + modelCfg: types.Config{ + ContextSize: int32ptr(16384), + }, + backendCfg: &inference.BackendConfiguration{ + ContextSize: int32ptr(4096), + }, + expectedValue: int32ptr(16384), + }, + { + name: "zero context size in backend config returns nil", + modelCfg: types.Config{}, + backendCfg: &inference.BackendConfiguration{ + ContextSize: int32ptr(0), + }, + expectedValue: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GetContextLength(tt.modelCfg, tt.backendCfg) + if (result == nil) != (tt.expectedValue == nil) { + t.Errorf("expected nil=%v, got nil=%v", tt.expectedValue == nil, result == nil) + } else if result != nil && *result != *tt.expectedValue { + t.Errorf("expected %d, got %d", *tt.expectedValue, *result) + } + }) + } +} + +func int32ptr(v int32) *int32 { + return &v +} diff --git a/pkg/inference/backends/vllm/vllm.go b/pkg/inference/backends/vllm/vllm.go index aa50bb280..9792b2a1a 100644 --- a/pkg/inference/backends/vllm/vllm.go +++ b/pkg/inference/backends/vllm/vllm.go @@ -67,6 +67,11 @@ func (v *vLLM) UsesExternalModelManagement() bool { return false } +// UsesTCP implements inference.Backend.UsesTCP. +func (v *vLLM) UsesTCP() bool { + return false +} + func (v *vLLM) Install(_ context.Context, _ *http.Client) error { if !platform.SupportsVLLM() { return errors.New("not implemented") diff --git a/pkg/inference/platform/platform.go b/pkg/inference/platform/platform.go index b6f3d7f5c..49bffb75e 100644 --- a/pkg/inference/platform/platform.go +++ b/pkg/inference/platform/platform.go @@ -12,3 +12,8 @@ func SupportsVLLM() bool { func SupportsMLX() bool { return runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" } + +// SupportsSGLang returns true if SGLang is supported on the current platform. +func SupportsSGLang() bool { + return runtime.GOOS == "linux" +} diff --git a/pkg/inference/scheduling/loader_test.go b/pkg/inference/scheduling/loader_test.go index 7ac5841b2..5a7ab7e04 100644 --- a/pkg/inference/scheduling/loader_test.go +++ b/pkg/inference/scheduling/loader_test.go @@ -47,6 +47,10 @@ func (m *mockBackend) UsesExternalModelManagement() bool { return m.usesExternalModelMgmt } +func (m *mockBackend) UsesTCP() bool { + return false +} + // fastFailBackend is a backend that immediately fails on Run to short-circuit wait() type fastFailBackend struct{ mockBackend } diff --git a/pkg/inference/scheduling/runner.go b/pkg/inference/scheduling/runner.go index ff3ea9b1e..73ccd7625 100644 --- a/pkg/inference/scheduling/runner.go +++ b/pkg/inference/scheduling/runner.go @@ -11,6 +11,7 @@ import ( "net/http" "net/http/httputil" "net/url" + "strconv" "time" "github.com/docker/model-runner/pkg/inference" @@ -26,6 +27,10 @@ const ( // readinessRetryInterval is the interval at which a runner will retry // readiness checks for a backend. readinessRetryInterval = 500 * time.Millisecond + // tcpBackendBasePort is the base port number for TCP-based backends. + // Each slot gets a unique port: basePort + slot (e.g., 30000, 30001, 30002). + // Port 30000+ is used to avoid conflicts with common services. + tcpBackendBasePort = 30000 ) // errBackendNotReadyInTime indicates that an inference backend took too @@ -83,14 +88,22 @@ func run( openAIRecorder *metrics.OpenAIRecorder, ) (*runner, error) { // Create a dialer / transport that target backend on the specified slot. - socket, err := RunnerSocketPath(slot) - if err != nil { - return nil, fmt.Errorf("unable to determine runner socket path: %w", err) + network := "tcp" + socket := net.JoinHostPort("127.0.0.1", strconv.Itoa(tcpBackendBasePort+slot)) + + if !backend.UsesTCP() { + var err error + socket, err = RunnerSocketPath(slot) + if err != nil { + return nil, fmt.Errorf("unable to determine runner socket path: %w", err) + } + network = "unix" } + dialer := &net.Dialer{} transport := &http.Transport{ DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) { - return dialer.DialContext(ctx, "unix", socket) + return dialer.DialContext(ctx, network, socket) }, MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, diff --git a/pkg/inference/scheduling/scheduler.go b/pkg/inference/scheduling/scheduler.go index fdafb1ee9..7f4247a10 100644 --- a/pkg/inference/scheduling/scheduler.go +++ b/pkg/inference/scheduling/scheduler.go @@ -11,6 +11,7 @@ import ( "github.com/docker/model-runner/pkg/distribution/types" "github.com/docker/model-runner/pkg/inference" "github.com/docker/model-runner/pkg/inference/backends/llamacpp" + "github.com/docker/model-runner/pkg/inference/backends/sglang" "github.com/docker/model-runner/pkg/inference/backends/vllm" "github.com/docker/model-runner/pkg/inference/models" "github.com/docker/model-runner/pkg/internal/utils" @@ -100,10 +101,15 @@ func (s *Scheduler) selectBackendForModel(model types.Model, backend inference.B } if config.Format == types.FormatSafetensors { + // Prefer vLLM for safetensors models if vllmBackend, ok := s.backends[vllm.Name]; ok && vllmBackend != nil { return vllmBackend } - s.log.Warnf("Model %s is in safetensors format but vLLM backend is not available. "+ + // Fall back to SGLang if vLLM is not available + if sglangBackend, ok := s.backends[sglang.Name]; ok && sglangBackend != nil { + return sglangBackend + } + s.log.Warnf("Model %s is in safetensors format but vLLM and SGLang backends are not available. "+ "Backend %s may not support this format and could fail at runtime.", utils.SanitizeForLog(modelRef), backend.Name()) }