Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 37 additions & 3 deletions pkg/inference/models/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"os"
"path/filepath"
"strings"
"sync"
"testing"

"github.com/google/go-containerregistry/pkg/registry"
Expand All @@ -20,6 +21,39 @@ import (
"github.com/sirupsen/logrus"
)

// mutexResponseRecorder wraps httptest.ResponseRecorder with a mutex for thread-safe access
type mutexResponseRecorder struct {
*httptest.ResponseRecorder
mu sync.Mutex
}

func newMutexResponseRecorder() *mutexResponseRecorder {
return &mutexResponseRecorder{
ResponseRecorder: httptest.NewRecorder(),
}
}

// WriteHeader wraps the underlying WriteHeader with mutex protection
func (m *mutexResponseRecorder) WriteHeader(code int) {
m.mu.Lock()
defer m.mu.Unlock()
m.ResponseRecorder.WriteHeader(code)
}

// Write wraps the underlying Write with mutex protection
func (m *mutexResponseRecorder) Write(b []byte) (int, error) {
m.mu.Lock()
defer m.mu.Unlock()
return m.ResponseRecorder.Write(b)
}

// Header wraps the underlying Header with mutex protection
func (m *mutexResponseRecorder) Header() http.Header {
m.mu.Lock()
defer m.mu.Unlock()
return m.ResponseRecorder.Header()
}

// getProjectRoot returns the absolute path to the project root directory
func getProjectRoot(t *testing.T) string {
// Start from the current test file's directory
Expand Down Expand Up @@ -119,7 +153,7 @@ func TestPullModel(t *testing.T) {
r.Header.Set("Accept", tt.acceptHeader)
}

w := httptest.NewRecorder()
w := newMutexResponseRecorder()
Copy link
Contributor

@ekcasey ekcasey Jun 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we can assume the real http.ResponseWriter is thread safe and we might be papering over real issues with this change to use a thread-safe version in the tests. It looks to me like the concurrent writes are happening within the pull implementation.

My best guess is this is that this is appearing now b/c with the per-layer progress reporting we are no longer waiting for the progress writing from one layer to complete before we start writing progress for the next layer.

err = m.PullModel(tag, r, w)
if err != nil {
t.Fatalf("Failed to pull model: %v", err)
Expand Down Expand Up @@ -229,7 +263,7 @@ func TestHandleGetModel(t *testing.T) {
// First pull the model if we're testing local access
if !tt.remote && !strings.Contains(tt.modelName, "nonexistent") {
r := httptest.NewRequest("POST", "/models/create", strings.NewReader(`{"from": "`+tt.modelName+`"}`))
w := httptest.NewRecorder()
w := newMutexResponseRecorder()
err = m.PullModel(tt.modelName, r, w)
if err != nil {
t.Fatalf("Failed to pull model: %v", err)
Expand All @@ -242,7 +276,7 @@ func TestHandleGetModel(t *testing.T) {
path += "?remote=true"
}
r := httptest.NewRequest("GET", path, nil)
w := httptest.NewRecorder()
w := newMutexResponseRecorder()

// Set the path value for {name} so r.PathValue("name") works
r.SetPathValue("name", tt.modelName)
Expand Down