Skip to content

Commit 19b7803

Browse files
committed
refactor: refactor token usage and response creation logic
- Add import for `go-openai` package - Refactor response creation to include detailed token usage if cached tokens are present - Modify return statements to use the new `usage` variable for both `Completion` and `GetSummaryPrefix` functions Signed-off-by: Bo-Yi Wu <[email protected]>
1 parent 0122019 commit 19b7803

File tree

2 files changed

+54
-20
lines changed

2 files changed

+54
-20
lines changed

provider/anthropic/anthropic.go

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212

1313
"github.com/appleboy/com/convert"
1414
"github.com/liushuangls/go-anthropic/v2"
15+
"github.com/sashabaranov/go-openai"
1516
)
1617

1718
var _ core.Generative = (*Client)(nil)
@@ -45,13 +46,21 @@ func (c *Client) Completion(ctx context.Context, content string) (*core.Response
4546
return nil, err
4647
}
4748

49+
usage := core.Usage{
50+
PromptTokens: resp.Usage.InputTokens,
51+
CompletionTokens: resp.Usage.OutputTokens,
52+
TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens,
53+
}
54+
55+
if resp.Usage.CacheCreationInputTokens > 0 || resp.Usage.CacheReadInputTokens > 0 {
56+
usage.PromptTokensDetails = &openai.PromptTokensDetails{
57+
CachedTokens: resp.Usage.CacheCreationInputTokens + resp.Usage.CacheReadInputTokens,
58+
}
59+
}
60+
4861
return &core.Response{
4962
Content: resp.Content[0].GetText(),
50-
Usage: core.Usage{
51-
PromptTokens: resp.Usage.InputTokens,
52-
CompletionTokens: resp.Usage.OutputTokens,
53-
TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens,
54-
},
63+
Usage: usage,
5564
}, nil
5665
}
5766

@@ -88,13 +97,21 @@ func (c *Client) GetSummaryPrefix(ctx context.Context, content string) (*core.Re
8897
return nil, fmt.Errorf("failed to unmarshal tool use input: %w", err)
8998
}
9099

100+
usage := core.Usage{
101+
PromptTokens: resp.Usage.InputTokens,
102+
CompletionTokens: resp.Usage.OutputTokens,
103+
TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens,
104+
}
105+
106+
if resp.Usage.CacheCreationInputTokens > 0 || resp.Usage.CacheReadInputTokens > 0 {
107+
usage.PromptTokensDetails = &openai.PromptTokensDetails{
108+
CachedTokens: resp.Usage.CacheCreationInputTokens + resp.Usage.CacheReadInputTokens,
109+
}
110+
}
111+
91112
return &core.Response{
92113
Content: result.Prefix,
93-
Usage: core.Usage{
94-
PromptTokens: resp.Usage.InputTokens,
95-
CompletionTokens: resp.Usage.OutputTokens,
96-
TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens,
97-
},
114+
Usage: usage,
98115
}, nil
99116
}
100117

provider/gemini/gemini.go

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"github.com/appleboy/com/convert"
1010

1111
"github.com/google/generative-ai-go/genai"
12+
"github.com/sashabaranov/go-openai"
1213
"google.golang.org/api/option"
1314
)
1415

@@ -36,13 +37,21 @@ func (c *Client) Completion(ctx context.Context, content string) (*core.Response
3637
}
3738
}
3839

40+
usage := core.Usage{
41+
PromptTokens: int(resp.UsageMetadata.PromptTokenCount),
42+
CompletionTokens: int(resp.UsageMetadata.CandidatesTokenCount),
43+
TotalTokens: int(resp.UsageMetadata.TotalTokenCount),
44+
}
45+
46+
if resp.UsageMetadata.CachedContentTokenCount > 0 {
47+
usage.PromptTokensDetails = &openai.PromptTokensDetails{
48+
CachedTokens: int(resp.UsageMetadata.CachedContentTokenCount),
49+
}
50+
}
51+
3952
return &core.Response{
4053
Content: ret,
41-
Usage: core.Usage{
42-
PromptTokens: int(resp.UsageMetadata.PromptTokenCount),
43-
CompletionTokens: int(resp.UsageMetadata.CandidatesTokenCount),
44-
TotalTokens: int(resp.UsageMetadata.TotalTokenCount),
45-
},
54+
Usage: usage,
4655
}, nil
4756
}
4857

@@ -61,13 +70,21 @@ func (c *Client) GetSummaryPrefix(ctx context.Context, content string) (*core.Re
6170

6271
part := resp.Candidates[0].Content.Parts[0]
6372

73+
usage := core.Usage{
74+
PromptTokens: int(resp.UsageMetadata.PromptTokenCount),
75+
CompletionTokens: int(resp.UsageMetadata.CandidatesTokenCount),
76+
TotalTokens: int(resp.UsageMetadata.TotalTokenCount),
77+
}
78+
79+
if resp.UsageMetadata.CachedContentTokenCount > 0 {
80+
usage.PromptTokensDetails = &openai.PromptTokensDetails{
81+
CachedTokens: int(resp.UsageMetadata.CachedContentTokenCount),
82+
}
83+
}
84+
6485
r := &core.Response{
6586
Content: strings.TrimSpace(strings.TrimSuffix(fmt.Sprintf("%v", part), "\n")),
66-
Usage: core.Usage{
67-
PromptTokens: int(resp.UsageMetadata.PromptTokenCount),
68-
CompletionTokens: int(resp.UsageMetadata.CandidatesTokenCount),
69-
TotalTokens: int(resp.UsageMetadata.TotalTokenCount),
70-
},
87+
Usage: usage,
7188
}
7289

7390
if c.debug {

0 commit comments

Comments
 (0)