Skip to content

Commit bf8d9e7

Browse files
committed
feat(vllm): enhance argument handling for vLLM backend configuration
1 parent 0bb2f02 commit bf8d9e7

File tree

3 files changed

+271
-8
lines changed

3 files changed

+271
-8
lines changed

pkg/inference/backends/vllm/vllm.go

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ func (v *vLLM) Install(_ context.Context, _ *http.Client) error {
9393
return nil
9494
}
9595

96-
func (v *vLLM) Run(ctx context.Context, socket, model string, modelRef string, _ inference.BackendMode, _ *inference.BackendConfiguration) error {
96+
func (v *vLLM) Run(ctx context.Context, socket, model string, modelRef string, mode inference.BackendMode, backendConfig *inference.BackendConfiguration) error {
9797
if !platform.SupportsVLLM() {
9898
v.log.Warn("vLLM backend is not yet supported")
9999
return errors.New("not implemented")
@@ -109,13 +109,14 @@ func (v *vLLM) Run(ctx context.Context, socket, model string, modelRef string, _
109109
v.log.Warnln("vLLM may not be able to start")
110110
}
111111

112-
args := []string{
113-
"serve",
114-
filepath.Dir(bundle.SafetensorsPath()),
115-
"--uds", socket,
116-
"--served-model-name", modelRef,
112+
// Get arguments from config
113+
args, err := v.config.GetArgs(bundle, socket, mode, backendConfig)
114+
if err != nil {
115+
return fmt.Errorf("failed to get vLLM arguments: %w", err)
117116
}
118-
// TODO: Add inference.BackendConfiguration.
117+
118+
// Add served model name
119+
args = append(args, "--served-model-name", modelRef)
119120

120121
v.log.Infof("vLLM args: %v", args)
121122
tailBuf := tailbuffer.NewTailBuffer(1024)
Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,81 @@
11
package vllm
22

3+
import (
4+
"fmt"
5+
"strconv"
6+
7+
"github.com/docker/model-runner/pkg/distribution/types"
8+
"github.com/docker/model-runner/pkg/inference"
9+
)
10+
311
// Config is the configuration for the vLLM backend.
412
type Config struct {
13+
// Args are the base arguments that are always included.
14+
Args []string
515
}
616

717
// NewDefaultVLLMConfig creates a new VLLMConfig with default values.
818
func NewDefaultVLLMConfig() *Config {
9-
return &Config{}
19+
return &Config{
20+
Args: []string{},
21+
}
22+
}
23+
24+
// GetArgs implements BackendConfig.GetArgs.
25+
func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference.BackendMode, config *inference.BackendConfiguration) ([]string, error) {
26+
// Start with the arguments from VLLMConfig
27+
args := append([]string{}, c.Args...)
28+
29+
// Add the serve command and model path (use directory for safetensors)
30+
modelPath := bundle.SafetensorsPath()
31+
if modelPath != "" {
32+
// vLLM expects the directory containing the safetensors files
33+
args = append(args, "serve", modelPath)
34+
} else {
35+
return nil, fmt.Errorf("safetensors path required by vLLM backend")
36+
}
37+
38+
// Add socket arguments
39+
args = append(args, "--uds", socket)
40+
41+
// Add mode-specific arguments
42+
switch mode {
43+
case inference.BackendModeCompletion:
44+
// Default mode for vLLM
45+
case inference.BackendModeEmbedding:
46+
// vLLM doesn't have a specific embedding flag like llama.cpp
47+
// Embedding models are detected automatically
48+
default:
49+
return nil, fmt.Errorf("unsupported backend mode %q", mode)
50+
}
51+
52+
// Add max-model-len if specified in model config or backend config
53+
if maxLen := GetMaxModelLen(bundle.RuntimeConfig(), config); maxLen != nil {
54+
args = append(args, "--max-model-len", strconv.FormatUint(*maxLen, 10))
55+
}
56+
// If nil, vLLM will automatically derive from the model config
57+
58+
// Add arguments from backend config
59+
if config != nil {
60+
args = append(args, config.RuntimeFlags...)
61+
}
62+
63+
return args, nil
64+
}
65+
66+
// GetMaxModelLen returns the max model length (context size) from model config or backend config.
67+
// Model config takes precedence over backend config.
68+
// Returns nil if neither is specified (vLLM will auto-derive from model).
69+
func GetMaxModelLen(modelCfg types.Config, backendCfg *inference.BackendConfiguration) *uint64 {
70+
// Model config takes precedence
71+
if modelCfg.ContextSize != nil {
72+
return modelCfg.ContextSize
73+
}
74+
// else use backend config
75+
if backendCfg != nil && backendCfg.ContextSize > 0 {
76+
val := uint64(backendCfg.ContextSize)
77+
return &val
78+
}
79+
// Return nil to let vLLM auto-derive from model config
80+
return nil
1081
}
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
package vllm
2+
3+
import (
4+
"testing"
5+
6+
"github.com/docker/model-runner/pkg/distribution/types"
7+
"github.com/docker/model-runner/pkg/inference"
8+
)
9+
10+
type mockModelBundle struct {
11+
safetensorsPath string
12+
runtimeConfig types.Config
13+
}
14+
15+
func (m *mockModelBundle) GGUFPath() string {
16+
return ""
17+
}
18+
19+
func (m *mockModelBundle) SafetensorsPath() string {
20+
return m.safetensorsPath
21+
}
22+
23+
func (m *mockModelBundle) ChatTemplatePath() string {
24+
return ""
25+
}
26+
27+
func (m *mockModelBundle) MMPROJPath() string {
28+
return ""
29+
}
30+
31+
func (m *mockModelBundle) RuntimeConfig() types.Config {
32+
return m.runtimeConfig
33+
}
34+
35+
func (m *mockModelBundle) RootDir() string {
36+
return "/path/to/bundle"
37+
}
38+
39+
func TestGetArgs(t *testing.T) {
40+
tests := []struct {
41+
name string
42+
config *inference.BackendConfiguration
43+
bundle *mockModelBundle
44+
expected []string
45+
}{
46+
{
47+
name: "basic args without context size",
48+
bundle: &mockModelBundle{
49+
safetensorsPath: "/path/to/model",
50+
},
51+
config: nil,
52+
expected: []string{
53+
"serve",
54+
"/path/to/model",
55+
"--uds",
56+
"/tmp/socket",
57+
},
58+
},
59+
{
60+
name: "with backend context size",
61+
bundle: &mockModelBundle{
62+
safetensorsPath: "/path/to/model",
63+
},
64+
config: &inference.BackendConfiguration{
65+
ContextSize: 8192,
66+
},
67+
expected: []string{
68+
"serve",
69+
"/path/to/model",
70+
"--uds",
71+
"/tmp/socket",
72+
"--max-model-len",
73+
"8192",
74+
},
75+
},
76+
{
77+
name: "with runtime flags",
78+
bundle: &mockModelBundle{
79+
safetensorsPath: "/path/to/model",
80+
},
81+
config: &inference.BackendConfiguration{
82+
RuntimeFlags: []string{"--gpu-memory-utilization", "0.9"},
83+
},
84+
expected: []string{
85+
"serve",
86+
"/path/to/model",
87+
"--uds",
88+
"/tmp/socket",
89+
"--gpu-memory-utilization",
90+
"0.9",
91+
},
92+
},
93+
{
94+
name: "with model context size (takes precedence)",
95+
bundle: &mockModelBundle{
96+
safetensorsPath: "/path/to/model",
97+
runtimeConfig: types.Config{
98+
ContextSize: ptrUint64(16384),
99+
},
100+
},
101+
config: &inference.BackendConfiguration{
102+
ContextSize: 8192,
103+
},
104+
expected: []string{
105+
"serve",
106+
"/path/to/model",
107+
"--uds",
108+
"/tmp/socket",
109+
"--max-model-len",
110+
"16384",
111+
},
112+
},
113+
}
114+
115+
for _, tt := range tests {
116+
t.Run(tt.name, func(t *testing.T) {
117+
config := NewDefaultVLLMConfig()
118+
args, err := config.GetArgs(tt.bundle, "/tmp/socket", inference.BackendModeCompletion, tt.config)
119+
if err != nil {
120+
t.Fatalf("unexpected error: %v", err)
121+
}
122+
123+
if len(args) != len(tt.expected) {
124+
t.Fatalf("expected %d args, got %d\nexpected: %v\ngot: %v", len(tt.expected), len(args), tt.expected, args)
125+
}
126+
127+
for i, arg := range args {
128+
if arg != tt.expected[i] {
129+
t.Errorf("arg[%d]: expected %q, got %q", i, tt.expected[i], arg)
130+
}
131+
}
132+
})
133+
}
134+
}
135+
136+
func TestGetMaxModelLen(t *testing.T) {
137+
tests := []struct {
138+
name string
139+
modelCfg types.Config
140+
backendCfg *inference.BackendConfiguration
141+
expectedValue *uint64
142+
}{
143+
{
144+
name: "no config",
145+
modelCfg: types.Config{},
146+
backendCfg: nil,
147+
expectedValue: nil,
148+
},
149+
{
150+
name: "backend config only",
151+
modelCfg: types.Config{},
152+
backendCfg: &inference.BackendConfiguration{
153+
ContextSize: 4096,
154+
},
155+
expectedValue: ptrUint64(4096),
156+
},
157+
{
158+
name: "model config only",
159+
modelCfg: types.Config{
160+
ContextSize: ptrUint64(8192),
161+
},
162+
backendCfg: nil,
163+
expectedValue: ptrUint64(8192),
164+
},
165+
{
166+
name: "model config takes precedence",
167+
modelCfg: types.Config{
168+
ContextSize: ptrUint64(16384),
169+
},
170+
backendCfg: &inference.BackendConfiguration{
171+
ContextSize: 4096,
172+
},
173+
expectedValue: ptrUint64(16384),
174+
},
175+
}
176+
177+
for _, tt := range tests {
178+
t.Run(tt.name, func(t *testing.T) {
179+
result := GetMaxModelLen(tt.modelCfg, tt.backendCfg)
180+
if (result == nil) != (tt.expectedValue == nil) {
181+
t.Errorf("expected nil=%v, got nil=%v", tt.expectedValue == nil, result == nil)
182+
} else if result != nil && *result != *tt.expectedValue {
183+
t.Errorf("expected %d, got %d", *tt.expectedValue, *result)
184+
}
185+
})
186+
}
187+
}
188+
189+
func ptrUint64(v uint64) *uint64 {
190+
return &v
191+
}

0 commit comments

Comments
 (0)