Skip to content

Commit 9ecad90

Browse files
committed
refactor: optimize billing flow for OpenAI-to-Anthropic convert
1 parent deff59a commit 9ecad90

16 files changed

+809
-433
lines changed

dto/openai_response.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -220,10 +220,12 @@ type CompletionsStreamResponse struct {
220220
}
221221

222222
type Usage struct {
223-
PromptTokens int `json:"prompt_tokens"`
224-
CompletionTokens int `json:"completion_tokens"`
225-
TotalTokens int `json:"total_tokens"`
226-
PromptCacheHitTokens int `json:"prompt_cache_hit_tokens,omitempty"`
223+
PromptTokens int `json:"prompt_tokens"`
224+
CompletionTokens int `json:"completion_tokens"`
225+
TotalTokens int `json:"total_tokens"`
226+
PromptCacheHitTokens int `json:"prompt_cache_hit_tokens,omitempty"`
227+
UsageSemantic string `json:"usage_semantic,omitempty"`
228+
UsageSource string `json:"usage_source,omitempty"`
227229

228230
PromptTokensDetails InputTokenDetails `json:"prompt_tokens_details"`
229231
CompletionTokenDetails OutputTokenDetails `json:"completion_tokens_details"`
@@ -251,7 +253,7 @@ type OpenAIVideoResponse struct {
251253

252254
type InputTokenDetails struct {
253255
CachedTokens int `json:"cached_tokens"`
254-
CachedCreationTokens int `json:"-"`
256+
CachedCreationTokens int `json:"cached_creation_tokens,omitempty"`
255257
TextTokens int `json:"text_tokens"`
256258
AudioTokens int `json:"audio_tokens"`
257259
ImageTokens int `json:"image_tokens"`

relay/audio_handler.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
7070
if usage.(*dto.Usage).CompletionTokenDetails.AudioTokens > 0 || usage.(*dto.Usage).PromptTokensDetails.AudioTokens > 0 {
7171
service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
7272
} else {
73-
postConsumeQuota(c, info, usage.(*dto.Usage))
73+
service.PostTextConsumeQuota(c, info, usage.(*dto.Usage), nil)
7474
}
7575

7676
return nil

relay/channel/claude/relay-claude.go

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,35 @@ type ClaudeResponseInfo struct {
555555
Done bool
556556
}
557557

558+
func cacheCreationTokensForOpenAIUsage(usage *dto.Usage) int {
559+
if usage == nil {
560+
return 0
561+
}
562+
splitCacheCreationTokens := usage.ClaudeCacheCreation5mTokens + usage.ClaudeCacheCreation1hTokens
563+
if splitCacheCreationTokens == 0 {
564+
return usage.PromptTokensDetails.CachedCreationTokens
565+
}
566+
if usage.PromptTokensDetails.CachedCreationTokens > splitCacheCreationTokens {
567+
return usage.PromptTokensDetails.CachedCreationTokens
568+
}
569+
return splitCacheCreationTokens
570+
}
571+
572+
func buildOpenAIStyleUsageFromClaudeUsage(usage *dto.Usage) dto.Usage {
573+
if usage == nil {
574+
return dto.Usage{}
575+
}
576+
clone := *usage
577+
cacheCreationTokens := cacheCreationTokensForOpenAIUsage(usage)
578+
totalInputTokens := usage.PromptTokens + usage.PromptTokensDetails.CachedTokens + cacheCreationTokens
579+
clone.PromptTokens = totalInputTokens
580+
clone.InputTokens = totalInputTokens
581+
clone.TotalTokens = totalInputTokens + usage.CompletionTokens
582+
clone.UsageSemantic = "openai"
583+
clone.UsageSource = "anthropic"
584+
return clone
585+
}
586+
558587
func buildMessageDeltaPatchUsage(claudeResponse *dto.ClaudeResponse, claudeInfo *ClaudeResponseInfo) *dto.ClaudeUsage {
559588
usage := &dto.ClaudeUsage{}
560589
if claudeResponse != nil && claudeResponse.Usage != nil {
@@ -643,6 +672,7 @@ func FormatClaudeResponseInfo(claudeResponse *dto.ClaudeResponse, oaiResponse *d
643672
// message_start, 获取usage
644673
if claudeResponse.Message != nil && claudeResponse.Message.Usage != nil {
645674
claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
675+
claudeInfo.Usage.UsageSemantic = "anthropic"
646676
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
647677
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
648678
claudeInfo.Usage.ClaudeCacheCreation5mTokens = claudeResponse.Message.Usage.GetCacheCreation5mTokens()
@@ -661,6 +691,7 @@ func FormatClaudeResponseInfo(claudeResponse *dto.ClaudeResponse, oaiResponse *d
661691
} else if claudeResponse.Type == "message_delta" {
662692
// 最终的usage获取
663693
if claudeResponse.Usage != nil {
694+
claudeInfo.Usage.UsageSemantic = "anthropic"
664695
if claudeResponse.Usage.InputTokens > 0 {
665696
// 不叠加,只取最新的
666697
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
@@ -754,12 +785,16 @@ func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, clau
754785
}
755786
claudeInfo.Usage = service.ResponseText2Usage(c, claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
756787
}
788+
if claudeInfo.Usage != nil {
789+
claudeInfo.Usage.UsageSemantic = "anthropic"
790+
}
757791

758792
if info.RelayFormat == types.RelayFormatClaude {
759793
//
760794
} else if info.RelayFormat == types.RelayFormatOpenAI {
761795
if info.ShouldIncludeUsage {
762-
response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
796+
openAIUsage := buildOpenAIStyleUsageFromClaudeUsage(claudeInfo.Usage)
797+
response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, openAIUsage)
763798
err := helper.ObjectData(c, response)
764799
if err != nil {
765800
common.SysLog("send final response failed: " + err.Error())
@@ -810,6 +845,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
810845
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
811846
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
812847
claudeInfo.Usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens
848+
claudeInfo.Usage.UsageSemantic = "anthropic"
813849
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Usage.CacheReadInputTokens
814850
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Usage.CacheCreationInputTokens
815851
claudeInfo.Usage.ClaudeCacheCreation5mTokens = claudeResponse.Usage.GetCacheCreation5mTokens()
@@ -819,7 +855,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
819855
switch info.RelayFormat {
820856
case types.RelayFormatOpenAI:
821857
openaiResponse := ResponseClaude2OpenAI(&claudeResponse)
822-
openaiResponse.Usage = *claudeInfo.Usage
858+
openaiResponse.Usage = buildOpenAIStyleUsageFromClaudeUsage(claudeInfo.Usage)
823859
responseData, err = json.Marshal(openaiResponse)
824860
if err != nil {
825861
return types.NewError(err, types.ErrorCodeBadResponseBody)

relay/channel/claude/relay_claude_test.go

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,3 +173,85 @@ func TestFormatClaudeResponseInfo_ContentBlockDelta(t *testing.T) {
173173
t.Errorf("ResponseText = %q, want %q", claudeInfo.ResponseText.String(), "hello")
174174
}
175175
}
176+
177+
func TestBuildOpenAIStyleUsageFromClaudeUsage(t *testing.T) {
178+
usage := &dto.Usage{
179+
PromptTokens: 100,
180+
CompletionTokens: 20,
181+
PromptTokensDetails: dto.InputTokenDetails{
182+
CachedTokens: 30,
183+
CachedCreationTokens: 50,
184+
},
185+
ClaudeCacheCreation5mTokens: 10,
186+
ClaudeCacheCreation1hTokens: 20,
187+
UsageSemantic: "anthropic",
188+
}
189+
190+
openAIUsage := buildOpenAIStyleUsageFromClaudeUsage(usage)
191+
192+
if openAIUsage.PromptTokens != 180 {
193+
t.Fatalf("PromptTokens = %d, want 180", openAIUsage.PromptTokens)
194+
}
195+
if openAIUsage.InputTokens != 180 {
196+
t.Fatalf("InputTokens = %d, want 180", openAIUsage.InputTokens)
197+
}
198+
if openAIUsage.TotalTokens != 200 {
199+
t.Fatalf("TotalTokens = %d, want 200", openAIUsage.TotalTokens)
200+
}
201+
if openAIUsage.UsageSemantic != "openai" {
202+
t.Fatalf("UsageSemantic = %s, want openai", openAIUsage.UsageSemantic)
203+
}
204+
if openAIUsage.UsageSource != "anthropic" {
205+
t.Fatalf("UsageSource = %s, want anthropic", openAIUsage.UsageSource)
206+
}
207+
}
208+
209+
func TestBuildOpenAIStyleUsageFromClaudeUsagePreservesCacheCreationRemainder(t *testing.T) {
210+
tests := []struct {
211+
name string
212+
cachedCreationTokens int
213+
cacheCreationTokens5m int
214+
cacheCreationTokens1h int
215+
expectedTotalInputToken int
216+
}{
217+
{
218+
name: "prefers aggregate when it includes remainder",
219+
cachedCreationTokens: 50,
220+
cacheCreationTokens5m: 10,
221+
cacheCreationTokens1h: 20,
222+
expectedTotalInputToken: 180,
223+
},
224+
{
225+
name: "falls back to split tokens when aggregate missing",
226+
cachedCreationTokens: 0,
227+
cacheCreationTokens5m: 10,
228+
cacheCreationTokens1h: 20,
229+
expectedTotalInputToken: 160,
230+
},
231+
}
232+
233+
for _, tt := range tests {
234+
t.Run(tt.name, func(t *testing.T) {
235+
usage := &dto.Usage{
236+
PromptTokens: 100,
237+
CompletionTokens: 20,
238+
PromptTokensDetails: dto.InputTokenDetails{
239+
CachedTokens: 30,
240+
CachedCreationTokens: tt.cachedCreationTokens,
241+
},
242+
ClaudeCacheCreation5mTokens: tt.cacheCreationTokens5m,
243+
ClaudeCacheCreation1hTokens: tt.cacheCreationTokens1h,
244+
UsageSemantic: "anthropic",
245+
}
246+
247+
openAIUsage := buildOpenAIStyleUsageFromClaudeUsage(usage)
248+
249+
if openAIUsage.PromptTokens != tt.expectedTotalInputToken {
250+
t.Fatalf("PromptTokens = %d, want %d", openAIUsage.PromptTokens, tt.expectedTotalInputToken)
251+
}
252+
if openAIUsage.InputTokens != tt.expectedTotalInputToken {
253+
t.Fatalf("InputTokens = %d, want %d", openAIUsage.InputTokens, tt.expectedTotalInputToken)
254+
}
255+
})
256+
}
257+
}

relay/claude_handler.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
122122
return newApiErr
123123
}
124124

125-
service.PostClaudeConsumeQuota(c, info, usage)
125+
service.PostTextConsumeQuota(c, info, usage, nil)
126126
return nil
127127
}
128128

@@ -190,6 +190,6 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
190190
return newAPIError
191191
}
192192

193-
service.PostClaudeConsumeQuota(c, info, usage.(*dto.Usage))
193+
service.PostTextConsumeQuota(c, info, usage.(*dto.Usage), nil)
194194
return nil
195195
}

0 commit comments

Comments
 (0)