Skip to content

Commit b9080db

Browse files
feat(model/ark): Refactor callback handling for Responses API (#649)
Moves the callback initialization logic from the generic ChatModel to the specific ResponsesAPIChatModel. This change ensures that EnsureRunInfo is only called for the Responses API flow, where it is actually needed. The GetType method was also added to ResponsesAPIChatModel to support this.
1 parent f658a8b commit b9080db

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

components/model/ark/chatmodel.go

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@ import (
2727
"github.com/volcengine/volcengine-go-sdk/service/arkruntime"
2828
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
2929

30-
"github.com/cloudwego/eino/callbacks"
31-
"github.com/cloudwego/eino/components"
3230
fmodel "github.com/cloudwego/eino/components/model"
3331
"github.com/cloudwego/eino/schema"
3432
)
@@ -389,9 +387,6 @@ type CacheInfo struct {
389387

390388
func (cm *ChatModel) Generate(ctx context.Context, in []*schema.Message, opts ...fmodel.Option) (
391389
outMsg *schema.Message, err error) {
392-
393-
ctx = callbacks.EnsureRunInfo(ctx, cm.GetType(), components.ComponentOfChatModel)
394-
395390
ok, err := cm.callByResponsesAPI(opts...)
396391
if err != nil {
397392
return nil, err
@@ -406,8 +401,6 @@ func (cm *ChatModel) Generate(ctx context.Context, in []*schema.Message, opts ..
406401
func (cm *ChatModel) Stream(ctx context.Context, in []*schema.Message, opts ...fmodel.Option) (
407402
outStream *schema.StreamReader[*schema.Message], err error) {
408403

409-
ctx = callbacks.EnsureRunInfo(ctx, cm.GetType(), components.ComponentOfChatModel)
410-
411404
ok, err := cm.callByResponsesAPI(opts...)
412405
if err != nil {
413406
return nil, err

components/model/ark/responses_api.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import (
2828

2929
"github.com/bytedance/sonic"
3030
"github.com/cloudwego/eino/callbacks"
31+
"github.com/cloudwego/eino/components"
3132
"github.com/cloudwego/eino/components/model"
3233
"github.com/cloudwego/eino/schema"
3334
"github.com/volcengine/volcengine-go-sdk/service/arkruntime"
@@ -219,8 +220,14 @@ type cacheConfig struct {
219220
ExpireAt *int64
220221
}
221222

223+
func (cm *ResponsesAPIChatModel) GetType() string {
224+
return "ResponsesAPI"
225+
}
226+
222227
func (cm *ResponsesAPIChatModel) Generate(ctx context.Context, input []*schema.Message,
223228
opts ...model.Option) (outMsg *schema.Message, err error) {
229+
ctx = callbacks.EnsureRunInfo(ctx, cm.GetType(), components.ComponentOfChatModel)
230+
224231
options, specOptions, err := cm.getOptions(opts)
225232
if err != nil {
226233
return nil, err
@@ -288,7 +295,7 @@ func (cm *ResponsesAPIChatModel) Generate(ctx context.Context, input []*schema.M
288295

289296
func (cm *ResponsesAPIChatModel) Stream(ctx context.Context, input []*schema.Message,
290297
opts ...model.Option) (outStream *schema.StreamReader[*schema.Message], err error) {
291-
298+
ctx = callbacks.EnsureRunInfo(ctx, cm.GetType(), components.ComponentOfChatModel)
292299
options, specOptions, err := cm.getOptions(opts)
293300
if err != nil {
294301
return nil, err

0 commit comments

Comments
 (0)