Skip to content

Commit d225558

Browse files
committed
feat: improve error handling with added status codes and headers
- Updated Execute methods to include enhanced error handling via `StatusCode` and `Headers` extraction. - Introduced structured error responses for cooling down scenarios, providing additional metadata and retry suggestions. - Refined quota management, allowing for differentiation between cool-down, disabled, and other block reasons. - Improved model filtering logic based on client availability and suspension criteria.
1 parent 9678be7 commit d225558

File tree

3 files changed

+211
-32
lines changed

3 files changed

+211
-32
lines changed

internal/registry/model_registry.go

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -352,14 +352,14 @@ func cloneModelInfo(model *ModelInfo) *ModelInfo {
352352
if model == nil {
353353
return nil
354354
}
355-
copy := *model
355+
copyModel := *model
356356
if len(model.SupportedGenerationMethods) > 0 {
357-
copy.SupportedGenerationMethods = append([]string(nil), model.SupportedGenerationMethods...)
357+
copyModel.SupportedGenerationMethods = append([]string(nil), model.SupportedGenerationMethods...)
358358
}
359359
if len(model.SupportedParameters) > 0 {
360-
copy.SupportedParameters = append([]string(nil), model.SupportedParameters...)
360+
copyModel.SupportedParameters = append([]string(nil), model.SupportedParameters...)
361361
}
362-
return &copy
362+
return &copyModel
363363
}
364364

365365
// UnregisterClient removes a client and decrements counts for its models
@@ -532,17 +532,25 @@ func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any
532532
}
533533
}
534534

535-
suspendedClients := 0
535+
cooldownSuspended := 0
536+
otherSuspended := 0
536537
if registration.SuspendedClients != nil {
537-
suspendedClients = len(registration.SuspendedClients)
538+
for _, reason := range registration.SuspendedClients {
539+
if strings.EqualFold(reason, "quota") {
540+
cooldownSuspended++
541+
continue
542+
}
543+
otherSuspended++
544+
}
538545
}
539-
effectiveClients := availableClients - expiredClients - suspendedClients
546+
547+
effectiveClients := availableClients - expiredClients - otherSuspended
540548
if effectiveClients < 0 {
541549
effectiveClients = 0
542550
}
543551

544-
// Only include models that have available clients
545-
if effectiveClients > 0 {
552+
// Include models that have available clients, or those solely cooling down.
553+
if effectiveClients > 0 || (availableClients > 0 && (expiredClients > 0 || cooldownSuspended > 0) && otherSuspended == 0) {
546554
model := r.convertModelToMap(registration.Info, handlerType)
547555
if model != nil {
548556
models = append(models, model)

sdk/api/handlers/handlers.go

Lines changed: 63 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,19 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
156156
}
157157
resp, err := h.AuthManager.Execute(ctx, providers, req, opts)
158158
if err != nil {
159-
return nil, &interfaces.ErrorMessage{StatusCode: http.StatusInternalServerError, Error: err}
159+
status := http.StatusInternalServerError
160+
if se, ok := err.(interface{ StatusCode() int }); ok && se != nil {
161+
if code := se.StatusCode(); code > 0 {
162+
status = code
163+
}
164+
}
165+
var addon http.Header
166+
if he, ok := err.(interface{ Headers() http.Header }); ok && he != nil {
167+
if hdr := he.Headers(); hdr != nil {
168+
addon = hdr.Clone()
169+
}
170+
}
171+
return nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
160172
}
161173
return cloneBytes(resp.Payload), nil
162174
}
@@ -187,7 +199,19 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
187199
}
188200
resp, err := h.AuthManager.ExecuteCount(ctx, providers, req, opts)
189201
if err != nil {
190-
return nil, &interfaces.ErrorMessage{StatusCode: http.StatusInternalServerError, Error: err}
202+
status := http.StatusInternalServerError
203+
if se, ok := err.(interface{ StatusCode() int }); ok && se != nil {
204+
if code := se.StatusCode(); code > 0 {
205+
status = code
206+
}
207+
}
208+
var addon http.Header
209+
if he, ok := err.(interface{ Headers() http.Header }); ok && he != nil {
210+
if hdr := he.Headers(); hdr != nil {
211+
addon = hdr.Clone()
212+
}
213+
}
214+
return nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
191215
}
192216
return cloneBytes(resp.Payload), nil
193217
}
@@ -222,7 +246,19 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
222246
chunks, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
223247
if err != nil {
224248
errChan := make(chan *interfaces.ErrorMessage, 1)
225-
errChan <- &interfaces.ErrorMessage{StatusCode: http.StatusInternalServerError, Error: err}
249+
status := http.StatusInternalServerError
250+
if se, ok := err.(interface{ StatusCode() int }); ok && se != nil {
251+
if code := se.StatusCode(); code > 0 {
252+
status = code
253+
}
254+
}
255+
var addon http.Header
256+
if he, ok := err.(interface{ Headers() http.Header }); ok && he != nil {
257+
if hdr := he.Headers(); hdr != nil {
258+
addon = hdr.Clone()
259+
}
260+
}
261+
errChan <- &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
226262
close(errChan)
227263
return nil, errChan
228264
}
@@ -233,7 +269,19 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
233269
defer close(errChan)
234270
for chunk := range chunks {
235271
if chunk.Err != nil {
236-
errChan <- &interfaces.ErrorMessage{StatusCode: http.StatusInternalServerError, Error: chunk.Err}
272+
status := http.StatusInternalServerError
273+
if se, ok := chunk.Err.(interface{ StatusCode() int }); ok && se != nil {
274+
if code := se.StatusCode(); code > 0 {
275+
status = code
276+
}
277+
}
278+
var addon http.Header
279+
if he, ok := chunk.Err.(interface{ Headers() http.Header }); ok && he != nil {
280+
if hdr := he.Headers(); hdr != nil {
281+
addon = hdr.Clone()
282+
}
283+
}
284+
errChan <- &interfaces.ErrorMessage{StatusCode: status, Error: chunk.Err, Addon: addon}
237285
return
238286
}
239287
if len(chunk.Payload) > 0 {
@@ -287,6 +335,17 @@ func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.Erro
287335
if msg != nil && msg.StatusCode > 0 {
288336
status = msg.StatusCode
289337
}
338+
if msg != nil && msg.Addon != nil {
339+
for key, values := range msg.Addon {
340+
if len(values) == 0 {
341+
continue
342+
}
343+
c.Writer.Header().Del(key)
344+
for _, value := range values {
345+
c.Writer.Header().Add(key, value)
346+
}
347+
}
348+
}
290349
c.Status(status)
291350
if msg != nil && msg.Error != nil {
292351
_, _ = c.Writer.Write([]byte(msg.Error.Error()))

sdk/cliproxy/auth/selector.go

Lines changed: 131 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@ package auth
22

33
import (
44
"context"
5+
"encoding/json"
6+
"fmt"
7+
"math"
8+
"net/http"
59
"sort"
10+
"strconv"
611
"sync"
712
"time"
813

@@ -15,6 +20,84 @@ type RoundRobinSelector struct {
1520
cursors map[string]int
1621
}
1722

23+
type blockReason int
24+
25+
const (
26+
blockReasonNone blockReason = iota
27+
blockReasonCooldown
28+
blockReasonDisabled
29+
blockReasonOther
30+
)
31+
32+
type modelCooldownError struct {
33+
model string
34+
resetIn time.Duration
35+
provider string
36+
}
37+
38+
func newModelCooldownError(model, provider string, resetIn time.Duration) *modelCooldownError {
39+
if resetIn < 0 {
40+
resetIn = 0
41+
}
42+
return &modelCooldownError{
43+
model: model,
44+
provider: provider,
45+
resetIn: resetIn,
46+
}
47+
}
48+
49+
func (e *modelCooldownError) Error() string {
50+
modelName := e.model
51+
if modelName == "" {
52+
modelName = "requested model"
53+
}
54+
message := fmt.Sprintf("All credentials for model %s are cooling down", modelName)
55+
if e.provider != "" {
56+
message = fmt.Sprintf("%s via provider %s", message, e.provider)
57+
}
58+
resetSeconds := int(math.Ceil(e.resetIn.Seconds()))
59+
if resetSeconds < 0 {
60+
resetSeconds = 0
61+
}
62+
displayDuration := e.resetIn
63+
if displayDuration > 0 && displayDuration < time.Second {
64+
displayDuration = time.Second
65+
} else {
66+
displayDuration = displayDuration.Round(time.Second)
67+
}
68+
errorBody := map[string]any{
69+
"code": "model_cooldown",
70+
"message": message,
71+
"model": e.model,
72+
"reset_time": displayDuration.String(),
73+
"reset_seconds": resetSeconds,
74+
}
75+
if e.provider != "" {
76+
errorBody["provider"] = e.provider
77+
}
78+
payload := map[string]any{"error": errorBody}
79+
data, err := json.Marshal(payload)
80+
if err != nil {
81+
return fmt.Sprintf(`{"error":{"code":"model_cooldown","message":"%s"}}`, message)
82+
}
83+
return string(data)
84+
}
85+
86+
func (e *modelCooldownError) StatusCode() int {
87+
return http.StatusTooManyRequests
88+
}
89+
90+
func (e *modelCooldownError) Headers() http.Header {
91+
headers := make(http.Header)
92+
headers.Set("Content-Type", "application/json")
93+
resetSeconds := int(math.Ceil(e.resetIn.Seconds()))
94+
if resetSeconds < 0 {
95+
resetSeconds = 0
96+
}
97+
headers.Set("Retry-After", strconv.Itoa(resetSeconds))
98+
return headers
99+
}
100+
18101
// Pick selects the next available auth for the provider in a round-robin manner.
19102
func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
20103
_ = ctx
@@ -27,14 +110,30 @@ func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, o
27110
}
28111
available := make([]*Auth, 0, len(auths))
29112
now := time.Now()
113+
cooldownCount := 0
114+
var earliest time.Time
30115
for i := 0; i < len(auths); i++ {
31116
candidate := auths[i]
32-
if isAuthBlockedForModel(candidate, model, now) {
117+
blocked, reason, next := isAuthBlockedForModel(candidate, model, now)
118+
if !blocked {
119+
available = append(available, candidate)
33120
continue
34121
}
35-
available = append(available, candidate)
122+
if reason == blockReasonCooldown {
123+
cooldownCount++
124+
if !next.IsZero() && (earliest.IsZero() || next.Before(earliest)) {
125+
earliest = next
126+
}
127+
}
36128
}
37129
if len(available) == 0 {
130+
if cooldownCount == len(auths) && !earliest.IsZero() {
131+
resetIn := earliest.Sub(now)
132+
if resetIn < 0 {
133+
resetIn = 0
134+
}
135+
return nil, newModelCooldownError(model, provider, resetIn)
136+
}
38137
return nil, &Error{Code: "auth_unavailable", Message: "no auth available"}
39138
}
40139
// Make round-robin deterministic even if caller's candidate order is unstable.
@@ -55,41 +154,54 @@ func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, o
55154
return available[index%len(available)], nil
56155
}
57156

58-
func isAuthBlockedForModel(auth *Auth, model string, now time.Time) bool {
157+
func isAuthBlockedForModel(auth *Auth, model string, now time.Time) (bool, blockReason, time.Time) {
59158
if auth == nil {
60-
return true
159+
return true, blockReasonOther, time.Time{}
61160
}
62161
if auth.Disabled || auth.Status == StatusDisabled {
63-
return true
162+
return true, blockReasonDisabled, time.Time{}
64163
}
65-
// If a specific model is requested, prefer its per-model state over any aggregated
66-
// auth-level unavailable flag. This prevents a failure on one model (e.g., 429 quota)
67-
// from blocking other models of the same provider that have no errors.
68164
if model != "" {
69165
if len(auth.ModelStates) > 0 {
70166
if state, ok := auth.ModelStates[model]; ok && state != nil {
71167
if state.Status == StatusDisabled {
72-
return true
168+
return true, blockReasonDisabled, time.Time{}
73169
}
74170
if state.Unavailable {
75171
if state.NextRetryAfter.IsZero() {
76-
return false
172+
return false, blockReasonNone, time.Time{}
77173
}
78174
if state.NextRetryAfter.After(now) {
79-
return true
175+
next := state.NextRetryAfter
176+
if !state.Quota.NextRecoverAt.IsZero() && state.Quota.NextRecoverAt.After(now) {
177+
next = state.Quota.NextRecoverAt
178+
}
179+
if next.Before(now) {
180+
next = now
181+
}
182+
if state.Quota.Exceeded {
183+
return true, blockReasonCooldown, next
184+
}
185+
return true, blockReasonOther, next
80186
}
81187
}
82-
// Explicit state exists and is not blocking.
83-
return false
188+
return false, blockReasonNone, time.Time{}
84189
}
85190
}
86-
// No explicit state for this model; do not block based on aggregated
87-
// auth-level unavailable status. Allow trying this model.
88-
return false
191+
return false, blockReasonNone, time.Time{}
89192
}
90-
// No specific model context: fall back to auth-level unavailable window.
91193
if auth.Unavailable && auth.NextRetryAfter.After(now) {
92-
return true
194+
next := auth.NextRetryAfter
195+
if !auth.Quota.NextRecoverAt.IsZero() && auth.Quota.NextRecoverAt.After(now) {
196+
next = auth.Quota.NextRecoverAt
197+
}
198+
if next.Before(now) {
199+
next = now
200+
}
201+
if auth.Quota.Exceeded {
202+
return true, blockReasonCooldown, next
203+
}
204+
return true, blockReasonOther, next
93205
}
94-
return false
206+
return false, blockReasonNone, time.Time{}
95207
}

0 commit comments

Comments
 (0)