Skip to content

Commit b2ffe15

Browse files
committed
vllm: Linux support only
Signed-off-by: Dorin Geman <dorin.geman@docker.com>
1 parent c0f57de commit b2ffe15

File tree

4 files changed

+60
-19
lines changed

4 files changed

+60
-19
lines changed

pkg/distribution/distribution/client.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,17 @@ import (
66
"fmt"
77
"io"
88
"net/http"
9-
"runtime"
109
"slices"
1110

12-
"github.com/docker/model-runner/pkg/distribution/internal/utils"
1311
"github.com/sirupsen/logrus"
1412

1513
"github.com/docker/model-runner/pkg/distribution/internal/progress"
1614
"github.com/docker/model-runner/pkg/distribution/internal/store"
15+
"github.com/docker/model-runner/pkg/distribution/internal/utils"
1716
"github.com/docker/model-runner/pkg/distribution/registry"
1817
"github.com/docker/model-runner/pkg/distribution/tarball"
1918
"github.com/docker/model-runner/pkg/distribution/types"
19+
"github.com/docker/model-runner/pkg/inference/platform"
2020
)
2121

2222
// Client provides model distribution functionality
@@ -403,7 +403,7 @@ func (c *Client) GetBundle(ref string) (types.ModelBundle, error) {
403403
}
404404

405405
func GetSupportedFormats() []types.Format {
406-
if runtime.GOOS == "linux" {
406+
if platform.SupportsVLLM() {
407407
return []types.Format{types.FormatGGUF, types.FormatSafetensors}
408408
}
409409
return []types.Format{types.FormatGGUF}

pkg/distribution/distribution/client_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ import (
1313
"net/url"
1414
"os"
1515
"path/filepath"
16-
"runtime"
1716
"strings"
1817
"testing"
1918

@@ -27,6 +26,7 @@ import (
2726
"github.com/docker/model-runner/pkg/distribution/internal/progress"
2827
"github.com/docker/model-runner/pkg/distribution/internal/safetensors"
2928
mdregistry "github.com/docker/model-runner/pkg/distribution/registry"
29+
"github.com/docker/model-runner/pkg/inference/platform"
3030
)
3131

3232
var (
@@ -464,7 +464,7 @@ func TestClientPullModel(t *testing.T) {
464464

465465
// Try to pull the safetensors model
466466
err = testClient.PullModel(context.Background(), tag, nil)
467-
if runtime.GOOS == "linux" {
467+
if platform.SupportsVLLM() {
468468
// On Linux, safetensors should be supported
469469
if err != nil {
470470
t.Fatalf("Expected no error on Linux, got: %v", err)

pkg/inference/backends/vllm/vllm.go

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,19 @@ import (
1313
"runtime"
1414
"strings"
1515

16+
"github.com/docker/model-runner/pkg/diskusage"
1617
"github.com/docker/model-runner/pkg/inference"
1718
"github.com/docker/model-runner/pkg/inference/models"
19+
"github.com/docker/model-runner/pkg/inference/platform"
1820
"github.com/docker/model-runner/pkg/logging"
1921
"github.com/docker/model-runner/pkg/sandbox"
2022
"github.com/docker/model-runner/pkg/tailbuffer"
2123
)
2224

2325
const (
2426
// Name is the backend name.
25-
Name = "vllm"
27+
Name = "vllm"
28+
vllmDir = "/opt/vllm-env/bin/"
2629
)
2730

2831
// vLLM is the vLLM-based backend implementation.
@@ -64,28 +67,48 @@ func (v *vLLM) UsesExternalModelManagement() bool {
6467
return false
6568
}
6669

67-
func (v *vLLM) Install(ctx context.Context, httpClient *http.Client) error {
70+
func (v *vLLM) Install(_ context.Context, _ *http.Client) error {
71+
if !platform.SupportsVLLM() {
72+
return errors.New("not implemented")
73+
}
74+
75+
vllmBinaryPath := v.binaryPath()
76+
if _, err := os.Stat(vllmBinaryPath); err != nil {
77+
if errors.Is(err, fs.ErrNotExist) {
78+
return fmt.Errorf("vLLM binary not found at %s", vllmBinaryPath)
79+
}
80+
return fmt.Errorf("failed to check vLLM binary: %w", err)
81+
}
82+
83+
// TODO: Find a way to get vllm's version. Running `vllm --version` is too slow.
84+
v.status = "running"
85+
6886
return nil
6987
}
7088

71-
func (v *vLLM) Run(ctx context.Context, socket, model string, modelRef string, mode inference.BackendMode, config *inference.BackendConfiguration) error {
89+
func (v *vLLM) Run(ctx context.Context, socket, model string, modelRef string, _ inference.BackendMode, _ *inference.BackendConfiguration) error {
90+
if !platform.SupportsVLLM() {
91+
v.log.Warn("vLLM backend is not yet supported")
92+
return errors.New("not implemented")
93+
}
94+
7295
bundle, err := v.modelManager.GetBundle(model)
7396
if err != nil {
7497
return fmt.Errorf("failed to get model: %w", err)
7598
}
7699

77100
if err := os.RemoveAll(socket); err != nil && !errors.Is(err, fs.ErrNotExist) {
78-
v.log.Warnf("failed to remove socket file %s: %w\n", socket, err)
101+
v.log.Warnf("failed to remove socket file %s: %v\n", socket, err)
79102
v.log.Warnln("vLLM may not be able to start")
80103
}
81104

82-
binPath := "/opt/vllm-env/bin"
83105
args := []string{
84106
"serve",
85107
filepath.Dir(bundle.SafetensorsPath()),
86108
"--uds", socket,
87109
"--served-model-name", modelRef,
88110
}
111+
// TODO: Add inference.BackendConfiguration.
89112

90113
v.log.Infof("vLLM args: %v", args)
91114
tailBuf := tailbuffer.NewTailBuffer(1024)
@@ -104,8 +127,8 @@ func (v *vLLM) Run(ctx context.Context, socket, model string, modelRef string, m
104127
command.Stdout = serverLogStream
105128
command.Stderr = out
106129
},
107-
binPath,
108-
filepath.Join(binPath, "vllm"),
130+
vllmDir,
131+
v.binaryPath(),
109132
args...,
110133
)
111134
if err != nil {
@@ -120,7 +143,7 @@ func (v *vLLM) Run(ctx context.Context, socket, model string, modelRef string, m
120143

121144
errOutput := new(strings.Builder)
122145
if _, err := io.Copy(errOutput, tailBuf); err != nil {
123-
v.log.Warnf("failed to read server output tail: %w", err)
146+
v.log.Warnf("failed to read server output tail: %v", err)
124147
}
125148

126149
if len(errOutput.String()) != 0 {
@@ -132,7 +155,7 @@ func (v *vLLM) Run(ctx context.Context, socket, model string, modelRef string, m
132155
vllmErrors <- vllmErr
133156
close(vllmErrors)
134157
if err := os.Remove(socket); err != nil && !errors.Is(err, fs.ErrNotExist) {
135-
v.log.Warnf("failed to remove socket file %s on exit: %w\n", socket, err)
158+
v.log.Warnf("failed to remove socket file %s on exit: %v\n", socket, err)
136159
}
137160
}()
138161
defer func() {
@@ -153,18 +176,28 @@ func (v *vLLM) Run(ctx context.Context, socket, model string, modelRef string, m
153176
}
154177

155178
func (v *vLLM) Status() string {
156-
return "enabled"
179+
return v.status
157180
}
158181

159182
func (v *vLLM) GetDiskUsage() (int64, error) {
160-
// TODO implement me
161-
return 0, nil
183+
size, err := diskusage.Size(vllmDir)
184+
if err != nil {
185+
return 0, fmt.Errorf("error while getting store size: %v", err)
186+
}
187+
return size, nil
162188
}
163189

164-
func (v *vLLM) GetRequiredMemoryForModel(ctx context.Context, model string, config *inference.BackendConfiguration) (inference.RequiredMemory, error) {
165-
// TODO implement me
190+
func (v *vLLM) GetRequiredMemoryForModel(_ context.Context, _ string, _ *inference.BackendConfiguration) (inference.RequiredMemory, error) {
191+
if !platform.SupportsVLLM() {
192+
return inference.RequiredMemory{}, errors.New("not implemented")
193+
}
194+
166195
return inference.RequiredMemory{
167196
RAM: 1,
168197
VRAM: 1,
169198
}, nil
170199
}
200+
201+
func (v *vLLM) binaryPath() string {
202+
return filepath.Join(vllmDir, "vllm")
203+
}

pkg/inference/platform/platform.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package platform
2+
3+
import "runtime"
4+
5+
// SupportsVLLM returns true if vLLM is supported on the current platform.
6+
func SupportsVLLM() bool {
7+
return runtime.GOOS == "linux"
8+
}

0 commit comments

Comments
 (0)