Skip to content

Commit c32e013

Browse files
committed
feat(aistudio): track Gemini usage and improve stream errors
1 parent 3839d93 commit c32e013

File tree

3 files changed

+52
-22
lines changed

3 files changed

+52
-22
lines changed

internal/runtime/executor/aistudio_executor.go

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
1515
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
1616
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
17+
"github.com/tidwall/gjson"
1718
"github.com/tidwall/sjson"
1819
)
1920

@@ -37,10 +38,13 @@ func (e *AistudioExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth)
3738
return nil
3839
}
3940

40-
func (e *AistudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
41+
func (e *AistudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
42+
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
43+
defer reporter.trackFailure(ctx, &err)
44+
4145
translatedReq, body, err := e.translateRequest(req, opts, false)
4246
if err != nil {
43-
return cliproxyexecutor.Response{}, err
47+
return resp, err
4448
}
4549
endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt)
4650
wsReq := &wsrelay.HTTPRequest{
@@ -68,24 +72,29 @@ func (e *AistudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
6872
AuthValue: authValue,
6973
})
7074

71-
resp, err := e.relay.RoundTrip(ctx, e.provider, wsReq)
75+
wsResp, err := e.relay.RoundTrip(ctx, e.provider, wsReq)
7276
if err != nil {
7377
recordAPIResponseError(ctx, e.cfg, err)
74-
return cliproxyexecutor.Response{}, err
78+
return resp, err
7579
}
76-
recordAPIResponseMetadata(ctx, e.cfg, resp.Status, resp.Headers.Clone())
77-
if len(resp.Body) > 0 {
78-
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(resp.Body))
80+
recordAPIResponseMetadata(ctx, e.cfg, wsResp.Status, wsResp.Headers.Clone())
81+
if len(wsResp.Body) > 0 {
82+
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(wsResp.Body))
7983
}
80-
if resp.Status < 200 || resp.Status >= 300 {
81-
return cliproxyexecutor.Response{}, statusErr{code: resp.Status, msg: string(resp.Body)}
84+
if wsResp.Status < 200 || wsResp.Status >= 300 {
85+
return resp, statusErr{code: wsResp.Status, msg: string(wsResp.Body)}
8286
}
87+
reporter.publish(ctx, parseGeminiUsage(wsResp.Body))
8388
var param any
84-
out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, bytes.Clone(resp.Body), &param)
85-
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
89+
out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), bytes.Clone(translatedReq), bytes.Clone(wsResp.Body), &param)
90+
resp = cliproxyexecutor.Response{Payload: []byte(out)}
91+
return resp, nil
8692
}
8793

88-
func (e *AistudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
94+
func (e *AistudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
95+
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
96+
defer reporter.trackFailure(ctx, &err)
97+
8998
translatedReq, body, err := e.translateRequest(req, opts, true)
9099
if err != nil {
91100
return nil, err
@@ -114,20 +123,22 @@ func (e *AistudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
114123
AuthType: authType,
115124
AuthValue: authValue,
116125
})
117-
stream, err := e.relay.Stream(ctx, e.provider, wsReq)
126+
wsStream, err := e.relay.Stream(ctx, e.provider, wsReq)
118127
if err != nil {
119128
recordAPIResponseError(ctx, e.cfg, err)
120129
return nil, err
121130
}
122131
out := make(chan cliproxyexecutor.StreamChunk)
132+
stream = out
123133
go func() {
124134
defer close(out)
125135
var param any
126136
metadataLogged := false
127-
for event := range stream {
137+
for event := range wsStream {
128138
if event.Err != nil {
129139
recordAPIResponseError(ctx, e.cfg, event.Err)
130-
out <- cliproxyexecutor.StreamChunk{Err: event.Err}
140+
reporter.publishFailure(ctx)
141+
out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)}
131142
return
132143
}
133144
switch event.Type {
@@ -139,6 +150,9 @@ func (e *AistudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
139150
case wsrelay.MessageTypeStreamChunk:
140151
if len(event.Payload) > 0 {
141152
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(event.Payload))
153+
if detail, ok := parseGeminiStreamUsage(event.Payload); ok {
154+
reporter.publish(ctx, detail)
155+
}
142156
}
143157
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, bytes.Clone(event.Payload), &param)
144158
for i := range lines {
@@ -158,19 +172,21 @@ func (e *AistudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
158172
for i := range lines {
159173
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
160174
}
175+
reporter.publish(ctx, parseGeminiUsage(event.Payload))
161176
return
162177
case wsrelay.MessageTypeError:
163178
recordAPIResponseError(ctx, e.cfg, event.Err)
179+
reporter.publishFailure(ctx)
164180
out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)}
165181
return
166182
}
167183
}
168184
}()
169-
return out, nil
185+
return stream, nil
170186
}
171187

172188
func (e *AistudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
173-
translatedReq, body, err := e.translateRequest(req, opts, false)
189+
_, body, err := e.translateRequest(req, opts, false)
174190
if err != nil {
175191
return cliproxyexecutor.Response{}, err
176192
}
@@ -210,9 +226,12 @@ func (e *AistudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.A
210226
if resp.Status < 200 || resp.Status >= 300 {
211227
return cliproxyexecutor.Response{}, statusErr{code: resp.Status, msg: string(resp.Body)}
212228
}
213-
var param any
214-
out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, bytes.Clone(resp.Body), &param)
215-
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
229+
totalTokens := gjson.GetBytes(resp.Body, "totalTokens").Int()
230+
if totalTokens <= 0 {
231+
return cliproxyexecutor.Response{}, fmt.Errorf("wsrelay: totalTokens missing in response")
232+
}
233+
translated := sdktranslator.TranslateTokenCount(ctx, body.toFormat, opts.SourceFormat, totalTokens, bytes.Clone(resp.Body))
234+
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
216235
}
217236

218237
func (e *AistudioExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {

internal/wsrelay/manager.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,11 +142,16 @@ func (m *Manager) handleWebsocket(w http.ResponseWriter, r *http.Request) {
142142
s.provider = strings.ToLower(s.id)
143143
}
144144
m.sessMutex.Lock()
145+
var replaced *session
145146
if existing, ok := m.sessions[s.provider]; ok {
146-
existing.cleanup(errors.New("replaced by new connection"))
147+
replaced = existing
147148
}
148149
m.sessions[s.provider] = s
149150
m.sessMutex.Unlock()
151+
152+
if replaced != nil {
153+
replaced.cleanup(errors.New("replaced by new connection"))
154+
}
150155
if m.onConnected != nil {
151156
m.onConnected(s.provider)
152157
}

sdk/cliproxy/service.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,9 @@ func (s *Service) wsOnConnected(provider string) {
203203
}
204204
if s.coreManager != nil {
205205
if existing, ok := s.coreManager.GetByID(provider); ok && existing != nil {
206-
return
206+
if !existing.Disabled && existing.Status == coreauth.StatusActive {
207+
return
208+
}
207209
}
208210
}
209211
now := time.Now().UTC()
@@ -225,6 +227,10 @@ func (s *Service) wsOnDisconnected(provider string, reason error) {
225227
return
226228
}
227229
if reason != nil {
230+
if strings.Contains(reason.Error(), "replaced by new connection") {
231+
log.Infof("websocket provider replaced: %s", provider)
232+
return
233+
}
228234
log.Warnf("websocket provider disconnected: %s (%v)", provider, reason)
229235
} else {
230236
log.Infof("websocket provider disconnected: %s", provider)

0 commit comments

Comments
 (0)