Skip to content

Commit 6d414d0

Browse files
committed
fix: context cancel not working during node runner execution
1 parent 1dc00e4 commit 6d414d0

File tree

6 files changed

+160
-120
lines changed

6 files changed

+160
-120
lines changed

backend/api/handler/coze/workflow_service_test.go

Lines changed: 43 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -106,33 +106,30 @@ import (
106106
"github.com/coze-dev/coze-studio/backend/types/errno"
107107
)
108108

109-
var (
110-
publishPatcher *mockey.Mocker
111-
)
112-
113109
func TestMain(m *testing.M) {
114110
callbacks.AppendGlobalHandlers(service.GetTokenCallbackHandler())
115111
service.RegisterAllNodeAdaptors()
116112
os.Exit(m.Run())
117113
}
118114

119115
type wfTestRunner struct {
120-
t *testing.T
121-
h *server.Hertz
122-
ctrl *gomock.Controller
123-
idGen *mock.MockIDGenerator
124-
appVarS *mockvar.MockStore
125-
userVarS *mockvar.MockStore
126-
varGetter *mockvar.MockVariablesMetaGetter
127-
modelManage *mockmodel.MockManager
128-
plugin *mockPlugin.MockPluginService
129-
tos *storageMock.MockStorage
130-
knowledge *knowledgemock.MockKnowledge
131-
database *databasemock.MockDatabase
132-
pluginSrv *pluginmock.MockPluginService
133-
internalModel *testutil.UTChatModel
134-
ctx context.Context
135-
closeFn func()
116+
t *testing.T
117+
h *server.Hertz
118+
ctrl *gomock.Controller
119+
idGen *mock.MockIDGenerator
120+
appVarS *mockvar.MockStore
121+
userVarS *mockvar.MockStore
122+
varGetter *mockvar.MockVariablesMetaGetter
123+
modelManage *mockmodel.MockManager
124+
plugin *mockPlugin.MockPluginService
125+
tos *storageMock.MockStorage
126+
knowledge *knowledgemock.MockKnowledge
127+
database *databasemock.MockDatabase
128+
pluginSrv *pluginmock.MockPluginService
129+
internalModel *testutil.UTChatModel
130+
publishPatcher *mockey.Mocker
131+
ctx context.Context
132+
closeFn func()
136133
}
137134

138135
var req2URL = map[reflect.Type]string{
@@ -256,7 +253,7 @@ func newWfTestRunner(t *testing.T) *wfTestRunner {
256253
workflowRepo := service.NewWorkflowRepository(mockIDGen, db, redisClient, mockTos, cpStore, utChatModel)
257254
mockey.Mock(appworkflow.GetWorkflowDomainSVC).Return(service.NewWorkflowService(workflowRepo)).Build()
258255
mockey.Mock(workflow2.GetRepository).Return(workflowRepo).Build()
259-
publishPatcher = mockey.Mock(appworkflow.PublishWorkflowResource).Return(nil).Build()
256+
publishPatcher := mockey.Mock(appworkflow.PublishWorkflowResource).Return(nil).Build()
260257

261258
mockCU := mockCrossUser.NewMockUser(ctrl)
262259
mockCU.EXPECT().GetUserSpaceList(gomock.Any(), gomock.Any()).Return([]*crossuser.EntitySpace{
@@ -305,9 +302,7 @@ func newWfTestRunner(t *testing.T) *wfTestRunner {
305302
}, nil).Build()
306303

307304
f := func() {
308-
if publishPatcher != nil {
309-
publishPatcher.UnPatch()
310-
}
305+
publishPatcher.UnPatch()
311306
m.UnPatch()
312307
m1.UnPatch()
313308
m2.UnPatch()
@@ -320,22 +315,23 @@ func newWfTestRunner(t *testing.T) *wfTestRunner {
320315
}
321316

322317
return &wfTestRunner{
323-
t: t,
324-
h: h,
325-
ctrl: ctrl,
326-
idGen: mockIDGen,
327-
appVarS: mockGlobalAppVarStore,
328-
userVarS: mockGlobalUserVarStore,
329-
varGetter: mockVarGetter,
330-
modelManage: mockModelManage,
331-
plugin: mPlugin,
332-
tos: mockTos,
333-
knowledge: mockKwOperator,
334-
database: mockDatabaseOperator,
335-
internalModel: utChatModel,
336-
ctx: context.Background(),
337-
closeFn: f,
338-
pluginSrv: mockPluginSrv,
318+
t: t,
319+
h: h,
320+
ctrl: ctrl,
321+
idGen: mockIDGen,
322+
appVarS: mockGlobalAppVarStore,
323+
userVarS: mockGlobalUserVarStore,
324+
varGetter: mockVarGetter,
325+
modelManage: mockModelManage,
326+
plugin: mPlugin,
327+
tos: mockTos,
328+
knowledge: mockKwOperator,
329+
database: mockDatabaseOperator,
330+
internalModel: utChatModel,
331+
ctx: context.Background(),
332+
closeFn: f,
333+
pluginSrv: mockPluginSrv,
334+
publishPatcher: publishPatcher,
339335
}
340336
}
341337

@@ -4147,14 +4143,7 @@ func TestCopyWorkflowAppToLibrary(t *testing.T) {
41474143

41484144
}
41494145

4150-
if publishPatcher != nil {
4151-
publishPatcher.UnPatch()
4152-
}
4153-
localPatcher := mockey.Mock(appworkflow.PublishWorkflowResource).To(mockPublishWorkflowResource).Build()
4154-
defer func() {
4155-
localPatcher.UnPatch()
4156-
publishPatcher = mockey.Mock(appworkflow.PublishWorkflowResource).Return(nil).Build()
4157-
}()
4146+
defer mockey.Mock(appworkflow.PublishWorkflowResource).To(mockPublishWorkflowResource).Build().UnPatch()
41584147

41594148
appID := "7513788954458456064"
41604149
appIDInt64, _ := strconv.ParseInt(appID, 10, 64)
@@ -4265,14 +4254,8 @@ func TestCopyWorkflowAppToLibrary(t *testing.T) {
42654254
return nil
42664255

42674256
}
4268-
if publishPatcher != nil {
4269-
publishPatcher.UnPatch()
4270-
}
4271-
localPatcher := mockey.Mock(appworkflow.PublishWorkflowResource).To(mockPublishWorkflowResource).Build()
4272-
defer func() {
4273-
localPatcher.UnPatch()
4274-
publishPatcher = mockey.Mock(appworkflow.PublishWorkflowResource).Return(nil).Build()
4275-
}()
4257+
4258+
defer mockey.Mock(appworkflow.PublishWorkflowResource).To(mockPublishWorkflowResource).Build().UnPatch()
42764259

42774260
defer mockey.Mock((*appknowledge.KnowledgeApplicationService).CopyKnowledge).Return(&modelknowledge.CopyKnowledgeResponse{
42784261
TargetKnowledgeID: 100100,
@@ -4313,6 +4296,7 @@ func TestCopyWorkflowAppToLibrary(t *testing.T) {
43134296
func TestMoveWorkflowAppToLibrary(t *testing.T) {
43144297
mockey.PatchConvey("test move workflow", t, func() {
43154298
r := newWfTestRunner(t)
4299+
r.publishPatcher.UnPatch()
43164300
defer r.closeFn()
43174301
vars := map[string]*vo.TypeInfo{
43184302
"app_v1": {
@@ -4354,14 +4338,7 @@ func TestMoveWorkflowAppToLibrary(t *testing.T) {
43544338

43554339
}
43564340

4357-
if publishPatcher != nil {
4358-
publishPatcher.UnPatch()
4359-
}
4360-
localPatcher := mockey.Mock(appworkflow.PublishWorkflowResource).To(mockPublishWorkflowResource).Build()
4361-
defer func() {
4362-
localPatcher.UnPatch()
4363-
publishPatcher = mockey.Mock(appworkflow.PublishWorkflowResource).Return(nil).Build()
4364-
}()
4341+
defer mockey.Mock(appworkflow.PublishWorkflowResource).To(mockPublishWorkflowResource).Build().UnPatch()
43654342

43664343
defer mockey.Mock((*appknowledge.KnowledgeApplicationService).MoveKnowledgeToLibrary).Return(nil).Build().UnPatch()
43674344
defer mockey.Mock((*appmemory.DatabaseApplicationService).MoveDatabaseToLibrary).Return(&appmemory.MoveDatabaseToLibraryResponse{}, nil).Build().UnPatch()
@@ -4479,6 +4456,7 @@ func TestMoveWorkflowAppToLibrary(t *testing.T) {
44794456
func TestDuplicateWorkflowsByAppID(t *testing.T) {
44804457
mockey.PatchConvey("test duplicate work", t, func() {
44814458
r := newWfTestRunner(t)
4459+
r.publishPatcher.UnPatch()
44824460
defer r.closeFn()
44834461

44844462
vars := map[string]*vo.TypeInfo{
@@ -4516,14 +4494,7 @@ func TestDuplicateWorkflowsByAppID(t *testing.T) {
45164494

45174495
}
45184496

4519-
if publishPatcher != nil {
4520-
publishPatcher.UnPatch()
4521-
}
4522-
localPatcher := mockey.Mock(appworkflow.PublishWorkflowResource).To(mockPublishWorkflowResource).Build()
4523-
defer func() {
4524-
localPatcher.UnPatch()
4525-
publishPatcher = mockey.Mock(appworkflow.PublishWorkflowResource).Return(nil).Build()
4526-
}()
4497+
defer mockey.Mock(appworkflow.PublishWorkflowResource).To(mockPublishWorkflowResource).Build().UnPatch()
45274498

45284499
appIDInt64 := int64(7513788954458456064)
45294500

backend/domain/workflow/internal/compose/node_runner.go

Lines changed: 55 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ import (
3636
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
3737
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
3838
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
39+
exec "github.com/coze-dev/coze-studio/backend/pkg/execute"
3940
"github.com/coze-dev/coze-studio/backend/pkg/logs"
4041
"github.com/coze-dev/coze-studio/backend/pkg/safego"
4142
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
@@ -614,52 +615,61 @@ func (r *nodeRunner[O]) postProcess(ctx context.Context, output map[string]any)
614615

615616
func (r *nodeRunner[O]) invoke(ctx context.Context, input map[string]any, opts ...O) (output map[string]any, err error) {
616617
var n int64
617-
for {
618-
select {
619-
case <-ctx.Done():
620-
return nil, ctx.Err()
621-
default:
622-
}
618+
var invokeOutput map[string]any
623619

624-
output, err = r.i(ctx, input, opts...)
620+
for {
621+
err = exec.RunWithContextDone(ctx, func() error {
622+
var invokeErr error
623+
invokeOutput, invokeErr = r.i(ctx, input, opts...)
624+
if invokeErr != nil {
625+
return invokeErr
626+
}
627+
return nil
628+
})
625629
if err != nil {
626-
if _, ok := compose.IsInterruptRerunError(err); ok { // interrupt, won't retry
630+
if _, ok := compose.IsInterruptRerunError(err); ok {
627631
r.interrupted = true
628632
return nil, err
629633
}
630634

631635
logs.CtxErrorf(ctx, "[invoke] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
636+
632637
if r.maxRetry > n {
633638
n++
634639
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil && exeCtx.NodeCtx != nil {
635640
exeCtx.CurrentRetryCount++
636641
}
637642
continue
638643
}
644+
639645
return nil, err
640646
}
647+
return invokeOutput, nil
641648

642-
return output, nil
643649
}
644650
}
645651

646652
func (r *nodeRunner[O]) stream(ctx context.Context, input map[string]any, opts ...O) (output *schema.StreamReader[map[string]any], err error) {
647653
var n int64
654+
var streamOutput *schema.StreamReader[map[string]any]
655+
648656
for {
649-
select {
650-
case <-ctx.Done():
651-
return nil, ctx.Err()
652-
default:
653-
}
657+
err = exec.RunWithContextDone(ctx, func() error {
658+
var streamErr error
659+
streamOutput, streamErr = r.s(ctx, input, opts...)
660+
if streamErr != nil {
661+
return streamErr
662+
}
663+
return nil
664+
})
654665

655-
output, err = r.s(ctx, input, opts...)
656666
if err != nil {
657-
if _, ok := compose.IsInterruptRerunError(err); ok { // interrupt, won't retry
667+
if _, ok := compose.IsInterruptRerunError(err); ok {
658668
r.interrupted = true
659669
return nil, err
660670
}
661671

662-
logs.CtxErrorf(ctx, "[invoke] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
672+
logs.CtxErrorf(ctx, "[stream] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
663673
if r.maxRetry > n {
664674
n++
665675
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil && exeCtx.NodeCtx != nil {
@@ -669,8 +679,8 @@ func (r *nodeRunner[O]) stream(ctx context.Context, input map[string]any, opts .
669679
}
670680
return nil, err
671681
}
682+
return streamOutput, nil
672683

673-
return output, nil
674684
}
675685
}
676686

@@ -680,40 +690,42 @@ func (r *nodeRunner[O]) collect(ctx context.Context, input *schema.StreamReader[
680690
}
681691

682692
copied := input.Copy(int(r.maxRetry))
683-
684693
var n int64
694+
685695
defer func() {
686696
for i := n + 1; i < r.maxRetry; i++ {
687697
copied[i].Close()
688698
}
689699
}()
690-
700+
var collectOutput map[string]any
691701
for {
692-
select {
693-
case <-ctx.Done():
694-
return nil, ctx.Err()
695-
default:
696-
}
702+
err = exec.RunWithContextDone(ctx, func() error {
703+
var collectErr error
704+
collectOutput, collectErr = r.c(ctx, copied[n], opts...)
705+
if collectErr != nil {
706+
return collectErr
707+
}
708+
return nil
709+
})
697710

698-
output, err = r.c(ctx, copied[n], opts...)
699711
if err != nil {
700-
if _, ok := compose.IsInterruptRerunError(err); ok { // interrupt, won't retry
712+
if _, ok := compose.IsInterruptRerunError(err); ok {
701713
r.interrupted = true
702714
return nil, err
703715
}
704716

705-
logs.CtxErrorf(ctx, "[invoke] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
717+
logs.CtxErrorf(ctx, "[collect] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
706718
if r.maxRetry > n {
707719
n++
708720
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil && exeCtx.NodeCtx != nil {
709721
exeCtx.CurrentRetryCount++
710722
}
711723
continue
712724
}
725+
713726
return nil, err
714727
}
715-
716-
return output, nil
728+
return collectOutput, nil
717729
}
718730
}
719731

@@ -731,21 +743,22 @@ func (r *nodeRunner[O]) transform(ctx context.Context, input *schema.StreamReade
731743
}
732744
}()
733745

746+
var transformOutput *schema.StreamReader[map[string]any]
734747
for {
735-
select {
736-
case <-ctx.Done():
737-
return nil, ctx.Err()
738-
default:
739-
}
740-
741-
output, err = r.t(ctx, copied[n], opts...)
748+
err = exec.RunWithContextDone(ctx, func() error {
749+
var transformErr error
750+
transformOutput, transformErr = r.t(ctx, copied[n], opts...)
751+
if transformErr != nil {
752+
return transformErr
753+
}
754+
return nil
755+
})
742756
if err != nil {
743-
if _, ok := compose.IsInterruptRerunError(err); ok { // interrupt, won't retry
757+
if _, ok := compose.IsInterruptRerunError(err); ok {
744758
r.interrupted = true
745759
return nil, err
746760
}
747-
748-
logs.CtxErrorf(ctx, "[invoke] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
761+
logs.CtxErrorf(ctx, "[transform] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
749762
if r.maxRetry > n {
750763
n++
751764
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil && exeCtx.NodeCtx != nil {
@@ -756,7 +769,8 @@ func (r *nodeRunner[O]) transform(ctx context.Context, input *schema.StreamReade
756769
return nil, err
757770
}
758771

759-
return output, nil
772+
return transformOutput, nil
773+
760774
}
761775
}
762776

0 commit comments

Comments
 (0)