Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/inference/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,5 +91,5 @@ type Backend interface {
GetDiskUsage() (int64, error)
// GetRequiredMemoryForModel returns the required working memory for a given
// model.
GetRequiredMemoryForModel(ctx context.Context, model string, config *BackendConfiguration) (*RequiredMemory, error)
GetRequiredMemoryForModel(ctx context.Context, model string, config *BackendConfiguration) (RequiredMemory, error)
}
10 changes: 5 additions & 5 deletions pkg/inference/backends/llamacpp/llamacpp.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,22 +230,22 @@ func (l *llamaCpp) GetDiskUsage() (int64, error) {
return size, nil
}

func (l *llamaCpp) GetRequiredMemoryForModel(ctx context.Context, model string, config *inference.BackendConfiguration) (*inference.RequiredMemory, error) {
func (l *llamaCpp) GetRequiredMemoryForModel(ctx context.Context, model string, config *inference.BackendConfiguration) (inference.RequiredMemory, error) {
var mdlGguf *parser.GGUFFile
var mdlConfig types.Config
inStore, err := l.modelManager.IsModelInStore(model)
if err != nil {
return nil, fmt.Errorf("checking if model is in local store: %w", err)
return inference.RequiredMemory{}, fmt.Errorf("checking if model is in local store: %w", err)
}
if inStore {
mdlGguf, mdlConfig, err = l.parseLocalModel(model)
if err != nil {
return nil, &inference.ErrGGUFParse{Err: err}
return inference.RequiredMemory{}, &inference.ErrGGUFParse{Err: err}
}
} else {
mdlGguf, mdlConfig, err = l.parseRemoteModel(ctx, model)
if err != nil {
return nil, &inference.ErrGGUFParse{Err: err}
return inference.RequiredMemory{}, &inference.ErrGGUFParse{Err: err}
}
}

Expand Down Expand Up @@ -278,7 +278,7 @@ func (l *llamaCpp) GetRequiredMemoryForModel(ctx context.Context, model string,
vram = 1
}

return &inference.RequiredMemory{
return inference.RequiredMemory{
RAM: ram,
VRAM: vram,
}, nil
Expand Down
4 changes: 2 additions & 2 deletions pkg/inference/backends/mlx/mlx.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,6 @@ func (m *mlx) GetDiskUsage() (int64, error) {
return 0, nil
}

func (m *mlx) GetRequiredMemoryForModel(ctx context.Context, model string, config *inference.BackendConfiguration) (*inference.RequiredMemory, error) {
return nil, errors.New("not implemented")
func (m *mlx) GetRequiredMemoryForModel(ctx context.Context, model string, config *inference.BackendConfiguration) (inference.RequiredMemory, error) {
return inference.RequiredMemory{}, errors.New("not implemented")
}
4 changes: 2 additions & 2 deletions pkg/inference/backends/vllm/vllm.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,6 @@ func (v *vLLM) GetDiskUsage() (int64, error) {
return 0, nil
}

func (v *vLLM) GetRequiredMemoryForModel(ctx context.Context, model string, config *inference.BackendConfiguration) (*inference.RequiredMemory, error) {
return nil, errors.New("not implemented")
func (v *vLLM) GetRequiredMemoryForModel(ctx context.Context, model string, config *inference.BackendConfiguration) (inference.RequiredMemory, error) {
return inference.RequiredMemory{}, errors.New("not implemented")
}
16 changes: 8 additions & 8 deletions pkg/inference/memory/estimator.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ import (

type MemoryEstimator interface {
SetDefaultBackend(MemoryEstimatorBackend)
GetRequiredMemoryForModel(context.Context, string, *inference.BackendConfiguration) (*inference.RequiredMemory, error)
HaveSufficientMemoryForModel(ctx context.Context, model string, config *inference.BackendConfiguration) (bool, error)
GetRequiredMemoryForModel(context.Context, string, *inference.BackendConfiguration) (inference.RequiredMemory, error)
HaveSufficientMemoryForModel(ctx context.Context, model string, config *inference.BackendConfiguration) (bool, inference.RequiredMemory, inference.RequiredMemory, error)
}

type MemoryEstimatorBackend interface {
GetRequiredMemoryForModel(context.Context, string, *inference.BackendConfiguration) (*inference.RequiredMemory, error)
GetRequiredMemoryForModel(context.Context, string, *inference.BackendConfiguration) (inference.RequiredMemory, error)
}

type memoryEstimator struct {
Expand All @@ -31,18 +31,18 @@ func (m *memoryEstimator) SetDefaultBackend(backend MemoryEstimatorBackend) {
m.defaultBackend = backend
}

func (m *memoryEstimator) GetRequiredMemoryForModel(ctx context.Context, model string, config *inference.BackendConfiguration) (*inference.RequiredMemory, error) {
func (m *memoryEstimator) GetRequiredMemoryForModel(ctx context.Context, model string, config *inference.BackendConfiguration) (inference.RequiredMemory, error) {
if m.defaultBackend == nil {
return nil, errors.New("default backend not configured")
return inference.RequiredMemory{}, errors.New("default backend not configured")
}

return m.defaultBackend.GetRequiredMemoryForModel(ctx, model, config)
}

func (m *memoryEstimator) HaveSufficientMemoryForModel(ctx context.Context, model string, config *inference.BackendConfiguration) (bool, error) {
func (m *memoryEstimator) HaveSufficientMemoryForModel(ctx context.Context, model string, config *inference.BackendConfiguration) (bool, inference.RequiredMemory, inference.RequiredMemory, error) {
req, err := m.GetRequiredMemoryForModel(ctx, model, config)
if err != nil {
return false, fmt.Errorf("estimating required memory for model: %w", err)
return false, inference.RequiredMemory{}, inference.RequiredMemory{}, fmt.Errorf("estimating required memory for model: %w", err)
}
return m.systemMemoryInfo.HaveSufficientMemory(*req), nil
return m.systemMemoryInfo.HaveSufficientMemory(req), req, m.systemMemoryInfo.GetTotalMemory(), nil
}
7 changes: 4 additions & 3 deletions pkg/inference/models/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,15 +163,16 @@ func (m *Manager) handleCreateModel(w http.ResponseWriter, r *http.Request) {
// besides pulling (such as model building).
if !request.IgnoreRuntimeMemoryCheck {
m.log.Infof("Will estimate memory required for %q", request.From)
proceed, err := m.memoryEstimator.HaveSufficientMemoryForModel(r.Context(), request.From, nil)
proceed, req, totalMem, err := m.memoryEstimator.HaveSufficientMemoryForModel(r.Context(), request.From, nil)
if err != nil {
m.log.Warnf("Failed to calculate memory required for model %q: %s", request.From, err)
// Prefer staying functional in case of unexpected estimation errors.
proceed = true
}
if !proceed {
m.log.Warnf("Runtime memory requirement for model %q exceeds total system memory", request.From)
http.Error(w, "Runtime memory requirement for model exceeds total system memory", http.StatusInsufficientStorage)
errstr := fmt.Sprintf("Runtime memory requirement for model %q exceeds total system memory: required %d RAM %d VRAM, system %d RAM %d VRAM", request.From, req.RAM, req.VRAM, totalMem.RAM, totalMem.VRAM)
m.log.Warnf(errstr)
http.Error(w, errstr, http.StatusInsufficientStorage)
return
}
}
Expand Down
8 changes: 4 additions & 4 deletions pkg/inference/models/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ type mockMemoryEstimator struct{}

func (me *mockMemoryEstimator) SetDefaultBackend(_ memory.MemoryEstimatorBackend) {}

func (me *mockMemoryEstimator) GetRequiredMemoryForModel(_ context.Context, _ string, _ *inference.BackendConfiguration) (*inference.RequiredMemory, error) {
return &inference.RequiredMemory{RAM: 0, VRAM: 0}, nil
func (me *mockMemoryEstimator) GetRequiredMemoryForModel(_ context.Context, _ string, _ *inference.BackendConfiguration) (inference.RequiredMemory, error) {
return inference.RequiredMemory{RAM: 0, VRAM: 0}, nil
}

func (me *mockMemoryEstimator) HaveSufficientMemoryForModel(_ context.Context, _ string, _ *inference.BackendConfiguration) (bool, error) {
return true, nil
func (me *mockMemoryEstimator) HaveSufficientMemoryForModel(_ context.Context, _ string, _ *inference.BackendConfiguration) (bool, inference.RequiredMemory, inference.RequiredMemory, error) {
return true, inference.RequiredMemory{}, inference.RequiredMemory{}, nil
}

// getProjectRoot returns the absolute path to the project root directory
Expand Down
2 changes: 1 addition & 1 deletion pkg/inference/scheduling/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ func (l *loader) load(ctx context.Context, backendName, modelID, modelRef string
// e.g. model is too new for gguf-parser-go to know. We should provide a cleaner
// way to bypass these checks.
l.log.Warnf("Could not parse model(%s), memory checks will be ignored for it. Error: %s", modelID, parseErr)
memory = &inference.RequiredMemory{
memory = inference.RequiredMemory{
RAM: 0,
VRAM: 0,
}
Expand Down