Skip to content

Commit e6fd394

Browse files
authored
Do not html encode pull/push progress if accept header is set to text/json (#38)
1 parent 6e32aa4 commit e6fd394

File tree

1 file changed

+42
-11
lines changed

1 file changed

+42
-11
lines changed

pkg/inference/models/manager.go

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ func (m *Manager) handleCreateModel(w http.ResponseWriter, r *http.Request) {
129129

130130
// Pull the model. In the future, we may support additional operations here
131131
// besides pulling (such as model building).
132-
if err := m.PullModel(r.Context(), request.From, w); err != nil {
132+
if err := m.PullModel(request.From, r, w); err != nil {
133133
if errors.Is(err, registry.ErrInvalidReference) {
134134
m.log.Warnf("Invalid model reference %q: %v", request.From, err)
135135
http.Error(w, "Invalid model reference", http.StatusBadRequest)
@@ -378,7 +378,7 @@ func (m *Manager) handlePushModel(w http.ResponseWriter, r *http.Request, model
378378
}
379379

380380
// Call the PushModel method on the distribution client.
381-
if err := m.PushModel(r.Context(), model, w); err != nil {
381+
if err := m.PushModel(model, r, w); err != nil {
382382
if errors.Is(err, distribution.ErrInvalidReference) {
383383
m.log.Warnf("Invalid model reference %q: %v", model, err)
384384
http.Error(w, "Invalid model reference", http.StatusBadRequest)
@@ -428,23 +428,33 @@ func (m *Manager) GetModelPath(ref string) (string, error) {
428428

429429
// PullModel pulls a model to local storage. Any error it returns is suitable
430430
// for writing back to the client.
431-
func (m *Manager) PullModel(ctx context.Context, model string, w http.ResponseWriter) error {
431+
func (m *Manager) PullModel(model string, r *http.Request, w http.ResponseWriter) error {
432432
// Restrict model pull concurrency.
433433
select {
434434
case <-m.pullTokens:
435-
case <-ctx.Done():
435+
case <-r.Context().Done():
436436
return context.Canceled
437437
}
438438
defer func() {
439439
m.pullTokens <- struct{}{}
440440
}()
441441

442442
// Set up response headers for streaming
443-
w.Header().Set("Content-Type", "text/event-stream")
444443
w.Header().Set("Cache-Control", "no-cache")
445444
w.Header().Set("Connection", "keep-alive")
446445
w.Header().Set("Transfer-Encoding", "chunked")
447446

447+
// Check Accept header to determine content type
448+
acceptHeader := r.Header.Get("Accept")
449+
isJSON := acceptHeader == "application/json"
450+
451+
if isJSON {
452+
w.Header().Set("Content-Type", "application/json")
453+
} else {
454+
// Defaults to text/plain
455+
w.Header().Set("Content-Type", "text/plain")
456+
}
457+
448458
// Create a flusher to ensure chunks are sent immediately
449459
flusher, ok := w.(http.Flusher)
450460
if !ok {
@@ -455,11 +465,12 @@ func (m *Manager) PullModel(ctx context.Context, model string, w http.ResponseWr
455465
progressWriter := &progressResponseWriter{
456466
writer: w,
457467
flusher: flusher,
468+
isJSON: isJSON,
458469
}
459470

460471
// Pull the model using the Docker model distribution client
461472
m.log.Infoln("Pulling model:", model)
462-
err := m.distributionClient.PullModel(ctx, model, progressWriter)
473+
err := m.distributionClient.PullModel(r.Context(), model, progressWriter)
463474
if err != nil {
464475
return fmt.Errorf("error while pulling model: %w", err)
465476
}
@@ -468,13 +479,22 @@ func (m *Manager) PullModel(ctx context.Context, model string, w http.ResponseWr
468479
}
469480

470481
// PushModel pushes a model from the store to the registry.
471-
func (m *Manager) PushModel(ctx context.Context, model string, w http.ResponseWriter) error {
482+
func (m *Manager) PushModel(model string, r *http.Request, w http.ResponseWriter) error {
472483
// Set up response headers for streaming
473-
w.Header().Set("Content-Type", "text/event-stream")
474484
w.Header().Set("Cache-Control", "no-cache")
475485
w.Header().Set("Connection", "keep-alive")
476486
w.Header().Set("Transfer-Encoding", "chunked")
477487

488+
// Check Accept header to determine content type
489+
acceptHeader := r.Header.Get("Accept")
490+
isJSON := acceptHeader == "application/json"
491+
492+
if isJSON {
493+
w.Header().Set("Content-Type", "application/json")
494+
} else {
495+
w.Header().Set("Content-Type", "text/plain")
496+
}
497+
478498
// Create a flusher to ensure chunks are sent immediately
479499
flusher, ok := w.(http.Flusher)
480500
if !ok {
@@ -485,11 +505,12 @@ func (m *Manager) PushModel(ctx context.Context, model string, w http.ResponseWr
485505
progressWriter := &progressResponseWriter{
486506
writer: w,
487507
flusher: flusher,
508+
isJSON: isJSON,
488509
}
489510

490511
// Pull the model using the Docker model distribution client
491512
m.log.Infoln("Pushing model:", model)
492-
err := m.distributionClient.PushModel(ctx, model, progressWriter)
513+
err := m.distributionClient.PushModel(r.Context(), model, progressWriter)
493514
if err != nil {
494515
return fmt.Errorf("error while pushing model: %w", err)
495516
}
@@ -501,11 +522,21 @@ func (m *Manager) PushModel(ctx context.Context, model string, w http.ResponseWr
501522
type progressResponseWriter struct {
502523
writer http.ResponseWriter
503524
flusher http.Flusher
525+
isJSON bool
504526
}
505527

506528
func (w *progressResponseWriter) Write(p []byte) (n int, err error) {
507-
escapedData := html.EscapeString(string(p))
508-
n, err = w.writer.Write([]byte(escapedData))
529+
var data []byte
530+
if w.isJSON {
531+
// For JSON, write the raw bytes without escaping
532+
data = p
533+
} else {
534+
// For plain text, escape HTML
535+
escapedData := html.EscapeString(string(p))
536+
data = []byte(escapedData)
537+
}
538+
539+
n, err = w.writer.Write(data)
509540
if err != nil {
510541
return 0, err
511542
}

0 commit comments

Comments
 (0)