Skip to content

Commit 287d182

Browse files
authored
[pipeline] bug fix: fatal error: concurrent map writes (#1026)
1 parent 05a26ef commit 287d182

File tree

7 files changed

+173
-16
lines changed

7 files changed

+173
-16
lines changed

pkg/apiserver/controller/pipeline/callback_test.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,14 @@ import (
2121
"io/ioutil"
2222
"testing"
2323

24+
"github.com/agiledragon/gomonkey/v2"
25+
log "github.com/sirupsen/logrus"
2426
"github.com/stretchr/testify/assert"
2527

2628
"github.com/PaddlePaddle/PaddleFlow/pkg/apiserver/common"
2729
"github.com/PaddlePaddle/PaddleFlow/pkg/apiserver/models"
2830
"github.com/PaddlePaddle/PaddleFlow/pkg/common/logger"
31+
"github.com/PaddlePaddle/PaddleFlow/pkg/pipeline"
2932
"github.com/PaddlePaddle/PaddleFlow/pkg/storage/driver"
3033
)
3134

@@ -73,3 +76,33 @@ func TestGetJobByRun(t *testing.T) {
7376
assert.Nil(t, err)
7477
assert.Equal(t, "job-run-post", jobView.JobID)
7578
}
79+
80+
func TestUpdateRunByWfEvent(t *testing.T) {
81+
driver.InitMockDB()
82+
ctx := &logger.RequestContext{UserName: MockRootUser}
83+
run := getMockRunWithoutRuntime()
84+
runID, err := models.CreateRun(ctx.Logging(), &run)
85+
assert.Nil(t, err)
86+
87+
wfMap.Store(runID, "abc")
88+
event := &pipeline.WorkflowEvent{
89+
Event: pipeline.WfEventRunUpdate,
90+
Extra: map[string]interface{}{
91+
common.WfEventKeyRunID: runID,
92+
common.WfEventKeyStatus: common.StatusRunSucceeded,
93+
common.WfEventKeyStartTime: "2022-09-09 10:00:09",
94+
},
95+
Message: "mesg",
96+
}
97+
98+
patch5 := gomonkey.ApplyFunc(models.GetRunByID, func(logEntry *log.Entry, runID string) (models.Run, error) {
99+
return run, nil
100+
})
101+
defer patch5.Reset()
102+
103+
_, flag := UpdateRunByWfEvent(runID, event)
104+
assert.True(t, flag)
105+
106+
_, ok := wfMap.Load(runID)
107+
assert.False(t, ok)
108+
}

pkg/apiserver/controller/pipeline/callbacks.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ func UpdateRunByWfEvent(id string, event interface{}) (int64, bool) {
109109
}
110110
if common.IsRunFinalStatus(status) {
111111
logging.Debugf("run[%s] has reached final status[%s]", runID, status)
112-
delete(wfMap, runID)
112+
wfMap.Delete(runID)
113113
}
114114
startTime, ok := wfEvent.Extra[common.WfEventKeyStartTime].(string)
115115
if !ok {
@@ -152,7 +152,7 @@ func UpdateRunByWfEvent(id string, event interface{}) (int64, bool) {
152152

153153
if common.IsRunFinalStatus(status) {
154154
logging.Debugf("run[%s] has reached final status[%s]", runID, status)
155-
delete(wfMap, runID)
155+
wfMap.Delete(runID)
156156

157157
// 给scheduler发concurrency channel信号
158158
if prevRun.ScheduleID != "" {

pkg/apiserver/controller/pipeline/run.go

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"errors"
2323
"fmt"
2424
"strings"
25+
"sync"
2526
"time"
2627

2728
log "github.com/sirupsen/logrus"
@@ -43,7 +44,7 @@ import (
4344
"github.com/PaddlePaddle/PaddleFlow/pkg/trace_logger"
4445
)
4546

46-
var wfMap = make(map[string]*pipeline.Workflow, 0)
47+
var wfMap = sync.Map{}
4748

4849
const (
4950
JsonFsOptions = "fs_options" // 由于在获取BodyMap的FsOptions前已经转为下划线形式,因此这里为fs_options
@@ -913,12 +914,14 @@ func StopRun(logEntry *log.Entry, userName, runID string, request UpdateRunReque
913914
return err
914915
}
915916

916-
wf, exist := wfMap[runID]
917+
wfInterface, exist := wfMap.Load(runID)
917918
if !exist {
918919
err := fmt.Errorf("run[%s]'s workflow ptr is lost", runID)
919920
logEntry.Errorln(err.Error())
920921
return err
921922
}
923+
924+
wf := wfInterface.(*pipeline.Workflow)
922925
run.RunOptions.StopForce = request.StopForce
923926
runUpdate := models.Run{
924927
Status: common.StatusRunTerminating,
@@ -1121,8 +1124,8 @@ func StartWf(run models.Run, wfPtr *pipeline.Workflow) error {
11211124
logEntry.Errorf("StartWf failed, error: %s", err.Error())
11221125
return err
11231126
}
1124-
wfMap[run.ID] = wfPtr
11251127

1128+
wfMap.Store(run.ID, wfPtr)
11261129
if err := models.UpdateRunStatus(logEntry, run.ID, common.StatusRunPending); err != nil {
11271130
return err
11281131
}
@@ -1216,7 +1219,7 @@ func RestartWf(run models.Run, isResume bool) (string, error) {
12161219
if isResume {
12171220
wfPtr.Resume(entryPointDagView, run.PostProcess, run.Status, run.RunOptions.StopForce)
12181221
} else {
1219-
wfMap[run.ID] = wfPtr
1222+
wfMap.Store(run.ID, wfPtr)
12201223
if err := models.UpdateRunStatus(logEntry, run.ID, common.StatusRunPending); err != nil {
12211224
return "", err
12221225
}
@@ -1246,7 +1249,7 @@ func newWorkflowByRun(run models.Run) (*pipeline.Workflow, error) {
12461249
}
12471250
// 如果此时没有runID的话,那么在后续有runID之后,需要:1. 填充wfMap 2. 初始化wf.runtime
12481251
if run.ID != "" {
1249-
wfMap[run.ID] = wfPtr
1252+
wfMap.Store(run.ID, wfPtr)
12501253
}
12511254
return wfPtr, nil
12521255
}

pkg/apiserver/controller/pipeline/run_test.go

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@ import (
2020
"encoding/base64"
2121
"encoding/json"
2222
"fmt"
23+
"reflect"
2324
"testing"
2425

2526
"github.com/agiledragon/gomonkey/v2"
27+
log "github.com/sirupsen/logrus"
2628
"github.com/stretchr/testify/assert"
2729
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
2830
"k8s.io/apimachinery/pkg/runtime"
@@ -289,6 +291,13 @@ func TestNewWorkflowByRun(t *testing.T) {
289291
}
290292
_, err = newMockWorkflowByRun(run2)
291293
assert.Nil(t, err)
294+
295+
run2.ID = "1445"
296+
_, err = newMockWorkflowByRun(run2)
297+
assert.Nil(t, err)
298+
299+
_, ok := wfMap.Load(run2.ID)
300+
assert.True(t, ok)
292301
}
293302

294303
func TestCreateRunByJson(t *testing.T) {
@@ -363,5 +372,118 @@ func TestCreateRun(t *testing.T) {
363372
}
364373
_, err = CreateRun(ctx, &createRunRequest, map[string]string{})
365374
assert.Nil(t, err)
375+
}
376+
377+
func TestStopRun(t *testing.T) {
378+
driver.InitMockDB()
379+
run := getMockRunWithoutRuntime()
380+
ctx := &logger.RequestContext{UserName: MockRootUser}
381+
runID, err := models.CreateRun(ctx.Logging(), &run)
382+
assert.Nil(t, err)
383+
384+
wfMap.Store(runID, &pipeline.Workflow{})
385+
386+
var wf *pipeline.Workflow
387+
patch := gomonkey.ApplyMethod(reflect.TypeOf(wf), "Stop", func(*pipeline.Workflow, bool) {
388+
return
389+
})
390+
defer patch.Reset()
391+
392+
var r *models.Run
393+
patch2 := gomonkey.ApplyMethod(reflect.TypeOf(r), "Encode", func(*models.Run) error {
394+
return nil
395+
})
396+
defer patch2.Reset()
366397

398+
patch3 := gomonkey.ApplyFunc(GetRunByID, func(logEntry *log.Entry, userName string, runID string) (models.Run, error) {
399+
return run, nil
400+
})
401+
defer patch3.Reset()
402+
403+
req := UpdateRunRequest{StopForce: false}
404+
err = StopRun(ctx.Logging(), "root", runID, req)
405+
assert.Nil(t, err)
406+
}
407+
408+
func TestStartWf(t *testing.T) {
409+
run := models.Run{
410+
ID: "run=00001",
411+
}
412+
wfptr := &pipeline.Workflow{}
413+
patch := gomonkey.ApplyMethod(reflect.TypeOf(wfptr), "NewWorkflowRuntime", func(*pipeline.Workflow) error {
414+
return nil
415+
})
416+
defer patch.Reset()
417+
418+
patch2 := gomonkey.ApplyFunc(models.UpdateRunStatus, func(logEntry *log.Entry, runID string, status string) error {
419+
return nil
420+
})
421+
defer patch2.Reset()
422+
423+
patch3 := gomonkey.ApplyMethod(reflect.TypeOf(wfptr), "Start", func(*pipeline.Workflow) {
424+
return
425+
})
426+
defer patch3.Reset()
427+
428+
StartWf(run, wfptr)
429+
_, ok := wfMap.Load(run.ID)
430+
assert.True(t, ok)
431+
}
432+
433+
func TestRestartWf(t *testing.T) {
434+
run := models.Run{
435+
ID: "run=00001",
436+
}
437+
wfptr := &pipeline.Workflow{}
438+
439+
patch := gomonkey.ApplyFunc(newWorkflowByRun, func(run models.Run) (*pipeline.Workflow, error) {
440+
return &pipeline.Workflow{}, nil
441+
})
442+
defer patch.Reset()
443+
444+
patch2 := gomonkey.ApplyFunc(models.UpdateRunStatus, func(logEntry *log.Entry, runID string, status string) error {
445+
return nil
446+
})
447+
defer patch2.Reset()
448+
449+
patch3 := gomonkey.ApplyMethod(reflect.TypeOf(wfptr), "Restart", func(*pipeline.Workflow, *schema.DagView, schema.PostProcessView) {
450+
return
451+
})
452+
defer patch3.Reset()
453+
454+
patch4 := gomonkey.ApplyFunc(models.GetRunJobsOfRun, func(logEntry *log.Entry, runID string) ([]models.RunJob, error) {
455+
return nil, nil
456+
})
457+
defer patch4.Reset()
458+
459+
patch5 := gomonkey.ApplyFunc(models.GetRunDagsOfRun, func(logEntry *log.Entry, runID string) ([]models.RunDag, error) {
460+
return nil, nil
461+
})
462+
defer patch5.Reset()
463+
464+
var r *models.Run
465+
patch6 := gomonkey.ApplyMethod(reflect.TypeOf(r), "Encode", func(*models.Run) error {
466+
return nil
467+
})
468+
defer patch6.Reset()
469+
470+
patch7 := gomonkey.ApplyFunc(models.CreateRun, func(logEntry *log.Entry, run *models.Run) (string, error) {
471+
return "", nil
472+
})
473+
defer patch7.Reset()
474+
475+
patch8 := gomonkey.ApplyFunc(models.CreateRunDag, func(logEntry *log.Entry, runDag *models.RunDag) (int64, error) {
476+
return 234, nil
477+
})
478+
defer patch8.Reset()
479+
480+
patch9 := gomonkey.ApplyMethod(reflect.TypeOf(r), "InitRuntime", func(_ *models.Run, jobs []models.RunJob, dags []models.RunDag) error {
481+
return nil
482+
})
483+
defer patch9.Reset()
484+
485+
id, err := RestartWf(run, false)
486+
assert.Nil(t, err)
487+
_, ok := wfMap.Load(id)
488+
assert.True(t, ok)
367489
}

pkg/pipeline/dagRuntime.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,7 @@ func (drt *DagRuntime) Resume(dagView *schema.DagView) {
483483
for _, name := range sorted {
484484
views, ok := dagView.EntryPoints[name]
485485
if !ok {
486-
// 说面当前节点还没有运行,在此处不进行处理
486+
// 说明当前节点还没有运行,在此处不进行处理
487487
continue
488488
}
489489

@@ -734,7 +734,7 @@ func (drt *DagRuntime) scheduleSubComponentAccordingView(dagView *schema.DagView
734734
continue
735735
}
736736

737-
// restart 时,所有子节点rumtine都处于终态,可以分成三类:
737+
// restart 时,所有子节点rumtime都处于终态,可以分成三类:
738738
// succeeded, skipped: 对于这类runtime无需重启,在 subruntime 中记录即可
739739
// failed, terminated: 需要重启
740740
// cancelled: 分两种情况:

pkg/pipeline/job.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ func NewPaddleFlowJob(name, image, userName string, eventChannel chan<- Workflow
8888
}
8989

9090
func NewPaddleFlowJobWithJobView(view *schema.JobView, image string, eventChannel chan<- WorkflowEvent,
91-
mainFS *schema.FsMount, extraFS []schema.FsMount) *PaddleFlowJob {
91+
mainFS *schema.FsMount, extraFS []schema.FsMount, userName string) *PaddleFlowJob {
9292
pfj := PaddleFlowJob{
9393
BaseJob: BaseJob{
9494
ID: view.JobID,
@@ -106,6 +106,7 @@ func NewPaddleFlowJobWithJobView(view *schema.JobView, image string, eventChanne
106106
eventChannel: eventChannel,
107107
mainFS: mainFS,
108108
extraFS: extraFS,
109+
userName: userName,
109110
}
110111

111112
pfj.Status = common.StatusRunRunning
@@ -211,7 +212,7 @@ func (pfj *PaddleFlowJob) Start() (string, error) {
211212
}
212213

213214
if pfj.ID == "" {
214-
err = fmt.Errorf("watch paddleflow job[%s] failed, job not started, id is empty!", pfj.Job().Name)
215+
err = fmt.Errorf("watch paddleflow job[%s] failed, job not started, id is empty", pfj.Job().Name)
215216
return "", err
216217
}
217218

pkg/pipeline/stepRuntime.go

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -193,10 +193,8 @@ func (srt *StepRuntime) Start() {
193193
srt.Execute()
194194
}
195195

196-
// Restart: 根据 jobView 来重启step
197-
// 如果 jobView 中的状态为 Succeeded, 则直接返回,无需重启
198-
// 如果 jobView 中的状态为 Running, 则进入监听即可
199-
// 否则 创建一个新的job并开始调度执行
196+
// Restart: 根据 jobView 来重新运行step
197+
// 只有step的状态为failed,terminated的时候,才会重新运行
200198
func (srt *StepRuntime) Restart(view *schema.JobView) {
201199
srt.logger.Infof("begin to restart step[%s]", srt.name)
202200
defer srt.processJobLock.Unlock()
@@ -223,7 +221,7 @@ func (srt *StepRuntime) Resume(view *schema.JobView) {
223221
defer srt.catchPanic()
224222

225223
srt.job = NewPaddleFlowJobWithJobView(view, srt.getWorkFlowStep().DockerEnv,
226-
srt.receiveEventChildren, srt.runConfig.mainFS, srt.getWorkFlowStep().ExtraFS)
224+
srt.receiveEventChildren, srt.runConfig.mainFS, srt.getWorkFlowStep().ExtraFS, srt.userName)
227225

228226
srt.pk = view.PK
229227
err := srt.updateStatus(view.Status)

0 commit comments

Comments
 (0)