Skip to content

Commit 13e2b4f

Browse files
committed
fix(vllm): validate safetensors path before serving and update test cases
1 parent b0b6ccc commit 13e2b4f

File tree

2 files changed

+27
-9
lines changed

2 files changed

+27
-9
lines changed

pkg/inference/backends/vllm/vllm_config.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,13 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference
2828
args := append([]string{}, c.Args...)
2929

3030
// Add the serve command and model path (use directory for safetensors)
31-
modelPath := filepath.Dir(bundle.SafetensorsPath())
32-
if modelPath != "" {
33-
// vLLM expects the directory containing the safetensors files
34-
args = append(args, "serve", modelPath)
35-
} else {
31+
safetensorsPath := bundle.SafetensorsPath()
32+
if safetensorsPath == "" {
3633
return nil, fmt.Errorf("safetensors path required by vLLM backend")
3734
}
35+
modelPath := filepath.Dir(safetensorsPath)
36+
// vLLM expects the directory containing the safetensors files
37+
args = append(args, "serve", modelPath)
3838

3939
// Add socket arguments
4040
args = append(args, "--uds", socket)

pkg/inference/backends/vllm/vllm_config_test.go

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,21 @@ func (m *mockModelBundle) RootDir() string {
3838

3939
func TestGetArgs(t *testing.T) {
4040
tests := []struct {
41-
name string
42-
config *inference.BackendConfiguration
43-
bundle *mockModelBundle
44-
expected []string
41+
name string
42+
config *inference.BackendConfiguration
43+
bundle *mockModelBundle
44+
expected []string
45+
expectError bool
4546
}{
47+
{
48+
name: "empty safetensors path should error",
49+
bundle: &mockModelBundle{
50+
safetensorsPath: "",
51+
},
52+
config: nil,
53+
expected: nil,
54+
expectError: true,
55+
},
4656
{
4757
name: "basic args without context size",
4858
bundle: &mockModelBundle{
@@ -116,6 +126,14 @@ func TestGetArgs(t *testing.T) {
116126
t.Run(tt.name, func(t *testing.T) {
117127
config := NewDefaultVLLMConfig()
118128
args, err := config.GetArgs(tt.bundle, "/tmp/socket", inference.BackendModeCompletion, tt.config)
129+
130+
if tt.expectError {
131+
if err == nil {
132+
t.Fatalf("expected error but got none")
133+
}
134+
return
135+
}
136+
119137
if err != nil {
120138
t.Fatalf("unexpected error: %v", err)
121139
}

0 commit comments

Comments
 (0)