Skip to content

Commit ef40911

Browse files
committed
remove memory estimation logic and related structures
1 parent f12b17b commit ef40911

File tree

10 files changed

+24
-415
lines changed

10 files changed

+24
-415
lines changed

main.go

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,11 @@ import (
1111
"syscall"
1212
"time"
1313

14-
"github.com/docker/model-runner/pkg/gpuinfo"
1514
"github.com/docker/model-runner/pkg/inference"
1615
"github.com/docker/model-runner/pkg/inference/backends/llamacpp"
1716
"github.com/docker/model-runner/pkg/inference/backends/mlx"
1817
"github.com/docker/model-runner/pkg/inference/backends/vllm"
1918
"github.com/docker/model-runner/pkg/inference/config"
20-
"github.com/docker/model-runner/pkg/inference/memory"
2119
"github.com/docker/model-runner/pkg/inference/models"
2220
"github.com/docker/model-runner/pkg/inference/scheduling"
2321
"github.com/docker/model-runner/pkg/metrics"
@@ -65,15 +63,6 @@ func main() {
6563
llamaServerPath = "/Applications/Docker.app/Contents/Resources/model-runner/bin"
6664
}
6765

68-
gpuInfo := gpuinfo.New(llamaServerPath)
69-
70-
sysMemInfo, err := memory.NewSystemMemoryInfo(log, gpuInfo)
71-
if err != nil {
72-
log.Fatalf("unable to initialize system memory info: %v", err)
73-
}
74-
75-
memEstimator := memory.NewEstimator(sysMemInfo)
76-
7766
// Create a proxy-aware HTTP transport
7867
// Use a safe type assertion with fallback, and explicitly set Proxy to http.ProxyFromEnvironment
7968
var baseTransport *http.Transport
@@ -93,7 +82,6 @@ func main() {
9382
log,
9483
clientConfig,
9584
nil,
96-
memEstimator,
9785
)
9886
modelManager := models.NewManager(log.WithFields(logrus.Fields{"component": "model-manager"}), clientConfig)
9987
log.Infof("LLAMA_SERVER_PATH: %s", llamaServerPath)
@@ -118,12 +106,6 @@ func main() {
118106
log.Fatalf("unable to initialize %s backend: %v", llamacpp.Name, err)
119107
}
120108

121-
if os.Getenv("MODEL_RUNNER_RUNTIME_MEMORY_CHECK") == "1" {
122-
memory.SetRuntimeMemoryCheck(true)
123-
}
124-
125-
memEstimator.SetDefaultBackend(llamaCppBackend)
126-
127109
vllmBackend, err := vllm.New(
128110
log,
129111
modelManager,
@@ -160,7 +142,6 @@ func main() {
160142
"",
161143
false,
162144
),
163-
sysMemInfo,
164145
)
165146

166147
// Create the HTTP handler for the scheduler

pkg/inference/memory/estimator.go

Lines changed: 0 additions & 53 deletions
This file was deleted.

pkg/inference/memory/settings.go

Lines changed: 0 additions & 18 deletions
This file was deleted.

pkg/inference/memory/system.go

Lines changed: 0 additions & 64 deletions
This file was deleted.

pkg/inference/models/handler_test.go

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,10 @@ import (
1717
"github.com/docker/model-runner/pkg/distribution/builder"
1818
reg "github.com/docker/model-runner/pkg/distribution/registry"
1919
"github.com/docker/model-runner/pkg/inference"
20-
"github.com/docker/model-runner/pkg/inference/memory"
2120

2221
"github.com/sirupsen/logrus"
2322
)
2423

25-
type mockMemoryEstimator struct{}
26-
27-
func (me *mockMemoryEstimator) SetDefaultBackend(_ memory.MemoryEstimatorBackend) {}
28-
29-
func (me *mockMemoryEstimator) GetRequiredMemoryForModel(_ context.Context, _ string, _ *inference.BackendConfiguration) (inference.RequiredMemory, error) {
30-
return inference.RequiredMemory{RAM: 0, VRAM: 0}, nil
31-
}
32-
33-
func (me *mockMemoryEstimator) HaveSufficientMemoryForModel(_ context.Context, _ string, _ *inference.BackendConfiguration) (bool, inference.RequiredMemory, inference.RequiredMemory, error) {
34-
return true, inference.RequiredMemory{}, inference.RequiredMemory{}, nil
35-
}
36-
3724
// getProjectRoot returns the absolute path to the project root directory
3825
func getProjectRoot(t *testing.T) string {
3926
// Start from the current test file's directory
@@ -123,11 +110,10 @@ func TestPullModel(t *testing.T) {
123110
for _, tt := range tests {
124111
t.Run(tt.name, func(t *testing.T) {
125112
log := logrus.NewEntry(logrus.StandardLogger())
126-
memEstimator := &mockMemoryEstimator{}
127113
handler := NewHTTPHandler(log, ClientConfig{
128114
StoreRootPath: tempDir,
129115
Logger: log.WithFields(logrus.Fields{"component": "model-manager"}),
130-
}, nil, memEstimator)
116+
}, nil)
131117

132118
r := httptest.NewRequest(http.MethodPost, "/models/create", strings.NewReader(`{"from": "`+tag+`"}`))
133119
if tt.acceptHeader != "" {
@@ -234,13 +220,12 @@ func TestHandleGetModel(t *testing.T) {
234220
for _, tt := range tests {
235221
t.Run(tt.name, func(t *testing.T) {
236222
log := logrus.NewEntry(logrus.StandardLogger())
237-
memEstimator := &mockMemoryEstimator{}
238223
handler := NewHTTPHandler(log, ClientConfig{
239224
StoreRootPath: tempDir,
240225
Logger: log.WithFields(logrus.Fields{"component": "model-manager"}),
241226
Transport: http.DefaultTransport,
242227
UserAgent: "test-agent",
243-
}, nil, memEstimator)
228+
}, nil)
244229

245230
// First pull the model if we're testing local access
246231
if !tt.remote && !strings.Contains(tt.modelName, "nonexistent") {
@@ -315,11 +300,10 @@ func TestCors(t *testing.T) {
315300
for _, tt := range tests {
316301
t.Run(tt.path, func(t *testing.T) {
317302
t.Parallel()
318-
memEstimator := &mockMemoryEstimator{}
319303
discard := logrus.New()
320304
discard.SetOutput(io.Discard)
321305
log := logrus.NewEntry(discard)
322-
m := NewHTTPHandler(log, ClientConfig{}, []string{"*"}, memEstimator)
306+
m := NewHTTPHandler(log, ClientConfig{}, []string{"*"})
323307
req := httptest.NewRequest(http.MethodOptions, "http://model-runner.docker.internal"+tt.path, http.NoBody)
324308
req.Header.Set("Origin", "docker.com")
325309
w := httptest.NewRecorder()

pkg/inference/models/http_handler.go

Lines changed: 5 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ import (
1515
"github.com/docker/model-runner/pkg/distribution/distribution"
1616
"github.com/docker/model-runner/pkg/distribution/registry"
1717
"github.com/docker/model-runner/pkg/inference"
18-
"github.com/docker/model-runner/pkg/inference/memory"
1918
"github.com/docker/model-runner/pkg/internal/utils"
2019
"github.com/docker/model-runner/pkg/logging"
2120
"github.com/docker/model-runner/pkg/middleware"
@@ -38,8 +37,6 @@ type HTTPHandler struct {
3837
httpHandler http.Handler
3938
// lock is used to synchronize access to the models manager's router.
4039
lock sync.RWMutex
41-
// memoryEstimator is used to calculate runtime memory requirements for models.
42-
memoryEstimator memory.MemoryEstimator
4340
// manager handles business logic for model operations.
4441
manager *Manager
4542
}
@@ -56,13 +53,12 @@ type ClientConfig struct {
5653
}
5754

5855
// NewHTTPHandler creates a new model's handler.
59-
func NewHTTPHandler(log logging.Logger, c ClientConfig, allowedOrigins []string, memoryEstimator memory.MemoryEstimator) *HTTPHandler {
56+
func NewHTTPHandler(log logging.Logger, c ClientConfig, allowedOrigins []string) *HTTPHandler {
6057
// Create the manager.
6158
m := &HTTPHandler{
62-
log: log,
63-
router: http.NewServeMux(),
64-
memoryEstimator: memoryEstimator,
65-
manager: NewManager(log.WithFields(logrus.Fields{"component": "service"}), c),
59+
log: log,
60+
router: http.NewServeMux(),
61+
manager: NewManager(log.WithFields(logrus.Fields{"component": "service"}), c),
6662
}
6763

6864
// Register routes.
@@ -163,23 +159,7 @@ func (h *HTTPHandler) handleCreateModel(w http.ResponseWriter, r *http.Request)
163159
// Normalize the model name to add defaults
164160
request.From = NormalizeModelName(request.From)
165161

166-
// Pull the model. In the future, we may support additional operations here
167-
// besides pulling (such as model building).
168-
if memory.RuntimeMemoryCheckEnabled() && !request.IgnoreRuntimeMemoryCheck {
169-
h.log.Infof("Will estimate memory required for %q", request.From)
170-
proceed, req, totalMem, err := h.memoryEstimator.HaveSufficientMemoryForModel(r.Context(), request.From, nil)
171-
if err != nil {
172-
h.log.Warnf("Failed to validate sufficient system memory for model %q: %s", request.From, err)
173-
// Prefer staying functional in case of unexpected estimation errors.
174-
proceed = true
175-
}
176-
if !proceed {
177-
errstr := fmt.Sprintf("Runtime memory requirement for model %q exceeds total system memory: required %d RAM %d VRAM, system %d RAM %d VRAM", request.From, req.RAM, req.VRAM, totalMem.RAM, totalMem.VRAM)
178-
h.log.Warnf(errstr)
179-
http.Error(w, errstr, http.StatusInsufficientStorage)
180-
return
181-
}
182-
}
162+
// Pull the model
183163
if err := h.manager.Pull(request.From, request.BearerToken, r, w); err != nil {
184164
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
185165
h.log.Infof("Request canceled/timed out while pulling model %q", request.From)

0 commit comments

Comments
 (0)