Skip to content

Commit 21d7973

Browse files
authored
Improve content-length handling (#115)
ref: See #114 * Improve content-length handling - Content length was not always being sent - Add tests for content-length
1 parent cc450e9 commit 21d7973

File tree

3 files changed

+52
-11
lines changed

3 files changed

+52
-11
lines changed

misc/simple-responder/simple-responder.go

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,17 @@ func main() {
3333

3434
// Set up the handler function using the provided response message
3535
r.POST("/v1/chat/completions", func(c *gin.Context) {
36-
c.Header("Content-Type", "text/plain")
36+
c.Header("Content-Type", "application/json")
3737

3838
// add a wait to simulate a slow query
3939
if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
4040
time.Sleep(wait)
4141
}
4242

43-
c.String(200, *responseMessage)
43+
c.JSON(http.StatusOK, gin.H{
44+
"responseMessage": *responseMessage,
45+
"h_content_length": c.Request.Header.Get("Content-Length"),
46+
})
4447
})
4548

4649
// for issue #62 to check model name strips profile slug
@@ -63,8 +66,11 @@ func main() {
6366
})
6467

6568
r.POST("/v1/completions", func(c *gin.Context) {
66-
c.Header("Content-Type", "text/plain")
67-
c.String(200, *responseMessage)
69+
c.Header("Content-Type", "application/json")
70+
c.JSON(http.StatusOK, gin.H{
71+
"responseMessage": *responseMessage,
72+
})
73+
6874
})
6975

7076
// issue #41
@@ -104,6 +110,10 @@ func main() {
104110
c.JSON(http.StatusOK, gin.H{
105111
"text": fmt.Sprintf("The length of the file is %d bytes", fileSize),
106112
"model": model,
113+
114+
// expose some header values for testing
115+
"h_content_type": c.GetHeader("Content-Type"),
116+
"h_content_length": c.GetHeader("Content-Length"),
107117
})
108118
})
109119

proxy/proxymanager.go

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,6 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
371371

372372
// dechunk it as we already have all the body bytes see issue #11
373373
c.Request.Header.Del("transfer-encoding")
374-
c.Request.Header.Del("content-length")
375374
c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes)))
376375

377376
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
@@ -382,11 +381,6 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
382381
}
383382

384383
func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
385-
// We need to reconstruct the multipart form in any case since the body is consumed
386-
// Create a new buffer for the reconstructed request
387-
var requestBuffer bytes.Buffer
388-
multipartWriter := multipart.NewWriter(&requestBuffer)
389-
390384
// Parse multipart form
391385
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk
392386
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
@@ -406,6 +400,11 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
406400
return
407401
}
408402

403+
// We need to reconstruct the multipart form in any case since the body is consumed
404+
// Create a new buffer for the reconstructed request
405+
var requestBuffer bytes.Buffer
406+
multipartWriter := multipart.NewWriter(&requestBuffer)
407+
409408
// Copy all form values
410409
for key, values := range c.Request.MultipartForm.Value {
411410
for _, value := range values {
@@ -479,6 +478,10 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
479478
modifiedReq.Header = c.Request.Header.Clone()
480479
modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType())
481480

481+
// set the content length of the body
482+
modifiedReq.Header.Set("Content-Length", strconv.Itoa(requestBuffer.Len()))
483+
modifiedReq.ContentLength = int64(requestBuffer.Len())
484+
482485
// Use the modified request for proxying
483486
if err := processGroup.ProxyRequest(realModelName, c.Writer, modifiedReq); err != nil {
484487
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))

proxy/proxymanager_test.go

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"mime/multipart"
99
"net/http"
1010
"net/http/httptest"
11+
"strconv"
1112
"sync"
1213
"testing"
1314
"time"
@@ -165,7 +166,9 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
165166

166167
mu.Lock()
167168

168-
results[key] = w.Body.String()
169+
var response map[string]string
170+
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
171+
results[key] = response["responseMessage"]
169172
mu.Unlock()
170173
}(key)
171174

@@ -442,6 +445,7 @@ func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
442445
assert.NoError(t, err)
443446
assert.Equal(t, "TheExpectedModel", response["model"])
444447
assert.Equal(t, response["text"], fmt.Sprintf("The length of the file is %d bytes", contentLength)) // matches simple-responder
448+
assert.Equal(t, strconv.Itoa(370+contentLength), response["h_content_length"])
445449
}
446450

447451
// Test useModelName in configuration sends overrides what is sent to upstream
@@ -592,3 +596,27 @@ func TestProxyManager_Upstream(t *testing.T) {
592596
assert.Equal(t, http.StatusOK, rec.Code)
593597
assert.Equal(t, "model1", rec.Body.String())
594598
}
599+
600+
func TestProxyManager_ChatContentLength(t *testing.T) {
601+
config := AddDefaultGroupToConfig(Config{
602+
HealthCheckTimeout: 15,
603+
Models: map[string]ModelConfig{
604+
"model1": getTestSimpleResponderConfig("model1"),
605+
},
606+
LogLevel: "error",
607+
})
608+
609+
proxy := New(config)
610+
defer proxy.StopProcesses()
611+
612+
reqBody := fmt.Sprintf(`{"model":"%s", "x": "this is just some content to push the length out a bit"}`, "model1")
613+
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
614+
w := httptest.NewRecorder()
615+
616+
proxy.HandlerFunc(w, req)
617+
assert.Equal(t, http.StatusOK, w.Code)
618+
var response map[string]string
619+
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
620+
assert.Equal(t, "81", response["h_content_length"])
621+
assert.Equal(t, "model1", response["responseMessage"])
622+
}

0 commit comments

Comments
 (0)