Skip to content

Commit cd07d3c

Browse files
committed
add reasoning parsing support
1 parent 9f51cf9 commit cd07d3c

File tree

2 files changed

+131
-10
lines changed

2 files changed

+131
-10
lines changed

go/plugins/ollama/ollama.go

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,6 @@ func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb fun
252252
Images: images,
253253
Stream: stream,
254254
}
255-
256255
} else {
257256
var messages []*ollamaMessage
258257
// Translate all messages to ollama message format.
@@ -289,6 +288,7 @@ func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb fun
289288
if err != nil {
290289
return nil, err
291290
}
291+
fmt.Printf("Ollama Request Payload: %s\n", string(payloadBytes))
292292

293293
// Determine the correct endpoint
294294
endpoint := g.serverAddress + "/api/chat"
@@ -322,6 +322,7 @@ func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb fun
322322

323323
var response *ai.ModelResponse
324324
if isChatModel {
325+
fmt.Printf("translating chat response\n")
325326
response, err = translateChatResponse(body)
326327
} else {
327328
response, err = translateModelResponse(body)
@@ -453,6 +454,21 @@ func translateChatResponse(responseData []byte) (*ai.ModelResponse, error) {
453454
Role: ai.RoleModel,
454455
},
455456
}
457+
458+
// Check for thinking/reasoning first
459+
if response.Message.Thinking != "" {
460+
aiPart := ai.NewReasoningPart(response.Message.Thinking, nil)
461+
modelResponse.Message.Content = append(modelResponse.Message.Content, aiPart)
462+
} else if strings.Contains(response.Message.Content, "<think>") {
463+
// If thinking is not explicitly returned, check if it's in the content
464+
thinking, content := parseThinking(response.Message.Content)
465+
if thinking != "" {
466+
aiPart := ai.NewReasoningPart(thinking, nil)
467+
modelResponse.Message.Content = append(modelResponse.Message.Content, aiPart)
468+
response.Message.Content = content
469+
}
470+
}
471+
456472
if len(response.Message.ToolCalls) > 0 {
457473
for _, toolCall := range response.Message.ToolCalls {
458474
toolRequest := &ai.ToolRequest{
@@ -462,12 +478,11 @@ func translateChatResponse(responseData []byte) (*ai.ModelResponse, error) {
462478
toolPart := ai.NewToolRequestPart(toolRequest)
463479
modelResponse.Message.Content = append(modelResponse.Message.Content, toolPart)
464480
}
465-
} else if response.Message.Content != "" {
466-
aiPart := ai.NewTextPart(response.Message.Content)
467-
modelResponse.Message.Content = append(modelResponse.Message.Content, aiPart)
468481
}
469-
if response.Message.Thinking != "" {
470-
aiPart := ai.NewReasoningPart(response.Message.Thinking, nil)
482+
483+
// Add remaining content as text if present
484+
if response.Message.Content != "" {
485+
aiPart := ai.NewTextPart(response.Message.Content)
471486
modelResponse.Message.Content = append(modelResponse.Message.Content, aiPart)
472487
}
473488

@@ -502,6 +517,10 @@ func translateChatChunk(input string) (*ai.ModelResponseChunk, error) {
502517
return nil, fmt.Errorf("failed to parse response JSON: %v", err)
503518
}
504519
chunk := &ai.ModelResponseChunk{}
520+
if response.Message.Content != "" {
521+
aiPart := ai.NewTextPart(response.Message.Content)
522+
chunk.Content = append(chunk.Content, aiPart)
523+
}
505524
if len(response.Message.ToolCalls) > 0 {
506525
for _, toolCall := range response.Message.ToolCalls {
507526
toolRequest := &ai.ToolRequest{
@@ -511,9 +530,6 @@ func translateChatChunk(input string) (*ai.ModelResponseChunk, error) {
511530
toolPart := ai.NewToolRequestPart(toolRequest)
512531
chunk.Content = append(chunk.Content, toolPart)
513532
}
514-
} else if response.Message.Content != "" {
515-
aiPart := ai.NewTextPart(response.Message.Content)
516-
chunk.Content = append(chunk.Content, aiPart)
517533
}
518534

519535
if response.Message.Thinking != "" {
@@ -593,3 +609,15 @@ func concatImages(input *ai.ModelRequest, roleFilter []ai.Role) ([]string, error
593609
}
594610
return images, nil
595611
}
612+
613+
// parseThinking extracts the thinking content from the response string.
614+
func parseThinking(content string) (string, string) {
615+
start := strings.Index(content, "<think>")
616+
end := strings.Index(content, "</think>")
617+
if start != -1 && end != -1 && end > start {
618+
thinking := content[start+len("<think>") : end]
619+
rest := content[:start] + content[end+len("</think>"):]
620+
return strings.TrimSpace(thinking), strings.TrimSpace(rest)
621+
}
622+
return "", content
623+
}

go/plugins/ollama/ollama_test.go

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,102 @@ func equalContent(a, b []*ai.Part) bool {
125125
return false
126126
}
127127
for i := range a {
128-
if a[i].Text != b[i].Text || !a[i].IsText() || !b[i].IsText() {
128+
if a[i].IsText() {
129+
if !b[i].IsText() || a[i].Text != b[i].Text {
130+
return false
131+
}
132+
} else if a[i].IsReasoning() {
133+
if !b[i].IsReasoning() || a[i].Text != b[i].Text {
134+
return false
135+
}
136+
} else {
137+
// For other types, we might need more specific checks,
138+
// but for now return false if kinds don't match or not handled
129139
return false
130140
}
131141
}
132142
return true
133143
}
144+
145+
func TestTranslateChatResponse(t *testing.T) {
146+
tests := []struct {
147+
name string
148+
input string
149+
want *ai.ModelResponse
150+
wantReasoning string
151+
wantErr bool
152+
}{
153+
{
154+
name: "Thinking field present",
155+
input: `{"model": "deepseek-r1", "created_at": "2024-06-20T12:34:56Z", "message": {"role": "assistant", "content": "Hello", "thinking": "I should say hello"}}`,
156+
want: &ai.ModelResponse{
157+
Message: &ai.Message{
158+
Role: ai.RoleModel,
159+
Content: []*ai.Part{
160+
ai.NewReasoningPart("I should say hello", nil),
161+
ai.NewTextPart("Hello"),
162+
},
163+
},
164+
},
165+
wantReasoning: "I should say hello",
166+
},
167+
{
168+
name: "Thinking in content tag",
169+
input: `{"model": "deepseek-r1", "created_at": "2024-06-20T12:34:56Z", "message": {"role": "assistant", "content": "<think>I should say hello</think>Hello"}}`,
170+
want: &ai.ModelResponse{
171+
Message: &ai.Message{
172+
Role: ai.RoleModel,
173+
Content: []*ai.Part{
174+
ai.NewReasoningPart("I should say hello", nil),
175+
ai.NewTextPart("Hello"),
176+
},
177+
},
178+
},
179+
wantReasoning: "I should say hello",
180+
},
181+
{
182+
name: "Only thinking in content",
183+
input: `{"model": "deepseek-r1", "created_at": "2024-06-20T12:34:56Z", "message": {"role": "assistant", "content": "<think>Just thinking</think>"}}`,
184+
want: &ai.ModelResponse{
185+
Message: &ai.Message{
186+
Role: ai.RoleModel,
187+
Content: []*ai.Part{
188+
ai.NewReasoningPart("Just thinking", nil),
189+
},
190+
},
191+
},
192+
wantReasoning: "Just thinking",
193+
},
194+
{
195+
name: "No thinking",
196+
input: `{"model": "llama3", "created_at": "2024-06-20T12:34:56Z", "message": {"role": "assistant", "content": "Hello"}}`,
197+
want: &ai.ModelResponse{
198+
Message: &ai.Message{
199+
Role: ai.RoleModel,
200+
Content: []*ai.Part{
201+
ai.NewTextPart("Hello"),
202+
},
203+
},
204+
},
205+
wantReasoning: "",
206+
},
207+
}
208+
209+
for _, tt := range tests {
210+
t.Run(tt.name, func(t *testing.T) {
211+
got, err := translateChatResponse([]byte(tt.input))
212+
if (err != nil) != tt.wantErr {
213+
t.Errorf("translateChatResponse() error = %v, wantErr %v", err, tt.wantErr)
214+
return
215+
}
216+
if !tt.wantErr {
217+
if got.Reasoning() != tt.wantReasoning {
218+
t.Errorf("translateChatResponse() Reasoning = %q, want %q", got.Reasoning(), tt.wantReasoning)
219+
}
220+
if !equalContent(got.Message.Content, tt.want.Message.Content) {
221+
t.Errorf("translateChatResponse() got = %v, want %v", got.Message.Content, tt.want.Message.Content)
222+
}
223+
}
224+
})
225+
}
226+
}

0 commit comments

Comments
 (0)