Skip to content

Commit 90e631f

Browse files
committed
add GPU memory utilization configuration for model executor
1 parent ff48957 commit 90e631f

File tree

6 files changed

+313
-2
lines changed

6 files changed

+313
-2
lines changed

cmd/cli/commands/configure.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ func newConfigureCmd() *cobra.Command {
1111
var flags ConfigureFlags
1212

1313
c := &cobra.Command{
14-
Use: "configure [--context-size=<n>] [--speculative-draft-model=<model>] [--hf_overrides=<json>] [--mode=<mode>] [--think] MODEL",
14+
Use: "configure [--context-size=<n>] [--speculative-draft-model=<model>] [--hf_overrides=<json>] [--gpu-memory-utilization=<float>] [--mode=<mode>] [--think] MODEL",
1515
Short: "Configure runtime options for a model",
1616
Hidden: true,
1717
Args: func(cmd *cobra.Command, args []string) error {

cmd/cli/commands/configure_flags.go

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,37 @@ func (v *BoolPtrValue) IsBoolFlag() bool {
8484
return true
8585
}
8686

87+
// Float64PtrValue implements pflag.Value interface for *float64 pointers
88+
// This allows flags to have a nil default value instead of 0.0
89+
type Float64PtrValue struct {
90+
ptr **float64
91+
}
92+
93+
// NewFloat64PtrValue creates a new Float64PtrValue for the given pointer
94+
func NewFloat64PtrValue(p **float64) *Float64PtrValue {
95+
return &Float64PtrValue{ptr: p}
96+
}
97+
98+
func (v *Float64PtrValue) String() string {
99+
if v.ptr == nil || *v.ptr == nil {
100+
return ""
101+
}
102+
return strconv.FormatFloat(**v.ptr, 'f', -1, 64)
103+
}
104+
105+
func (v *Float64PtrValue) Set(s string) error {
106+
val, err := strconv.ParseFloat(s, 64)
107+
if err != nil {
108+
return err
109+
}
110+
*v.ptr = &val
111+
return nil
112+
}
113+
114+
func (v *Float64PtrValue) Type() string {
115+
return "float64"
116+
}
117+
87118
// ptr is a helper function to create a pointer to int32
88119
func ptr(v int32) *int32 {
89120
return &v
@@ -100,7 +131,8 @@ type ConfigureFlags struct {
100131
NumTokens int
101132
MinAcceptanceRate float64
102133
// vLLM-specific flags
103-
HFOverrides string
134+
HFOverrides string
135+
GPUMemoryUtilization *float64
104136
// Think parameter for reasoning models
105137
Think *bool
106138
}
@@ -112,6 +144,7 @@ func (f *ConfigureFlags) RegisterFlags(cmd *cobra.Command) {
112144
cmd.Flags().IntVar(&f.NumTokens, "speculative-num-tokens", 0, "number of tokens to predict speculatively")
113145
cmd.Flags().Float64Var(&f.MinAcceptanceRate, "speculative-min-acceptance-rate", 0, "minimum acceptance rate for speculative decoding")
114146
cmd.Flags().StringVar(&f.HFOverrides, "hf_overrides", "", "HuggingFace model config overrides (JSON) - vLLM only")
147+
cmd.Flags().Var(NewFloat64PtrValue(&f.GPUMemoryUtilization), "gpu-memory-utilization", "fraction of GPU memory to use for the model executor (0.0-1.0) - vLLM only")
115148
cmd.Flags().Var(NewBoolPtrValue(&f.Think), "think", "enable reasoning mode for thinking models")
116149
cmd.Flags().StringVar(&f.Mode, "mode", "", "backend operation mode (completion, embedding, reranking)")
117150
}
@@ -151,6 +184,18 @@ func (f *ConfigureFlags) BuildConfigureRequest(model string) (scheduling.Configu
151184
req.VLLM.HFOverrides = hfo
152185
}
153186

187+
// Set GPU memory utilization if provided (vLLM-specific)
188+
if f.GPUMemoryUtilization != nil {
189+
utilization := *f.GPUMemoryUtilization
190+
if utilization < 0.0 || utilization > 1.0 {
191+
return req, fmt.Errorf("--gpu-memory-utilization must be between 0.0 and 1.0, got %f", utilization)
192+
}
193+
if req.VLLM == nil {
194+
req.VLLM = &inference.VLLMConfig{}
195+
}
196+
req.VLLM.GPUMemoryUtilization = f.GPUMemoryUtilization
197+
}
198+
154199
// Set reasoning budget from --think flag
155200
reasoningBudget := f.getReasoningBudget()
156201
if reasoningBudget != nil {

cmd/cli/commands/configure_test.go

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,126 @@ func TestConfigureCmdThinkFlag(t *testing.T) {
128128
}
129129
}
130130

131+
func TestConfigureCmdGPUMemoryUtilizationFlag(t *testing.T) {
132+
// Create the configure command
133+
cmd := newConfigureCmd()
134+
135+
// Verify the --gpu-memory-utilization flag exists
136+
gpuMemFlag := cmd.Flags().Lookup("gpu-memory-utilization")
137+
if gpuMemFlag == nil {
138+
t.Fatal("--gpu-memory-utilization flag not found")
139+
}
140+
141+
// Verify the default value is empty (nil pointer)
142+
if gpuMemFlag.DefValue != "" {
143+
t.Errorf("Expected default gpu-memory-utilization value to be '' (nil), got '%s'", gpuMemFlag.DefValue)
144+
}
145+
146+
// Verify the flag type
147+
if gpuMemFlag.Value.Type() != "float64" {
148+
t.Errorf("Expected gpu-memory-utilization flag type to be 'float64', got '%s'", gpuMemFlag.Value.Type())
149+
}
150+
151+
// Test setting the flag value
152+
err := cmd.Flags().Set("gpu-memory-utilization", "0.7")
153+
if err != nil {
154+
t.Errorf("Failed to set gpu-memory-utilization flag: %v", err)
155+
}
156+
157+
// Verify the value was set
158+
gpuMemValue := gpuMemFlag.Value.String()
159+
if gpuMemValue != "0.7" {
160+
t.Errorf("Expected gpu-memory-utilization flag value to be '0.7', got '%s'", gpuMemValue)
161+
}
162+
}
163+
164+
func TestGPUMemoryUtilizationBehavior(t *testing.T) {
165+
// Helper to create float64 pointer
166+
float64Ptr := func(f float64) *float64 { return &f }
167+
168+
tests := []struct {
169+
name string
170+
gpuMemValue *float64
171+
expectError bool
172+
expectGPUMemSet bool
173+
expectedGPUMemUtil float64
174+
}{
175+
{
176+
name: "default - not set (nil)",
177+
gpuMemValue: nil,
178+
expectError: false,
179+
expectGPUMemSet: false,
180+
},
181+
{
182+
name: "valid value 0.5",
183+
gpuMemValue: float64Ptr(0.5),
184+
expectError: false,
185+
expectGPUMemSet: true,
186+
expectedGPUMemUtil: 0.5,
187+
},
188+
{
189+
name: "edge case 0.0",
190+
gpuMemValue: float64Ptr(0.0),
191+
expectError: false,
192+
expectGPUMemSet: true,
193+
expectedGPUMemUtil: 0.0,
194+
},
195+
{
196+
name: "edge case 1.0",
197+
gpuMemValue: float64Ptr(1.0),
198+
expectError: false,
199+
expectGPUMemSet: true,
200+
expectedGPUMemUtil: 1.0,
201+
},
202+
{
203+
name: "invalid - negative value",
204+
gpuMemValue: float64Ptr(-0.1),
205+
expectError: true,
206+
},
207+
{
208+
name: "invalid - value > 1.0",
209+
gpuMemValue: float64Ptr(1.5),
210+
expectError: true,
211+
},
212+
}
213+
214+
for _, tt := range tests {
215+
t.Run(tt.name, func(t *testing.T) {
216+
flags := ConfigureFlags{
217+
GPUMemoryUtilization: tt.gpuMemValue,
218+
}
219+
220+
req, err := flags.BuildConfigureRequest("test-model")
221+
222+
if tt.expectError {
223+
if err == nil {
224+
t.Fatal("Expected error but got none")
225+
}
226+
return
227+
}
228+
229+
if err != nil {
230+
t.Fatalf("Unexpected error: %v", err)
231+
}
232+
233+
if tt.expectGPUMemSet {
234+
// GPU memory utilization should be set
235+
if req.VLLM == nil || req.VLLM.GPUMemoryUtilization == nil {
236+
t.Fatal("Expected GPU memory utilization to be set")
237+
}
238+
if *req.VLLM.GPUMemoryUtilization != tt.expectedGPUMemUtil {
239+
t.Errorf("Expected GPU memory utilization to be %f, got %f", tt.expectedGPUMemUtil, *req.VLLM.GPUMemoryUtilization)
240+
}
241+
} else {
242+
// GPU memory utilization should NOT be set
243+
if req.VLLM != nil && req.VLLM.GPUMemoryUtilization != nil {
244+
t.Errorf("Expected GPU memory utilization to be nil when not set, got %f", *req.VLLM.GPUMemoryUtilization)
245+
}
246+
}
247+
})
248+
}
249+
}
250+
131251
func TestThinkFlagBehavior(t *testing.T) {
132252
// Helper to create bool pointer
133253
boolPtr := func(b bool) *bool { return &b }

pkg/inference/backend.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ type VLLMConfig struct {
6767
// HFOverrides contains HuggingFace model configuration overrides.
6868
// This maps to vLLM's --hf-overrides flag which accepts a JSON dictionary.
6969
HFOverrides HFOverrides `json:"hf-overrides,omitempty"`
70+
// GPUMemoryUtilization sets the fraction of GPU memory to be used for the model executor.
71+
// Must be between 0.0 and 1.0. If not specified, vLLM uses its default value of 0.9.
72+
// This maps to vLLM's --gpu-memory-utilization flag.
73+
GPUMemoryUtilization *float64 `json:"gpu-memory-utilization,omitempty"`
7074
}
7175

7276
// LlamaCppConfig contains llama.cpp-specific configuration options.

pkg/inference/backends/vllm/vllm_config.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,15 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference
6060

6161
// Add vLLM-specific arguments from backend config
6262
if config != nil && config.VLLM != nil {
63+
// Add GPU memory utilization if specified
64+
if config.VLLM.GPUMemoryUtilization != nil {
65+
utilization := *config.VLLM.GPUMemoryUtilization
66+
if utilization < 0.0 || utilization > 1.0 {
67+
return nil, fmt.Errorf("gpu-memory-utilization must be between 0.0 and 1.0, got %f", utilization)
68+
}
69+
args = append(args, "--gpu-memory-utilization", strconv.FormatFloat(utilization, 'f', -1, 64))
70+
}
71+
6372
// Add HuggingFace overrides if specified
6473
if len(config.VLLM.HFOverrides) > 0 {
6574
hfOverridesJSON, err := json.Marshal(config.VLLM.HFOverrides)

pkg/inference/backends/vllm/vllm_config_test.go

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,135 @@ func TestGetArgs(t *testing.T) {
203203
`{"model_type":"bert"}`,
204204
},
205205
},
206+
{
207+
name: "with GPU memory utilization 0.5",
208+
bundle: &mockModelBundle{
209+
safetensorsPath: "/path/to/model",
210+
},
211+
config: &inference.BackendConfiguration{
212+
VLLM: &inference.VLLMConfig{
213+
GPUMemoryUtilization: float64ptr(0.5),
214+
},
215+
},
216+
expected: []string{
217+
"serve",
218+
"/path/to",
219+
"--uds",
220+
"/tmp/socket",
221+
"--gpu-memory-utilization",
222+
"0.5",
223+
},
224+
},
225+
{
226+
name: "with GPU memory utilization 0.0 (edge case)",
227+
bundle: &mockModelBundle{
228+
safetensorsPath: "/path/to/model",
229+
},
230+
config: &inference.BackendConfiguration{
231+
VLLM: &inference.VLLMConfig{
232+
GPUMemoryUtilization: float64ptr(0.0),
233+
},
234+
},
235+
expected: []string{
236+
"serve",
237+
"/path/to",
238+
"--uds",
239+
"/tmp/socket",
240+
"--gpu-memory-utilization",
241+
"0",
242+
},
243+
},
244+
{
245+
name: "with GPU memory utilization 1.0 (edge case)",
246+
bundle: &mockModelBundle{
247+
safetensorsPath: "/path/to/model",
248+
},
249+
config: &inference.BackendConfiguration{
250+
VLLM: &inference.VLLMConfig{
251+
GPUMemoryUtilization: float64ptr(1.0),
252+
},
253+
},
254+
expected: []string{
255+
"serve",
256+
"/path/to",
257+
"--uds",
258+
"/tmp/socket",
259+
"--gpu-memory-utilization",
260+
"1",
261+
},
262+
},
263+
{
264+
name: "with GPU memory utilization negative (invalid)",
265+
bundle: &mockModelBundle{
266+
safetensorsPath: "/path/to/model",
267+
},
268+
config: &inference.BackendConfiguration{
269+
VLLM: &inference.VLLMConfig{
270+
GPUMemoryUtilization: float64ptr(-0.1),
271+
},
272+
},
273+
expectError: true,
274+
},
275+
{
276+
name: "with GPU memory utilization > 1.0 (invalid)",
277+
bundle: &mockModelBundle{
278+
safetensorsPath: "/path/to/model",
279+
},
280+
config: &inference.BackendConfiguration{
281+
VLLM: &inference.VLLMConfig{
282+
GPUMemoryUtilization: float64ptr(1.5),
283+
},
284+
},
285+
expectError: true,
286+
},
287+
{
288+
name: "with GPU memory utilization and other parameters",
289+
bundle: &mockModelBundle{
290+
safetensorsPath: "/path/to/model",
291+
},
292+
config: &inference.BackendConfiguration{
293+
ContextSize: int32ptr(8192),
294+
VLLM: &inference.VLLMConfig{
295+
GPUMemoryUtilization: float64ptr(0.7),
296+
HFOverrides: inference.HFOverrides{
297+
"architectures": []interface{}{"LlamaForCausalLM"},
298+
},
299+
},
300+
},
301+
expected: []string{
302+
"serve",
303+
"/path/to",
304+
"--uds",
305+
"/tmp/socket",
306+
"--max-model-len",
307+
"8192",
308+
"--gpu-memory-utilization",
309+
"0.7",
310+
"--hf-overrides",
311+
`{"architectures":["LlamaForCausalLM"]}`,
312+
},
313+
},
314+
{
315+
name: "without GPU memory utilization (should not add flag)",
316+
bundle: &mockModelBundle{
317+
safetensorsPath: "/path/to/model",
318+
},
319+
config: &inference.BackendConfiguration{
320+
VLLM: &inference.VLLMConfig{
321+
HFOverrides: inference.HFOverrides{
322+
"model_type": "llama",
323+
},
324+
},
325+
},
326+
expected: []string{
327+
"serve",
328+
"/path/to",
329+
"--uds",
330+
"/tmp/socket",
331+
"--hf-overrides",
332+
`{"model_type":"llama"}`,
333+
},
334+
},
206335
}
207336

208337
for _, tt := range tests {
@@ -290,3 +419,7 @@ func TestGetMaxModelLen(t *testing.T) {
290419
func int32ptr(n int32) *int32 {
291420
return &n
292421
}
422+
423+
func float64ptr(n float64) *float64 {
424+
return &n
425+
}

0 commit comments

Comments
 (0)