Skip to content

Commit 8103b4b

Browse files
authored
Merge pull request #3009 from seefs001/feature/improve-param-override
feat: improve channel override ui/ux
2 parents 30061f3 + a53e139 commit 8103b4b

26 files changed

+6174
-255
lines changed

controller/channel-test.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string,
366366
newAPIError: types.NewError(err, types.ErrorCodeConvertRequestFailed),
367367
}
368368
}
369-
jsonData, err := json.Marshal(convertedRequest)
369+
jsonData, err := common.Marshal(convertedRequest)
370370
if err != nil {
371371
return testResult{
372372
context: c,
@@ -385,8 +385,15 @@ func testChannel(channel *model.Channel, testModel string, endpointType string,
385385
//}
386386

387387
if len(info.ParamOverride) > 0 {
388-
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
388+
jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info)
389389
if err != nil {
390+
if fixedErr, ok := relaycommon.AsParamOverrideReturnError(err); ok {
391+
return testResult{
392+
context: c,
393+
localErr: fixedErr,
394+
newAPIError: relaycommon.NewAPIErrorFromParamOverride(fixedErr),
395+
}
396+
}
390397
return testResult{
391398
context: c,
392399
localErr: err,

controller/relay.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,11 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
182182
ModelName: relayInfo.OriginModelName,
183183
Retry: common.GetPointer(0),
184184
}
185+
relayInfo.RetryIndex = 0
186+
relayInfo.LastError = nil
185187

186188
for ; retryParam.GetRetry() <= common.RetryTimes; retryParam.IncreaseRetry() {
189+
relayInfo.RetryIndex = retryParam.GetRetry()
187190
channel, channelErr := getChannel(c, relayInfo, retryParam)
188191
if channelErr != nil {
189192
logger.LogError(c, channelErr.Error())
@@ -216,10 +219,12 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
216219
}
217220

218221
if newAPIError == nil {
222+
relayInfo.LastError = nil
219223
return
220224
}
221225

222226
newAPIError = service.NormalizeViolationFeeError(newAPIError)
227+
relayInfo.LastError = newAPIError
223228

224229
processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
225230

middleware/distributor.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,8 +348,13 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
348348
common.SetContextKey(c, constant.ContextKeyChannelCreateTime, channel.CreatedTime)
349349
common.SetContextKey(c, constant.ContextKeyChannelSetting, channel.GetSetting())
350350
common.SetContextKey(c, constant.ContextKeyChannelOtherSetting, channel.GetOtherSettings())
351-
common.SetContextKey(c, constant.ContextKeyChannelParamOverride, channel.GetParamOverride())
352-
common.SetContextKey(c, constant.ContextKeyChannelHeaderOverride, channel.GetHeaderOverride())
351+
paramOverride := channel.GetParamOverride()
352+
headerOverride := channel.GetHeaderOverride()
353+
if mergedParam, applied := service.ApplyChannelAffinityOverrideTemplate(c, paramOverride); applied {
354+
paramOverride = mergedParam
355+
}
356+
common.SetContextKey(c, constant.ContextKeyChannelParamOverride, paramOverride)
357+
common.SetContextKey(c, constant.ContextKeyChannelHeaderOverride, headerOverride)
353358
if nil != channel.OpenAIOrganization && *channel.OpenAIOrganization != "" {
354359
common.SetContextKey(c, constant.ContextKeyChannelOrganization, *channel.OpenAIOrganization)
355360
}

relay/channel/api_request.go

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -169,12 +169,17 @@ func applyHeaderOverridePlaceholders(template string, c *gin.Context, apiKey str
169169
// Passthrough rules are applied first, then normal overrides are applied, so explicit overrides win.
170170
func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]string, error) {
171171
headerOverride := make(map[string]string)
172+
if info == nil {
173+
return headerOverride, nil
174+
}
175+
176+
headerOverrideSource := common.GetEffectiveHeaderOverride(info)
172177

173178
passAll := false
174179
var passthroughRegex []*regexp.Regexp
175180
if !info.IsChannelTest {
176-
for k := range info.HeadersOverride {
177-
key := strings.TrimSpace(k)
181+
for k := range headerOverrideSource {
182+
key := strings.TrimSpace(strings.ToLower(k))
178183
if key == "" {
179184
continue
180185
}
@@ -183,12 +188,11 @@ func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]s
183188
continue
184189
}
185190

186-
lower := strings.ToLower(key)
187191
var pattern string
188192
switch {
189-
case strings.HasPrefix(lower, headerPassthroughRegexPrefix):
193+
case strings.HasPrefix(key, headerPassthroughRegexPrefix):
190194
pattern = strings.TrimSpace(key[len(headerPassthroughRegexPrefix):])
191-
case strings.HasPrefix(lower, headerPassthroughRegexPrefixV2):
195+
case strings.HasPrefix(key, headerPassthroughRegexPrefixV2):
192196
pattern = strings.TrimSpace(key[len(headerPassthroughRegexPrefixV2):])
193197
default:
194198
continue
@@ -229,15 +233,15 @@ func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]s
229233
if value == "" {
230234
continue
231235
}
232-
headerOverride[name] = value
236+
headerOverride[strings.ToLower(strings.TrimSpace(name))] = value
233237
}
234238
}
235239

236-
for k, v := range info.HeadersOverride {
240+
for k, v := range headerOverrideSource {
237241
if isHeaderPassthroughRuleKey(k) {
238242
continue
239243
}
240-
key := strings.TrimSpace(k)
244+
key := strings.TrimSpace(strings.ToLower(k))
241245
if key == "" {
242246
continue
243247
}

relay/channel/api_request_test.go

Lines changed: 89 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ func TestProcessHeaderOverride_ChannelTestSkipsClientHeaderPlaceholder(t *testin
5353

5454
headers, err := processHeaderOverride(info, ctx)
5555
require.NoError(t, err)
56-
_, ok := headers["X-Upstream-Trace"]
56+
_, ok := headers["x-upstream-trace"]
5757
require.False(t, ok)
5858
}
5959

@@ -77,7 +77,38 @@ func TestProcessHeaderOverride_NonTestKeepsClientHeaderPlaceholder(t *testing.T)
7777

7878
headers, err := processHeaderOverride(info, ctx)
7979
require.NoError(t, err)
80-
require.Equal(t, "trace-123", headers["X-Upstream-Trace"])
80+
require.Equal(t, "trace-123", headers["x-upstream-trace"])
81+
}
82+
83+
func TestProcessHeaderOverride_RuntimeOverrideIsFinalHeaderMap(t *testing.T) {
84+
t.Parallel()
85+
86+
gin.SetMode(gin.TestMode)
87+
recorder := httptest.NewRecorder()
88+
ctx, _ := gin.CreateTestContext(recorder)
89+
ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
90+
91+
info := &relaycommon.RelayInfo{
92+
IsChannelTest: false,
93+
UseRuntimeHeadersOverride: true,
94+
RuntimeHeadersOverride: map[string]any{
95+
"x-static": "runtime-value",
96+
"x-runtime": "runtime-only",
97+
},
98+
ChannelMeta: &relaycommon.ChannelMeta{
99+
HeadersOverride: map[string]any{
100+
"X-Static": "legacy-value",
101+
"X-Legacy": "legacy-only",
102+
},
103+
},
104+
}
105+
106+
headers, err := processHeaderOverride(info, ctx)
107+
require.NoError(t, err)
108+
require.Equal(t, "runtime-value", headers["x-static"])
109+
require.Equal(t, "runtime-only", headers["x-runtime"])
110+
_, exists := headers["x-legacy"]
111+
require.False(t, exists)
81112
}
82113

83114
func TestProcessHeaderOverride_PassthroughSkipsAcceptEncoding(t *testing.T) {
@@ -101,8 +132,62 @@ func TestProcessHeaderOverride_PassthroughSkipsAcceptEncoding(t *testing.T) {
101132

102133
headers, err := processHeaderOverride(info, ctx)
103134
require.NoError(t, err)
104-
require.Equal(t, "trace-123", headers["X-Trace-Id"])
135+
require.Equal(t, "trace-123", headers["x-trace-id"])
105136

106-
_, hasAcceptEncoding := headers["Accept-Encoding"]
137+
_, hasAcceptEncoding := headers["accept-encoding"]
107138
require.False(t, hasAcceptEncoding)
108139
}
140+
141+
func TestProcessHeaderOverride_PassHeadersTemplateSetsRuntimeHeaders(t *testing.T) {
142+
t.Parallel()
143+
144+
gin.SetMode(gin.TestMode)
145+
recorder := httptest.NewRecorder()
146+
ctx, _ := gin.CreateTestContext(recorder)
147+
ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
148+
ctx.Request.Header.Set("Originator", "Codex CLI")
149+
ctx.Request.Header.Set("Session_id", "sess-123")
150+
151+
info := &relaycommon.RelayInfo{
152+
IsChannelTest: false,
153+
RequestHeaders: map[string]string{
154+
"Originator": "Codex CLI",
155+
"Session_id": "sess-123",
156+
},
157+
ChannelMeta: &relaycommon.ChannelMeta{
158+
ParamOverride: map[string]any{
159+
"operations": []any{
160+
map[string]any{
161+
"mode": "pass_headers",
162+
"value": []any{"Originator", "Session_id", "X-Codex-Beta-Features"},
163+
},
164+
},
165+
},
166+
HeadersOverride: map[string]any{
167+
"X-Static": "legacy-value",
168+
},
169+
},
170+
}
171+
172+
_, err := relaycommon.ApplyParamOverrideWithRelayInfo([]byte(`{"model":"gpt-4.1"}`), info)
173+
require.NoError(t, err)
174+
require.True(t, info.UseRuntimeHeadersOverride)
175+
require.Equal(t, "Codex CLI", info.RuntimeHeadersOverride["originator"])
176+
require.Equal(t, "sess-123", info.RuntimeHeadersOverride["session_id"])
177+
_, exists := info.RuntimeHeadersOverride["x-codex-beta-features"]
178+
require.False(t, exists)
179+
require.Equal(t, "legacy-value", info.RuntimeHeadersOverride["x-static"])
180+
181+
headers, err := processHeaderOverride(info, ctx)
182+
require.NoError(t, err)
183+
require.Equal(t, "Codex CLI", headers["originator"])
184+
require.Equal(t, "sess-123", headers["session_id"])
185+
_, exists = headers["x-codex-beta-features"]
186+
require.False(t, exists)
187+
188+
upstreamReq := httptest.NewRequest(http.MethodPost, "https://example.com/v1/responses", nil)
189+
applyHeaderOverrideToRequest(upstreamReq, headers)
190+
require.Equal(t, "Codex CLI", upstreamReq.Header.Get("Originator"))
191+
require.Equal(t, "sess-123", upstreamReq.Header.Get("Session_id"))
192+
require.Empty(t, upstreamReq.Header.Get("X-Codex-Beta-Features"))
193+
}

relay/chat_completions_via_responses.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ func applySystemPromptIfNeeded(c *gin.Context, info *relaycommon.RelayInfo, requ
7070
}
7171

7272
func chatCompletionsViaResponses(c *gin.Context, info *relaycommon.RelayInfo, adaptor channel.Adaptor, request *dto.GeneralOpenAIRequest) (*dto.Usage, *types.NewAPIError) {
73-
overrideCtx := relaycommon.BuildParamOverrideContext(info)
7473
chatJSON, err := common.Marshal(request)
7574
if err != nil {
7675
return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
@@ -82,9 +81,9 @@ func chatCompletionsViaResponses(c *gin.Context, info *relaycommon.RelayInfo, ad
8281
}
8382

8483
if len(info.ParamOverride) > 0 {
85-
chatJSON, err = relaycommon.ApplyParamOverride(chatJSON, info.ParamOverride, overrideCtx)
84+
chatJSON, err = relaycommon.ApplyParamOverrideWithRelayInfo(chatJSON, info)
8685
if err != nil {
87-
return nil, types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
86+
return nil, newAPIErrorFromParamOverride(err)
8887
}
8988
}
9089

relay/claude_handler.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,9 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
153153

154154
// apply param override
155155
if len(info.ParamOverride) > 0 {
156-
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
156+
jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info)
157157
if err != nil {
158-
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
158+
return newAPIErrorFromParamOverride(err)
159159
}
160160
}
161161

0 commit comments

Comments
 (0)