Skip to content

Commit 6f9ec3e

Browse files
committed
move cancel function setup out of RunWorkflow transaction
1 parent d795789 commit 6f9ec3e

File tree

1 file changed

+40
-38
lines changed

1 file changed

+40
-38
lines changed

dbos/workflow.go

Lines changed: 40 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -732,11 +732,8 @@ func (c *dbosContext) RunWorkflow(_ DBOSContext, fn WorkflowFunc, input any, opt
732732
Priority: int(params.priority),
733733
}
734734

735-
var stopFunc func() bool
736-
cancelFuncCompleted := make(chan struct{})
737-
var workflowCtx DBOSContext
738-
outcomeChan := make(chan workflowOutcome[any], 1)
739735
var earlyReturnPollingHandle *workflowPollingHandle[any]
736+
var insertStatusResult *insertWorkflowResult
740737

741738
// Init status and record child workflow relationship in a single transaction
742739
err := retry(c, func() error {
@@ -752,7 +749,7 @@ func (c *dbosContext) RunWorkflow(_ DBOSContext, fn WorkflowFunc, input any, opt
752749
maxRetries: params.maxRetries,
753750
tx: tx,
754751
}
755-
insertStatusResult, err := c.systemDB.insertWorkflowStatus(uncancellableCtx, insertInput)
752+
insertStatusResult, err = c.systemDB.insertWorkflowStatus(uncancellableCtx, insertInput)
756753
if err != nil {
757754
c.logger.Error("failed to insert workflow status", "error", err, "workflow_id", workflowID)
758755
return err
@@ -785,43 +782,11 @@ func (c *dbosContext) RunWorkflow(_ DBOSContext, fn WorkflowFunc, input any, opt
785782
return nil
786783
}
787784

788-
// Create workflow state to track step execution
789-
wfState := &workflowState{
790-
workflowID: workflowID,
791-
stepID: -1, // Steps are O-indexed
792-
}
793-
794-
workflowCtx = WithValue(c, workflowStateKey, wfState)
795-
796-
// If the workflow has a timeout but no deadline, compute the deadline from the timeout.
797-
// Else use the durable deadline.
798-
durableDeadline := time.Time{}
799-
if insertStatusResult.timeout > 0 && insertStatusResult.workflowDeadline.IsZero() {
800-
durableDeadline = time.Now().Add(insertStatusResult.timeout)
801-
} else if !insertStatusResult.workflowDeadline.IsZero() {
802-
durableDeadline = insertStatusResult.workflowDeadline
803-
}
804-
805-
if !durableDeadline.IsZero() {
806-
workflowCtx, _ = WithTimeout(workflowCtx, time.Until(durableDeadline))
807-
// Register a cancel function that cancels the workflow in the DB as soon as the context is cancelled
808-
dbosCancelFunction := func() {
809-
c.logger.Info("Cancelling workflow", "workflow_id", workflowID)
810-
err = retry(c, func() error {
811-
return c.systemDB.cancelWorkflow(uncancellableCtx, workflowID)
812-
}, withRetrierLogger(c.logger))
813-
if err != nil {
814-
c.logger.Error("Failed to cancel workflow", "error", err)
815-
}
816-
close(cancelFuncCompleted)
817-
}
818-
stopFunc = context.AfterFunc(workflowCtx, dbosCancelFunction)
819-
}
820-
821785
// Commit the transaction. This must happen before we start the goroutine to ensure the workflow is found by steps in the database
822786
if err := tx.Commit(uncancellableCtx); err != nil {
823787
return newWorkflowExecutionError(workflowID, fmt.Errorf("failed to commit transaction: %w", err))
824788
}
789+
825790
return nil
826791
}, withRetrierLogger(c.logger))
827792
if err != nil {
@@ -831,6 +796,43 @@ func (c *dbosContext) RunWorkflow(_ DBOSContext, fn WorkflowFunc, input any, opt
831796
return earlyReturnPollingHandle, nil
832797
}
833798

799+
outcomeChan := make(chan workflowOutcome[any], 1)
800+
801+
// Create workflow state to track step execution
802+
wfState := &workflowState{
803+
workflowID: workflowID,
804+
stepID: -1, // Steps are O-indexed
805+
}
806+
807+
workflowCtx := WithValue(c, workflowStateKey, wfState)
808+
809+
// If the workflow has a timeout but no deadline, compute the deadline from the timeout.
810+
// Else use the durable deadline.
811+
durableDeadline := time.Time{}
812+
if insertStatusResult.timeout > 0 && insertStatusResult.workflowDeadline.IsZero() {
813+
durableDeadline = time.Now().Add(insertStatusResult.timeout)
814+
} else if !insertStatusResult.workflowDeadline.IsZero() {
815+
durableDeadline = insertStatusResult.workflowDeadline
816+
}
817+
818+
var stopFunc func() bool
819+
cancelFuncCompleted := make(chan struct{})
820+
if !durableDeadline.IsZero() {
821+
workflowCtx, _ = WithTimeout(workflowCtx, time.Until(durableDeadline))
822+
// Register a cancel function that cancels the workflow in the DB as soon as the context is cancelled
823+
dbosCancelFunction := func() {
824+
c.logger.Info("Cancelling workflow", "workflow_id", workflowID)
825+
err = retry(c, func() error {
826+
return c.systemDB.cancelWorkflow(uncancellableCtx, workflowID)
827+
}, withRetrierLogger(c.logger))
828+
if err != nil {
829+
c.logger.Error("Failed to cancel workflow", "error", err)
830+
}
831+
close(cancelFuncCompleted)
832+
}
833+
stopFunc = context.AfterFunc(workflowCtx, dbosCancelFunction)
834+
}
835+
834836
// Run the function in a goroutine
835837
c.workflowsWg.Add(1)
836838
go func() {

0 commit comments

Comments
 (0)