Skip to content
This repository was archived by the owner on Sep 18, 2025. It is now read-only.

Commit c08f5a7

Browse files
committed
fix cancell
1 parent 36f201d commit c08f5a7

File tree

4 files changed

+42
-1
lines changed

4 files changed

+42
-1
lines changed

internal/llm/provider/anthropic.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ func NewAnthropicProvider(opts ...AnthropicOption) (Provider, error) {
9292
}
9393

9494
func (a *anthropicProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
95+
messages = cleanupMessages(messages)
9596
anthropicMessages := a.convertToAnthropicMessages(messages)
9697
anthropicTools := a.convertToAnthropicTools(tools)
9798

@@ -135,6 +136,7 @@ func (a *anthropicProvider) SendMessages(ctx context.Context, messages []message
135136
}
136137

137138
func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
139+
messages = cleanupMessages(messages)
138140
anthropicMessages := a.convertToAnthropicMessages(messages)
139141
anthropicTools := a.convertToAnthropicTools(tools)
140142

internal/llm/provider/gemini.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ func (p *geminiProvider) extractTokenUsage(resp *genai.GenerateContentResponse)
154154
}
155155

156156
func (p *geminiProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
157+
messages = cleanupMessages(messages)
157158
model := p.client.GenerativeModel(p.model.APIModel)
158159
model.SetMaxOutputTokens(p.maxTokens)
159160

@@ -206,6 +207,7 @@ func (p *geminiProvider) SendMessages(ctx context.Context, messages []message.Me
206207
}
207208

208209
func (p *geminiProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
210+
messages = cleanupMessages(messages)
209211
model := p.client.GenerativeModel(p.model.APIModel)
210212
model.SetMaxOutputTokens(p.maxTokens)
211213

internal/llm/provider/openai.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ func (p *openaiProvider) extractTokenUsage(usage openai.CompletionUsage) TokenUs
163163
}
164164

165165
func (p *openaiProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
166+
messages = cleanupMessages(messages)
166167
chatMessages := p.convertToOpenAIMessages(messages)
167168
openaiTools := p.convertToOpenAITools(tools)
168169

@@ -206,6 +207,7 @@ func (p *openaiProvider) SendMessages(ctx context.Context, messages []message.Me
206207
}
207208

208209
func (p *openaiProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
210+
messages = cleanupMessages(messages)
209211
chatMessages := p.convertToOpenAIMessages(messages)
210212
openaiTools := p.convertToOpenAITools(tools)
211213

@@ -276,4 +278,3 @@ func (p *openaiProvider) StreamResponse(ctx context.Context, messages []message.
276278

277279
return eventChan, nil
278280
}
279-

internal/llm/provider/provider.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,39 @@ type Provider interface {
5252

5353
StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error)
5454
}
55+
56+
func cleanupMessages(messages []message.Message) []message.Message {
57+
// First pass: filter out canceled messages
58+
var cleanedMessages []message.Message
59+
for _, msg := range messages {
60+
if msg.FinishReason() != "canceled" {
61+
cleanedMessages = append(cleanedMessages, msg)
62+
}
63+
}
64+
65+
// Second pass: filter out tool messages without a corresponding tool call
66+
var result []message.Message
67+
toolMessageIDs := make(map[string]bool)
68+
69+
for _, msg := range cleanedMessages {
70+
if msg.Role == message.Assistant {
71+
for _, toolCall := range msg.ToolCalls() {
72+
toolMessageIDs[toolCall.ID] = true // Mark as referenced
73+
}
74+
}
75+
}
76+
77+
// Keep only messages that aren't unreferenced tool messages
78+
for _, msg := range cleanedMessages {
79+
if msg.Role == message.Tool {
80+
for _, toolCall := range msg.ToolResults() {
81+
if referenced, exists := toolMessageIDs[toolCall.ToolCallID]; exists && referenced {
82+
result = append(result, msg)
83+
}
84+
}
85+
} else {
86+
result = append(result, msg)
87+
}
88+
}
89+
return result
90+
}

0 commit comments

Comments
 (0)