Skip to content

Commit 197b89e

Browse files
committed
feat: refactor request body handling to use BodyStorage for improved efficiency
1 parent 75e533e commit 197b89e

File tree

12 files changed

+69
-66
lines changed

12 files changed

+69
-66
lines changed

common/body_storage.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,12 @@ func CreateBodyStorageFromReader(reader io.Reader, contentLength int64, maxBytes
302302
return storage, nil
303303
}
304304

305+
// ReaderOnly wraps an io.Reader to hide io.Closer, preventing http.NewRequest
306+
// from type-asserting io.ReadCloser and closing the underlying BodyStorage.
307+
func ReaderOnly(r io.Reader) io.Reader {
308+
return struct{ io.Reader }{r}
309+
}
310+
305311
// CleanupOldCacheFiles 清理旧的缓存文件(用于启动时清理残留)
306312
func CleanupOldCacheFiles() {
307313
// 使用统一的缓存管理

common/gin.go

Lines changed: 32 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -33,22 +33,27 @@ func IsRequestBodyTooLargeError(err error) bool {
3333
return errors.As(err, &mbe)
3434
}
3535

36-
func GetRequestBody(c *gin.Context) ([]byte, error) {
36+
func GetRequestBody(c *gin.Context) (io.Seeker, error) {
3737
// 首先检查是否有 BodyStorage 缓存
3838
if storage, exists := c.Get(KeyBodyStorage); exists && storage != nil {
3939
if bs, ok := storage.(BodyStorage); ok {
4040
if _, err := bs.Seek(0, io.SeekStart); err != nil {
4141
return nil, fmt.Errorf("failed to seek body storage: %w", err)
4242
}
43-
return bs.Bytes()
43+
return bs, nil
4444
}
4545
}
4646

4747
// 检查旧的缓存方式
4848
cached, exists := c.Get(KeyRequestBody)
4949
if exists && cached != nil {
5050
if b, ok := cached.([]byte); ok {
51-
return b, nil
51+
bs, err := CreateBodyStorage(b)
52+
if err != nil {
53+
return nil, err
54+
}
55+
c.Set(KeyBodyStorage, bs)
56+
return bs, nil
5257
}
5358
}
5459

@@ -74,47 +79,20 @@ func GetRequestBody(c *gin.Context) ([]byte, error) {
7479
// 缓存存储对象
7580
c.Set(KeyBodyStorage, storage)
7681

77-
// 获取字节数据
78-
body, err := storage.Bytes()
79-
if err != nil {
80-
return nil, err
81-
}
82-
83-
// 同时设置旧的缓存键以保持兼容性
84-
c.Set(KeyRequestBody, body)
85-
86-
return body, nil
82+
return storage, nil
8783
}
8884

8985
// GetBodyStorage 获取请求体存储对象(用于需要多次读取的场景)
9086
func GetBodyStorage(c *gin.Context) (BodyStorage, error) {
91-
// 检查是否已有存储
92-
if storage, exists := c.Get(KeyBodyStorage); exists && storage != nil {
93-
if bs, ok := storage.(BodyStorage); ok {
94-
if _, err := bs.Seek(0, io.SeekStart); err != nil {
95-
return nil, fmt.Errorf("failed to seek body storage: %w", err)
96-
}
97-
return bs, nil
98-
}
99-
}
100-
101-
// 如果没有,调用 GetRequestBody 创建存储
102-
_, err := GetRequestBody(c)
87+
seeker, err := GetRequestBody(c)
10388
if err != nil {
10489
return nil, err
10590
}
106-
107-
// 再次获取存储
108-
if storage, exists := c.Get(KeyBodyStorage); exists && storage != nil {
109-
if bs, ok := storage.(BodyStorage); ok {
110-
if _, err := bs.Seek(0, io.SeekStart); err != nil {
111-
return nil, fmt.Errorf("failed to seek body storage: %w", err)
112-
}
113-
return bs, nil
114-
}
91+
bs, ok := seeker.(BodyStorage)
92+
if !ok {
93+
return nil, errors.New("unexpected body storage type")
11594
}
116-
117-
return nil, errors.New("failed to get body storage")
95+
return bs, nil
11896
}
11997

12098
// CleanupBodyStorage 清理请求体存储(应在请求结束时调用)
@@ -128,13 +106,14 @@ func CleanupBodyStorage(c *gin.Context) {
128106
}
129107

130108
func UnmarshalBodyReusable(c *gin.Context, v any) error {
131-
requestBody, err := GetRequestBody(c)
109+
storage, err := GetBodyStorage(c)
110+
if err != nil {
111+
return err
112+
}
113+
requestBody, err := storage.Bytes()
132114
if err != nil {
133115
return err
134116
}
135-
//if DebugEnabled {
136-
// println("UnmarshalBodyReusable request body:", string(requestBody))
137-
//}
138117
contentType := c.Request.Header.Get("Content-Type")
139118
if strings.HasPrefix(contentType, "application/json") {
140119
err = Unmarshal(requestBody, v)
@@ -150,7 +129,10 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
150129
return err
151130
}
152131
// Reset request body
153-
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
132+
if _, seekErr := storage.Seek(0, io.SeekStart); seekErr != nil {
133+
return seekErr
134+
}
135+
c.Request.Body = io.NopCloser(storage)
154136
return nil
155137
}
156138

@@ -252,7 +234,11 @@ func init() {
252234
}
253235

254236
func ParseMultipartFormReusable(c *gin.Context) (*multipart.Form, error) {
255-
requestBody, err := GetRequestBody(c)
237+
storage, err := GetBodyStorage(c)
238+
if err != nil {
239+
return nil, err
240+
}
241+
requestBody, err := storage.Bytes()
256242
if err != nil {
257243
return nil, err
258244
}
@@ -270,7 +256,10 @@ func ParseMultipartFormReusable(c *gin.Context) (*multipart.Form, error) {
270256
}
271257

272258
// Reset request body
273-
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
259+
if _, seekErr := storage.Seek(0, io.SeekStart); seekErr != nil {
260+
return nil, seekErr
261+
}
262+
c.Request.Body = io.NopCloser(storage)
274263
return form, nil
275264
}
276265

controller/relay.go

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package controller
22

33
import (
4-
"bytes"
54
"errors"
65
"fmt"
76
"io"
@@ -193,7 +192,7 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
193192
}
194193

195194
addUsedChannel(c, channel.Id)
196-
requestBody, bodyErr := common.GetRequestBody(c)
195+
bodyStorage, bodyErr := common.GetBodyStorage(c)
197196
if bodyErr != nil {
198197
// Ensure consistent 413 for oversized bodies even when error occurs later (e.g., retry path)
199198
if common.IsRequestBodyTooLargeError(bodyErr) || errors.Is(bodyErr, common.ErrRequestBodyTooLarge) {
@@ -203,7 +202,7 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
203202
}
204203
break
205204
}
206-
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
205+
c.Request.Body = io.NopCloser(bodyStorage)
207206

208207
switch relayFormat {
209208
case types.RelayFormatOpenAIRealtime:
@@ -483,7 +482,7 @@ func RelayTask(c *gin.Context) {
483482
logger.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, retryParam.GetRetry()))
484483
//middleware.SetupContextForSelectedChannel(c, channel, originalModel)
485484

486-
requestBody, err := common.GetRequestBody(c)
485+
bodyStorage, err := common.GetBodyStorage(c)
487486
if err != nil {
488487
if common.IsRequestBodyTooLargeError(err) || errors.Is(err, common.ErrRequestBodyTooLarge) {
489488
taskErr = service.TaskErrorWrapperLocal(err, "read_request_body_failed", http.StatusRequestEntityTooLarge)
@@ -492,7 +491,7 @@ func RelayTask(c *gin.Context) {
492491
}
493492
break
494493
}
495-
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
494+
c.Request.Body = io.NopCloser(bodyStorage)
496495
taskErr = taskRelayHandler(c, relayInfo)
497496
}
498497
useChannel := c.GetStringSlice("use_channel")

relay/channel/aws/relay-aws.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,10 +165,14 @@ func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor,
165165
// buildAwsRequestBody prepares the payload for AWS requests, applying passthrough rules when enabled.
166166
func buildAwsRequestBody(c *gin.Context, info *relaycommon.RelayInfo, awsClaudeReq any) ([]byte, error) {
167167
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
168-
body, err := common.GetRequestBody(c)
168+
storage, err := common.GetBodyStorage(c)
169169
if err != nil {
170170
return nil, errors.Wrap(err, "get request body for pass-through fail")
171171
}
172+
body, err := storage.Bytes()
173+
if err != nil {
174+
return nil, errors.Wrap(err, "get request body bytes fail")
175+
}
172176
var data map[string]interface{}
173177
if err := common.Unmarshal(body, &data); err != nil {
174178
return nil, errors.Wrap(err, "pass-through unmarshal request body fail")

relay/channel/task/sora/adaptor.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package sora
22

33
import (
4-
"bytes"
54
"fmt"
65
"io"
76
"net/http"
@@ -104,11 +103,11 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
104103
}
105104

106105
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
107-
cachedBody, err := common.GetRequestBody(c)
106+
storage, err := common.GetBodyStorage(c)
108107
if err != nil {
109108
return nil, errors.Wrap(err, "get_request_body_failed")
110109
}
111-
return bytes.NewReader(cachedBody), nil
110+
return common.ReaderOnly(storage), nil
112111
}
113112

114113
// DoRequest delegates to common helper.

relay/claude_handler.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,11 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
129129

130130
var requestBody io.Reader
131131
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
132-
body, err := common.GetRequestBody(c)
132+
storage, err := common.GetBodyStorage(c)
133133
if err != nil {
134134
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
135135
}
136-
requestBody = bytes.NewBuffer(body)
136+
requestBody = common.ReaderOnly(storage)
137137
} else {
138138
convertedRequest, err := adaptor.ConvertClaudeRequest(c, info, request)
139139
if err != nil {

relay/compatible_handler.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,14 +100,16 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
100100
var requestBody io.Reader
101101

102102
if passThroughGlobal || info.ChannelSetting.PassThroughBodyEnabled {
103-
body, err := common.GetRequestBody(c)
103+
storage, err := common.GetBodyStorage(c)
104104
if err != nil {
105105
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
106106
}
107107
if common.DebugEnabled {
108-
println("requestBody: ", string(body))
108+
if debugBytes, bErr := storage.Bytes(); bErr == nil {
109+
println("requestBody: ", string(debugBytes))
110+
}
109111
}
110-
requestBody = bytes.NewBuffer(body)
112+
requestBody = common.ReaderOnly(storage)
111113
} else {
112114
convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, request)
113115
if err != nil {

relay/gemini_handler.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,11 +138,11 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
138138

139139
var requestBody io.Reader
140140
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
141-
body, err := common.GetRequestBody(c)
141+
storage, err := common.GetBodyStorage(c)
142142
if err != nil {
143143
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
144144
}
145-
requestBody = bytes.NewReader(body)
145+
requestBody = common.ReaderOnly(storage)
146146
} else {
147147
// 使用 ConvertGeminiRequest 转换请求格式
148148
convertedRequest, err := adaptor.ConvertGeminiRequest(c, info, request)

relay/image_handler.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,11 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
4747
var requestBody io.Reader
4848

4949
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
50-
body, err := common.GetRequestBody(c)
50+
storage, err := common.GetBodyStorage(c)
5151
if err != nil {
5252
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
5353
}
54-
requestBody = bytes.NewBuffer(body)
54+
requestBody = common.ReaderOnly(storage)
5555
} else {
5656
convertedRequest, err := adaptor.ConvertImageRequest(c, info, *request)
5757
if err != nil {

relay/rerank_handler.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,11 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
4343

4444
var requestBody io.Reader
4545
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
46-
body, err := common.GetRequestBody(c)
46+
storage, err := common.GetBodyStorage(c)
4747
if err != nil {
4848
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
4949
}
50-
requestBody = bytes.NewBuffer(body)
50+
requestBody = common.ReaderOnly(storage)
5151
} else {
5252
convertedRequest, err := adaptor.ConvertRerankRequest(c, info.RelayMode, *request)
5353
if err != nil {

0 commit comments

Comments
 (0)