diff --git a/dbos/admin_server_test.go b/dbos/admin_server_test.go index 25e2ff8c..42484e82 100644 --- a/dbos/admin_server_test.go +++ b/dbos/admin_server_test.go @@ -34,9 +34,6 @@ func TestAdminServer(t *testing.T) { } }() - // Give time for any startup processes - time.Sleep(100 * time.Millisecond) - // Verify admin server is not running client := &http.Client{Timeout: 1 * time.Second} _, err = client.Get(fmt.Sprintf("http://localhost:3001/%s", strings.TrimPrefix(_HEALTHCHECK_PATTERN, "GET /"))) diff --git a/dbos/conductor.go b/dbos/conductor.go new file mode 100644 index 00000000..a89f77fc --- /dev/null +++ b/dbos/conductor.go @@ -0,0 +1,955 @@ +package dbos + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "math" + "math/rand" + "net" + "net/url" + "os" + "sync" + "sync/atomic" + "time" + + "github.com/gorilla/websocket" +) + +const ( + _PING_INTERVAL = 20 * time.Second + _PING_TIMEOUT = 30 * time.Second // Should be slightly greater than server's executorPingWait (25s) + _INITIAL_RECONNECT_WAIT = 1 * time.Second + _MAX_RECONNECT_WAIT = 30 * time.Second + _HANDSHAKE_TIMEOUT = 10 * time.Second + _WRITE_DEADLINE = 5 * time.Second +) + +// ConductorConfig contains configuration for the conductor +type ConductorConfig struct { + url string + apiKey string + appName string +} + +// Conductor manages the WebSocket connection to the DBOS conductor service +type Conductor struct { + dbosCtx *dbosContext + logger *slog.Logger + + // Connection management + conn *websocket.Conn + needsReconnect atomic.Bool + wg sync.WaitGroup + stopOnce sync.Once + writeMu sync.Mutex // writeMu protects concurrent writes to the WebSocket connection (pings + handling messages) + + // Connection parameters + url url.URL + pingInterval time.Duration + pingTimeout time.Duration + reconnectWait time.Duration + + // pingCancel cancels the ping goroutine context + pingCancel context.CancelFunc +} + +// Launch starts the conductor main goroutine +func (c *Conductor) Launch() { + c.logger.Info("Launching conductor") + c.wg.Add(1) + go c.run() +} + +func NewConductor(dbosCtx *dbosContext, config ConductorConfig) (*Conductor, error) { + if config.apiKey == "" { + return nil, fmt.Errorf("conductor API key is required") + } + if config.url == "" { + return nil, fmt.Errorf("conductor URL is required") + } + + baseURL, err := url.Parse(config.url) + if err != nil { + return nil, fmt.Errorf("invalid conductor URL: %w", err) + } + + wsURL := url.URL{ + Scheme: baseURL.Scheme, + Host: baseURL.Host, + Path: baseURL.JoinPath("websocket", config.appName, config.apiKey).Path, + } + + c := &Conductor{ + dbosCtx: dbosCtx, + url: wsURL, + pingInterval: _PING_INTERVAL, + pingTimeout: _PING_TIMEOUT, + reconnectWait: _INITIAL_RECONNECT_WAIT, + logger: dbosCtx.logger.With("service", "conductor"), + } + + // Start with needsReconnect set to true so we connect on first run + c.needsReconnect.Store(true) + + return c, nil +} + +func (c *Conductor) Shutdown(timeout time.Duration) { + c.stopOnce.Do(func() { + if c.pingCancel != nil { + c.pingCancel() + } + + c.closeConn() + + done := make(chan struct{}) + go func() { + c.wg.Wait() + close(done) + }() + + select { + case <-done: + c.logger.Info("Conductor shut down") + case <-time.After(timeout): + c.logger.Warn("Timeout waiting for conductor to shut down", "timeout", timeout) + } + }) +} + +// reconnectWaitWithJitter adds random jitter to the reconnect wait time to prevent thundering herd +func (c *Conductor) reconnectWaitWithJitter() time.Duration { + // Add jitter: random value between 0.5 * wait and 1.5 * wait + jitter := 0.5 + rand.Float64() // #nosec G404 -- jitter for backoff doesn't need crypto-secure randomness + return time.Duration(float64(c.reconnectWait) * jitter) +} + +// closeConn closes the connection and signals that reconnection is needed +func (c *Conductor) closeConn() { + // Cancel ping goroutine first + if c.pingCancel != nil { + c.pingCancel() + c.pingCancel = nil + } + + // Acquire write mutex to ensure no concurrent writes during close + c.writeMu.Lock() + defer c.writeMu.Unlock() + + if c.conn != nil { + if err := c.conn.SetWriteDeadline(time.Now().Add(_WRITE_DEADLINE)); err != nil { + c.logger.Warn("Failed to set write deadline", "error", err) + } + err := c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "shutting down")) + if err != nil { + c.logger.Warn("Failed to send close message", "error", err) + } + err = c.conn.Close() + if err != nil { + c.logger.Warn("Failed to close connection", "error", err) + } + c.conn = nil + } + // Signal that we need to reconnect + c.needsReconnect.Store(true) +} + +func (c *Conductor) run() { + defer c.wg.Done() + + for { + // Check if the context has been cancelled + select { + case <-c.dbosCtx.Done(): + c.logger.Info("DBOS context done, stopping conductor", "cause", context.Cause(c.dbosCtx)) + c.closeConn() + return + default: + } + + // Connect if reconnection is needed + if c.needsReconnect.Load() { + if err := c.connect(); err != nil { + c.logger.Warn("Failed to connect to conductor", "error", err) + select { + case <-c.dbosCtx.Done(): + c.logger.Info("DBOS context done, stopping conductor", "cause", context.Cause(c.dbosCtx)) + return + case <-time.After(c.reconnectWaitWithJitter()): + // Exponential backoff with jitter up to max wait + if c.reconnectWait < _MAX_RECONNECT_WAIT { + c.reconnectWait *= 2 + if c.reconnectWait > _MAX_RECONNECT_WAIT { + c.reconnectWait = _MAX_RECONNECT_WAIT + } + } + continue + } + } + // Reset reconnect wait and clear reconnect flag on successful connection + c.reconnectWait = _INITIAL_RECONNECT_WAIT + c.needsReconnect.Store(false) + } + + // This shouldn't happen but check anyway + if c.conn == nil { + c.needsReconnect.Store(true) + continue + } + + // Read message (will timeout based on read deadline set in connect) + messageType, message, err := c.conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { + c.logger.Warn("Unexpected WebSocket close", "error", err) + } else if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + c.logger.Debug("Read deadline reached", "error", err) + } else { + c.logger.Debug("Connection closed", "error", err) + } + // Close connection to trigger reconnection + c.closeConn() + continue + } + + // Only accept text messages + if messageType != websocket.TextMessage { + c.logger.Warn("Received unexpected message type, forcing reconnection", "type", messageType) + c.closeConn() + continue + } + + ht := time.Now() + if err := c.handleMessage(message); err != nil { + c.logger.Error("Failed to handle message", "error", err) + } + c.logger.Debug("Handled message", "message", messageType, "latency_us", time.Since(ht).Microseconds()) + } +} + +func (c *Conductor) connect() error { + c.logger.Debug("Connecting to conductor") + + dialer := websocket.Dialer{ + HandshakeTimeout: _HANDSHAKE_TIMEOUT, + } + + conn, _, err := dialer.Dial(c.url.String(), nil) + if err != nil { + return fmt.Errorf("failed to dial conductor: %w", err) + } + + // Set initial read deadline + if err := conn.SetReadDeadline(time.Now().Add(c.pingTimeout)); err != nil { + cErr := conn.Close() + if cErr != nil { + c.logger.Warn("Failed to close connection", "error", cErr) + } + return fmt.Errorf("failed to set read deadline: %w", err) + } + + // Set pong handler to reset read deadline + conn.SetPongHandler(func(appData string) error { + c.logger.Debug("Received pong from conductor") + return conn.SetReadDeadline(time.Now().Add(c.pingTimeout)) + }) + + // Store the connection + c.conn = conn + + // Create a cancellable context for the ping goroutine + pingCtx, pingCancel := context.WithCancel(c.dbosCtx) + c.pingCancel = pingCancel + + // Start ping goroutine + c.wg.Add(1) + go func() { + defer c.wg.Done() + ticker := time.NewTicker(c.pingInterval) + defer ticker.Stop() + + for { + select { + case <-pingCtx.Done(): + c.logger.Debug("Exiting Conductor ping goroutine", "cause", context.Cause(pingCtx)) + return + case <-ticker.C: + if err := c.ping(); err != nil { + c.logger.Warn("Ping failed, signaling reconnection", "error", err) + // Signal that we need to reconnect and exit ping goroutine + c.needsReconnect.Store(true) + return + } + } + } + }() + + c.logger.Info("Connected to DBOS conductor") + return nil +} + +func (c *Conductor) ping() error { + c.writeMu.Lock() + defer c.writeMu.Unlock() + + if c.conn == nil { + return fmt.Errorf("no connection") + } + + c.logger.Debug("Sending ping to conductor") + + if err := c.conn.SetWriteDeadline(time.Now().Add(_WRITE_DEADLINE)); err != nil { + c.logger.Warn("Failed to set write deadline for ping", "error", err) + } + if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { + return fmt.Errorf("failed to send ping: %w", err) + } + if err := c.conn.SetWriteDeadline(time.Time{}); err != nil { + c.logger.Warn("Failed to clear write deadline", "error", err) + } + + return nil +} + +func (c *Conductor) handleMessage(data []byte) error { + var base baseMessage + if err := json.Unmarshal(data, &base); err != nil { + c.logger.Error("Failed to parse message", "error", err) + return fmt.Errorf("failed to parse base message: %w", err) + } + c.logger.Debug("Received message", "type", base.Type, "request_id", base.RequestID) + + switch base.Type { + case executorInfo: + return c.handleExecutorInfoRequest(data, base.RequestID) + case recoveryMessage: + return c.handleRecoveryRequest(data, base.RequestID) + case cancelWorkflowMessage: + return c.handleCancelWorkflowRequest(data, base.RequestID) + case resumeWorkflowMessage: + return c.handleResumeWorkflowRequest(data, base.RequestID) + case listWorkflowsMessage: + return c.handleListWorkflowsRequest(data, base.RequestID) + case listQueuedWorkflowsMessage: + return c.handleListQueuedWorkflowsRequest(data, base.RequestID) + case listStepsMessage: + return c.handleListStepsRequest(data, base.RequestID) + case getWorkflowMessage: + return c.handleGetWorkflowRequest(data, base.RequestID) + case forkWorkflowMessage: + return c.handleForkWorkflowRequest(data, base.RequestID) + case existPendingWorkflowsMessage: + return c.handleExistPendingWorkflowsRequest(data, base.RequestID) + case retentionMessage: + return c.handleRetentionRequest(data, base.RequestID) + default: + c.logger.Warn("Unknown message type", "type", base.Type) + return c.handleUnknownMessageType(base.RequestID, base.Type, "Unknown message type") + } +} + +func (c *Conductor) handleExecutorInfoRequest(data []byte, requestID string) error { + var req executorInfoRequest + if err := json.Unmarshal(data, &req); err != nil { + c.logger.Error("Failed to parse executor info request", "error", err) + return fmt.Errorf("failed to parse executor info request: %w", err) + } + c.logger.Debug("Handling executor info request", "request_id", req) + + hostname, err := os.Hostname() + if err != nil { + c.logger.Error("Failed to get hostname", "error", err) + return fmt.Errorf("failed to get hostname: %w", err) + } + + response := executorInfoResponse{ + baseResponse: baseResponse{ + baseMessage: baseMessage{ + Type: executorInfo, + RequestID: requestID, + }, + }, + ExecutorID: c.dbosCtx.GetExecutorID(), + ApplicationVersion: c.dbosCtx.GetApplicationVersion(), + Hostname: &hostname, + } + + return c.sendResponse(response, string(executorInfo)) +} + +func (c *Conductor) handleRecoveryRequest(data []byte, requestID string) error { + var req recoveryConductorRequest + if err := json.Unmarshal(data, &req); err != nil { + c.logger.Error("Failed to parse recovery request", "error", err) + return fmt.Errorf("failed to parse recovery request: %w", err) + } + c.logger.Debug("Handling recovery request", "executor_ids", req.ExecutorIDs, "request_id", requestID) + + success := true + var errorMsg *string + + _, err := recoverPendingWorkflows(c.dbosCtx, req.ExecutorIDs) + if err != nil { + c.logger.Error("Failed to recover pending workflows", "executor_ids", req.ExecutorIDs, "error", err) + errStr := fmt.Sprintf("failed to recover pending workflows: %v", err) + errorMsg = &errStr + success = false + } else { + c.logger.Info("Successfully recovered pending workflows", "executor_ids", req.ExecutorIDs) + } + + response := recoveryConductorResponse{ + baseResponse: baseResponse{ + baseMessage: baseMessage{ + Type: recoveryMessage, + RequestID: requestID, + }, + ErrorMessage: errorMsg, + }, + Success: success, + } + + return c.sendResponse(response, string(recoveryMessage)) +} + +func (c *Conductor) handleCancelWorkflowRequest(data []byte, requestID string) error { + var req cancelWorkflowConductorRequest + if err := json.Unmarshal(data, &req); err != nil { + c.logger.Error("Failed to parse cancel workflow request", "error", err) + return fmt.Errorf("failed to parse cancel workflow request: %w", err) + } + c.logger.Debug("Handling cancel workflow request", "workflow_id", req.WorkflowID, "request_id", requestID) + + success := true + var errorMsg *string + + if err := c.dbosCtx.CancelWorkflow(c.dbosCtx, req.WorkflowID); err != nil { + c.logger.Error("Failed to cancel workflow", "workflow_id", req.WorkflowID, "error", err) + errStr := fmt.Sprintf("failed to cancel workflow: %v", err) + errorMsg = &errStr + success = false + } else { + c.logger.Info("Successfully cancelled workflow", "workflow_id", req.WorkflowID) + } + + response := cancelWorkflowConductorResponse{ + baseResponse: baseResponse{ + baseMessage: baseMessage{ + Type: cancelWorkflowMessage, + RequestID: requestID, + }, + ErrorMessage: errorMsg, + }, + Success: success, + } + + return c.sendResponse(response, string(cancelWorkflowMessage)) +} + +func (c *Conductor) handleResumeWorkflowRequest(data []byte, requestID string) error { + var req resumeWorkflowConductorRequest + if err := json.Unmarshal(data, &req); err != nil { + c.logger.Error("Failed to parse resume workflow request", "error", err) + return fmt.Errorf("failed to parse resume workflow request: %w", err) + } + c.logger.Debug("Handling resume workflow request", "workflow_id", req.WorkflowID, "request_id", requestID) + + success := true + var errorMsg *string + + _, err := c.dbosCtx.ResumeWorkflow(c.dbosCtx, req.WorkflowID) + if err != nil { + c.logger.Error("Failed to resume workflow", "workflow_id", req.WorkflowID, "error", err) + errStr := fmt.Sprintf("failed to resume workflow: %v", err) + errorMsg = &errStr + success = false + } else { + c.logger.Info("Successfully resumed workflow", "workflow_id", req.WorkflowID) + } + + response := resumeWorkflowConductorResponse{ + baseResponse: baseResponse{ + baseMessage: baseMessage{ + Type: resumeWorkflowMessage, + RequestID: requestID, + }, + ErrorMessage: errorMsg, + }, + Success: success, + } + + return c.sendResponse(response, string(resumeWorkflowMessage)) +} + +func (c *Conductor) handleRetentionRequest(data []byte, requestID string) error { + var req retentionConductorRequest + if err := json.Unmarshal(data, &req); err != nil { + c.logger.Error("Failed to parse retention request", "error", err) + return fmt.Errorf("failed to parse retention request: %w", err) + } + c.logger.Debug("Handling retention request", "request", req, "request_id", requestID) + + success := true + var errorMsg *string + + // Handle garbage collection if parameters are provided + if req.Body.GCCutoffEpochMs != nil || req.Body.GCRowsThreshold != nil { + var cutoffMs *int64 + if req.Body.GCCutoffEpochMs != nil { + ms := int64(*req.Body.GCCutoffEpochMs) + cutoffMs = &ms + } + + var rowsThreshold *int + if req.Body.GCRowsThreshold != nil { + rowsThreshold = req.Body.GCRowsThreshold + } + + input := garbageCollectWorkflowsInput{ + cutoffEpochTimestampMs: cutoffMs, + rowsThreshold: rowsThreshold, + } + + err := c.dbosCtx.systemDB.garbageCollectWorkflows(c.dbosCtx, input) + if err != nil { + c.logger.Error("Failed to garbage collect workflows", "error", err) + errStr := fmt.Sprintf("failed to garbage collect workflows: %v", err) + errorMsg = &errStr + success = false + } else { + c.logger.Info("Successfully garbage collected workflows", "cutoff_ms", cutoffMs, "rows_threshold", rowsThreshold) + } + } + + // Handle timeout enforcement if parameter is provided and garbage collection succeeded + if success && req.Body.TimeoutCutoffEpochMs != nil { + cutoffTime := time.UnixMilli(int64(*req.Body.TimeoutCutoffEpochMs)) + err := c.dbosCtx.systemDB.cancelAllBefore(c.dbosCtx, cutoffTime) + if err != nil { + c.logger.Error("Failed to timeout workflows", "cutoff_ms", *req.Body.TimeoutCutoffEpochMs, "error", err) + errStr := fmt.Sprintf("failed to timeout workflows: %v", err) + errorMsg = &errStr + success = false + } else { + c.logger.Info("Successfully timed out workflows", "cutoff_ms", *req.Body.TimeoutCutoffEpochMs) + } + } + + response := retentionConductorResponse{ + baseResponse: baseResponse{ + baseMessage: baseMessage{ + Type: retentionMessage, + RequestID: requestID, + }, + ErrorMessage: errorMsg, + }, + Success: success, + } + + return c.sendResponse(response, string(retentionMessage)) +} + +func (c *Conductor) handleListWorkflowsRequest(data []byte, requestID string) error { + var req listWorkflowsConductorRequest + if err := json.Unmarshal(data, &req); err != nil { + c.logger.Error("Failed to parse list workflows request", "error", err) + return fmt.Errorf("failed to parse list workflows request: %w", err) + } + c.logger.Debug("Handling list workflows request", "request", req) + + var opts []ListWorkflowsOption + opts = append(opts, WithLoadInput(req.Body.LoadInput)) + opts = append(opts, WithLoadOutput(req.Body.LoadOutput)) + opts = append(opts, WithSortDesc(req.Body.SortDesc)) + if len(req.Body.WorkflowUUIDs) > 0 { + opts = append(opts, WithWorkflowIDs(req.Body.WorkflowUUIDs)) + } + if req.Body.WorkflowName != nil { + opts = append(opts, WithName(*req.Body.WorkflowName)) + } + if req.Body.AuthenticatedUser != nil { + opts = append(opts, WithUser(*req.Body.AuthenticatedUser)) + } + if req.Body.ApplicationVersion != nil { + opts = append(opts, WithAppVersion(*req.Body.ApplicationVersion)) + } + if req.Body.Limit != nil { + opts = append(opts, WithLimit(*req.Body.Limit)) + } + if req.Body.Offset != nil { + opts = append(opts, WithOffset(*req.Body.Offset)) + } + if req.Body.StartTime != nil { + opts = append(opts, WithStartTime(*req.Body.StartTime)) + } + if req.Body.EndTime != nil { + opts = append(opts, WithEndTime(*req.Body.EndTime)) + } + if req.Body.Status != nil { + opts = append(opts, WithStatus([]WorkflowStatusType{WorkflowStatusType(*req.Body.Status)})) + } + + workflows, err := c.dbosCtx.ListWorkflows(c.dbosCtx, opts...) + if err != nil { + c.logger.Error("Failed to list workflows", "error", err) + errorMsg := fmt.Sprintf("failed to list workflows: %v", err) + response := listWorkflowsConductorResponse{ + baseResponse: baseResponse{ + baseMessage: baseMessage{ + Type: listWorkflowsMessage, + RequestID: requestID, + }, + ErrorMessage: &errorMsg, + }, + Output: []listWorkflowsConductorResponseBody{}, + } + return c.sendResponse(response, "list workflows response") + } + + formattedWorkflows := make([]listWorkflowsConductorResponseBody, len(workflows)) + for i, wf := range workflows { + formattedWorkflows[i] = formatListWorkflowsResponseBody(wf) + } + + response := listWorkflowsConductorResponse{ + baseResponse: baseResponse{ + baseMessage: baseMessage{ + Type: listWorkflowsMessage, + RequestID: requestID, + }, + }, + Output: formattedWorkflows, + } + + return c.sendResponse(response, string(listWorkflowsMessage)) +} + +func (c *Conductor) handleListQueuedWorkflowsRequest(data []byte, requestID string) error { + var req listWorkflowsConductorRequest + if err := json.Unmarshal(data, &req); err != nil { + c.logger.Error("Failed to parse list queued workflows request", "error", err) + return fmt.Errorf("failed to parse list queued workflows request: %w", err) + } + c.logger.Debug("Handling list queued workflows request", "request", req) + + // Build functional options for ListWorkflows + var opts []ListWorkflowsOption + opts = append(opts, WithLoadInput(req.Body.LoadInput)) + opts = append(opts, WithLoadOutput(false)) // Don't load output for queued workflows + opts = append(opts, WithSortDesc(req.Body.SortDesc)) + + // Add status filter for queued workflows + queuedStatuses := make([]WorkflowStatusType, 0) + if req.Body.Status != nil { + // If a specific status is requested, use that status + status := WorkflowStatusType(*req.Body.Status) + if status != WorkflowStatusPending && status != WorkflowStatusEnqueued { + c.logger.Warn("Received unexpected filtering status for listing queued workflows", "status", status) + } + queuedStatuses = append(queuedStatuses, status) + } + if len(queuedStatuses) == 0 { + queuedStatuses = []WorkflowStatusType{WorkflowStatusPending, WorkflowStatusEnqueued} + } + opts = append(opts, WithStatus(queuedStatuses)) + + if req.Body.WorkflowName != nil { + opts = append(opts, WithName(*req.Body.WorkflowName)) + } + if req.Body.Limit != nil { + opts = append(opts, WithLimit(*req.Body.Limit)) + } + if req.Body.Offset != nil { + opts = append(opts, WithOffset(*req.Body.Offset)) + } + if req.Body.StartTime != nil { + opts = append(opts, WithStartTime(*req.Body.StartTime)) + } + if req.Body.EndTime != nil { + opts = append(opts, WithEndTime(*req.Body.EndTime)) + } + if req.Body.QueueName != nil { + opts = append(opts, WithQueueName(*req.Body.QueueName)) + } + + workflows, err := c.dbosCtx.ListWorkflows(c.dbosCtx, opts...) + if err != nil { + c.logger.Error("Failed to list queued workflows", "error", err) + errorMsg := fmt.Sprintf("failed to list queued workflows: %v", err) + response := listWorkflowsConductorResponse{ + baseResponse: baseResponse{ + baseMessage: baseMessage{ + Type: listQueuedWorkflowsMessage, + RequestID: requestID, + }, + ErrorMessage: &errorMsg, + }, + Output: []listWorkflowsConductorResponseBody{}, + } + return c.sendResponse(response, string(listQueuedWorkflowsMessage)) + } + + // If no queue name was specified, only include workflows that have a queue name + var filteredWorkflows []WorkflowStatus + if req.Body.QueueName == nil { + for _, wf := range workflows { + if wf.QueueName != "" { + filteredWorkflows = append(filteredWorkflows, wf) + } + } + } else { + filteredWorkflows = workflows + } + + // Prepare response payload + formattedWorkflows := make([]listWorkflowsConductorResponseBody, len(filteredWorkflows)) + for i, wf := range filteredWorkflows { + formattedWorkflows[i] = formatListWorkflowsResponseBody(wf) + } + + response := listWorkflowsConductorResponse{ + baseResponse: baseResponse{ + baseMessage: baseMessage{ + Type: listQueuedWorkflowsMessage, + RequestID: requestID, + }, + }, + Output: formattedWorkflows, + } + + return c.sendResponse(response, string(listQueuedWorkflowsMessage)) +} + +func (c *Conductor) handleListStepsRequest(data []byte, requestID string) error { + var req listStepsConductorRequest + if err := json.Unmarshal(data, &req); err != nil { + c.logger.Error("Failed to parse list steps request", "error", err) + return fmt.Errorf("failed to parse list steps request: %w", err) + } + c.logger.Debug("Handling list steps request", "request", req) + + // Get workflow steps using the existing systemDB method + steps, err := c.dbosCtx.systemDB.getWorkflowSteps(c.dbosCtx, req.WorkflowID) + if err != nil { + c.logger.Error("Failed to list workflow steps", "workflow_id", req.WorkflowID, "error", err) + errorMsg := fmt.Sprintf("failed to list workflow steps: %v", err) + response := listStepsConductorResponse{ + baseResponse: baseResponse{ + baseMessage: baseMessage{ + Type: listStepsMessage, + RequestID: requestID, + }, + ErrorMessage: &errorMsg, + }, + Output: nil, + } + return c.sendResponse(response, string(listStepsMessage)) + } + + // Convert steps to response format + var formattedSteps *[]workflowStepsConductorResponseBody + if steps != nil { + stepsList := make([]workflowStepsConductorResponseBody, len(steps)) + for i, step := range steps { + stepsList[i] = formatWorkflowStepsResponseBody(step) + } + formattedSteps = &stepsList + } + + response := listStepsConductorResponse{ + baseResponse: baseResponse{ + baseMessage: baseMessage{ + Type: listStepsMessage, + RequestID: requestID, + }, + }, + Output: formattedSteps, + } + + return c.sendResponse(response, string(listStepsMessage)) +} + +func (c *Conductor) handleGetWorkflowRequest(data []byte, requestID string) error { + var req getWorkflowConductorRequest + if err := json.Unmarshal(data, &req); err != nil { + c.logger.Error("Failed to parse get workflow request", "error", err) + return fmt.Errorf("failed to parse get workflow request: %w", err) + } + c.logger.Debug("Handling get workflow request", "workflow_id", req.WorkflowID) + + workflows, err := c.dbosCtx.ListWorkflows(c.dbosCtx, WithWorkflowIDs([]string{req.WorkflowID})) + if err != nil { + c.logger.Error("Failed to get workflow", "workflow_id", req.WorkflowID, "error", err) + errorMsg := fmt.Sprintf("failed to get workflow: %v", err) + response := getWorkflowConductorResponse{ + baseResponse: baseResponse{ + baseMessage: baseMessage{ + Type: getWorkflowMessage, + RequestID: requestID, + }, + ErrorMessage: &errorMsg, + }, + Output: nil, + } + return c.sendResponse(response, "get workflow response") + } + + var formattedWorkflow *listWorkflowsConductorResponseBody + if len(workflows) > 0 { + formatted := formatListWorkflowsResponseBody(workflows[0]) + formattedWorkflow = &formatted + } + + response := getWorkflowConductorResponse{ + baseResponse: baseResponse{ + baseMessage: baseMessage{ + Type: getWorkflowMessage, + RequestID: requestID, + }, + }, + Output: formattedWorkflow, + } + + return c.sendResponse(response, string(getWorkflowMessage)) +} + +func (c *Conductor) handleForkWorkflowRequest(data []byte, requestID string) error { + var req forkWorkflowConductorRequest + if err := json.Unmarshal(data, &req); err != nil { + c.logger.Error("Failed to parse fork workflow request", "error", err) + return fmt.Errorf("failed to parse fork workflow request: %w", err) + } + c.logger.Debug("Handling fork workflow request", "request", req) + + // Validate StartStep to prevent integer overflow + if req.Body.StartStep < 0 { + return fmt.Errorf("invalid StartStep: cannot be negative") + } + if req.Body.StartStep > math.MaxInt32/2 { + return fmt.Errorf("invalid StartStep: cannot be greater than %d", math.MaxInt32/2) + } + input := ForkWorkflowInput{ + OriginalWorkflowID: req.Body.WorkflowID, + StartStep: uint(req.Body.StartStep), // #nosec G115 -- validated above + } + + // Set optional fields + if req.Body.NewWorkflowID != nil { + input.ForkedWorkflowID = *req.Body.NewWorkflowID + } + if req.Body.ApplicationVersion != nil { + input.ApplicationVersion = *req.Body.ApplicationVersion + } + + // Execute the fork workflow + handle, err := c.dbosCtx.ForkWorkflow(c.dbosCtx, input) + var newWorkflowID *string + var errorMsg *string + + if err != nil { + c.logger.Error("Failed to fork workflow", "original_workflow_id", req.Body.WorkflowID, "error", err) + errStr := fmt.Sprintf("failed to fork workflow: %v", err) + errorMsg = &errStr + } else { + workflowID := handle.GetWorkflowID() + newWorkflowID = &workflowID + c.logger.Info("Successfully forked workflow", "original_workflow_id", req.Body.WorkflowID, "new_workflow_id", workflowID) + } + + response := forkWorkflowConductorResponse{ + baseResponse: baseResponse{ + baseMessage: baseMessage{ + Type: forkWorkflowMessage, + RequestID: requestID, + }, + ErrorMessage: errorMsg, + }, + NewWorkflowID: newWorkflowID, + } + + return c.sendResponse(response, string(forkWorkflowMessage)) +} + +func (c *Conductor) handleExistPendingWorkflowsRequest(data []byte, requestID string) error { + var req existPendingWorkflowsConductorRequest + if err := json.Unmarshal(data, &req); err != nil { + c.logger.Error("Failed to parse exist pending workflows request", "error", err) + return fmt.Errorf("failed to parse exist pending workflows request: %w", err) + } + c.logger.Debug("Handling exist pending workflows request", "executor_id", req.ExecutorID, "application_version", req.ApplicationVersion) + + opts := []ListWorkflowsOption{ + WithStatus([]WorkflowStatusType{WorkflowStatusPending}), + WithLimit(1), // We only need to know if any exist, so limit to 1 for efficiency + WithExecutorIDs([]string{req.ExecutorID}), + WithAppVersion(req.ApplicationVersion), + } + + workflows, err := c.dbosCtx.ListWorkflows(c.dbosCtx, opts...) + var errorMsg *string + if err != nil { + c.logger.Error("Failed to check for pending workflows", "executor_id", req.ExecutorID, "application_version", req.ApplicationVersion, "error", err) + errStr := fmt.Sprintf("failed to check for pending workflows: %v", err) + errorMsg = &errStr + } + + response := existPendingWorkflowsConductorResponse{ + baseResponse: baseResponse{ + baseMessage: baseMessage{ + Type: existPendingWorkflowsMessage, + RequestID: requestID, + }, + ErrorMessage: errorMsg, + }, + Exist: len(workflows) > 0, + } + + return c.sendResponse(response, string(existPendingWorkflowsMessage)) +} + +func (c *Conductor) handleUnknownMessageType(requestID string, msgType messageType, errorMsg string) error { + if c.conn == nil { + return fmt.Errorf("no connection") + } + + response := baseResponse{ + baseMessage: baseMessage{ + Type: msgType, + RequestID: requestID, + }, + ErrorMessage: &errorMsg, + } + + return c.sendResponse(response, "unknown message type response") +} + +func (c *Conductor) sendResponse(response any, responseType string) error { + c.writeMu.Lock() + defer c.writeMu.Unlock() + + if c.conn == nil { + return fmt.Errorf("no connection") + } + + data, err := json.Marshal(response) + if err != nil { + return fmt.Errorf("failed to marshal %s: %w", responseType, err) + } + + c.logger.Debug("Sending response", "type", responseType, "len", len(data)) + + if err := c.conn.SetWriteDeadline(time.Now().Add(_WRITE_DEADLINE)); err != nil { + c.logger.Warn("Failed to set write deadline", "type", responseType, "error", err) + } + if err := c.conn.WriteMessage(websocket.TextMessage, data); err != nil { + c.logger.Error("Failed to send response", "type", responseType, "error", err) + return fmt.Errorf("failed to send message: %w", err) + } + if err := c.conn.SetWriteDeadline(time.Time{}); err != nil { + c.logger.Warn("Failed to clear write deadline", "type", responseType, "error", err) + } + + return nil +} diff --git a/dbos/conductor_protocol.go b/dbos/conductor_protocol.go new file mode 100644 index 00000000..7f06065b --- /dev/null +++ b/dbos/conductor_protocol.go @@ -0,0 +1,321 @@ +package dbos + +import ( + "encoding/json" + "strconv" + "time" +) + +// messageType represents the type of message exchanged with the conductor +type messageType string + +const ( + executorInfo messageType = "executor_info" + recoveryMessage messageType = "recovery" + cancelWorkflowMessage messageType = "cancel" + resumeWorkflowMessage messageType = "resume" + listWorkflowsMessage messageType = "list_workflows" + listQueuedWorkflowsMessage messageType = "list_queued_workflows" + listStepsMessage messageType = "list_steps" + getWorkflowMessage messageType = "get_workflow" + forkWorkflowMessage messageType = "fork_workflow" + existPendingWorkflowsMessage messageType = "exist_pending_workflows" + retentionMessage messageType = "retention" +) + +// baseMessage represents the common structure of all conductor messages +type baseMessage struct { + Type messageType `json:"type"` + RequestID string `json:"request_id"` +} + +// baseResponse extends baseMessage with optional error handling +type baseResponse struct { + baseMessage + ErrorMessage *string `json:"error_message,omitempty"` +} + +// executorInfoRequest is sent by the conductor to request executor information +type executorInfoRequest struct { + baseMessage +} + +// executorInfoResponse is sent in response to executor info requests +type executorInfoResponse struct { + baseResponse + ExecutorID string `json:"executor_id"` + ApplicationVersion string `json:"application_version"` + Hostname *string `json:"hostname,omitempty"` +} + +// listWorkflowsConductorRequestBody contains filter parameters for listing workflows +type listWorkflowsConductorRequestBody struct { + WorkflowUUIDs []string `json:"workflow_uuids,omitempty"` + WorkflowName *string `json:"workflow_name,omitempty"` + AuthenticatedUser *string `json:"authenticated_user,omitempty"` + StartTime *time.Time `json:"start_time,omitempty"` + EndTime *time.Time `json:"end_time,omitempty"` + Status *string `json:"status,omitempty"` + ApplicationVersion *string `json:"application_version,omitempty"` + QueueName *string `json:"queue_name,omitempty"` + Limit *int `json:"limit,omitempty"` + Offset *int `json:"offset,omitempty"` + SortDesc bool `json:"sort_desc"` + LoadInput bool `json:"load_input"` + LoadOutput bool `json:"load_output"` +} + +// listWorkflowsConductorRequest is sent by the conductor to list workflows +type listWorkflowsConductorRequest struct { + baseMessage + Body listWorkflowsConductorRequestBody `json:"body"` +} + +// listWorkflowsConductorResponseBody represents a single workflow in the list response +type listWorkflowsConductorResponseBody struct { + WorkflowUUID string `json:"WorkflowUUID"` + Status *string `json:"Status,omitempty"` + WorkflowName *string `json:"WorkflowName,omitempty"` + WorkflowClassName *string `json:"WorkflowClassName,omitempty"` + WorkflowConfigName *string `json:"WorkflowConfigName,omitempty"` + AuthenticatedUser *string `json:"AuthenticatedUser,omitempty"` + AssumedRole *string `json:"AssumedRole,omitempty"` + AuthenticatedRoles *string `json:"AuthenticatedRoles,omitempty"` + Input *string `json:"Input,omitempty"` + Output *string `json:"Output,omitempty"` + Error *string `json:"Error,omitempty"` + CreatedAt *string `json:"CreatedAt,omitempty"` + UpdatedAt *string `json:"UpdatedAt,omitempty"` + QueueName *string `json:"QueueName,omitempty"` + ApplicationVersion *string `json:"ApplicationVersion,omitempty"` + ExecutorID *string `json:"ExecutorID,omitempty"` +} + +// listWorkflowsConductorResponse is sent in response to list workflows requests +type listWorkflowsConductorResponse struct { + baseResponse + Output []listWorkflowsConductorResponseBody `json:"output"` +} + +// formatListWorkflowsResponseBody converts WorkflowStatus to listWorkflowsConductorResponseBody for the conductor protocol +func formatListWorkflowsResponseBody(wf WorkflowStatus) listWorkflowsConductorResponseBody { + output := listWorkflowsConductorResponseBody{ + WorkflowUUID: wf.ID, + } + + // Convert status + if wf.Status != "" { + status := string(wf.Status) + output.Status = &status + } + + // Convert workflow name + if wf.Name != "" { + output.WorkflowName = &wf.Name + } + + // Copy optional fields + output.AuthenticatedUser = wf.AuthenticatedUser + output.AssumedRole = wf.AssumedRole + output.AuthenticatedRoles = wf.AuthenticatedRoles + + // Convert input/output to JSON strings if present + if wf.Input != nil { + inputJSON, err := json.Marshal(wf.Input) + if err == nil { + inputStr := string(inputJSON) + output.Input = &inputStr + } + } + if wf.Output != nil { + outputJSON, err := json.Marshal(wf.Output) + if err == nil { + outputStr := string(outputJSON) + output.Output = &outputStr + } + } + + // Convert error to string + if wf.Error != nil { + errorStr := wf.Error.Error() + output.Error = &errorStr + } + + // Convert timestamps to unix epochs + if !wf.CreatedAt.IsZero() { + createdStr := strconv.FormatInt(wf.CreatedAt.UnixMilli(), 10) + output.CreatedAt = &createdStr + } + if !wf.UpdatedAt.IsZero() { + updatedStr := strconv.FormatInt(wf.UpdatedAt.UnixMilli(), 10) + output.UpdatedAt = &updatedStr + } + + // Copy queue name + if wf.QueueName != "" { + output.QueueName = &wf.QueueName + } + + // Copy application version + if wf.ApplicationVersion != "" { + output.ApplicationVersion = &wf.ApplicationVersion + } + + // Copy executor ID + if wf.ExecutorID != "" { + output.ExecutorID = &wf.ExecutorID + } + + return output +} + +// listStepsConductorRequest is sent by the conductor to list workflow steps +type listStepsConductorRequest struct { + baseMessage + WorkflowID string `json:"workflow_id"` +} + +// workflowStepsConductorResponseBody represents a single workflow step in the list response +type workflowStepsConductorResponseBody struct { + FunctionID int `json:"function_id"` + FunctionName string `json:"function_name"` + Output *string `json:"output,omitempty"` + Error *string `json:"error,omitempty"` + ChildWorkflowID *string `json:"child_workflow_id,omitempty"` +} + +// listStepsConductorResponse is sent in response to list steps requests +type listStepsConductorResponse struct { + baseResponse + Output *[]workflowStepsConductorResponseBody `json:"output,omitempty"` +} + +// formatWorkflowStepsResponseBody converts stepInfo to workflowStepsConductorResponseBody for the conductor protocol +func formatWorkflowStepsResponseBody(step stepInfo) workflowStepsConductorResponseBody { + output := workflowStepsConductorResponseBody{ + FunctionID: step.StepID, + FunctionName: step.StepName, + } + + // Convert output to JSON string if present + if step.Output != nil { + outputJSON, err := json.Marshal(step.Output) + if err == nil { + outputStr := string(outputJSON) + output.Output = &outputStr + } + } + + // Convert error to string if present + if step.Error != nil { + errorStr := step.Error.Error() + output.Error = &errorStr + } + + // Set child workflow ID if present + if step.ChildWorkflowID != "" { + output.ChildWorkflowID = &step.ChildWorkflowID + } + + return output +} + +// getWorkflowConductorRequest is sent by the conductor to get a specific workflow +type getWorkflowConductorRequest struct { + baseMessage + WorkflowID string `json:"workflow_id"` +} + +// getWorkflowConductorResponse is sent in response to get workflow requests +type getWorkflowConductorResponse struct { + baseResponse + Output *listWorkflowsConductorResponseBody `json:"output,omitempty"` +} + +// forkWorkflowConductorRequestBody contains the fork workflow parameters +type forkWorkflowConductorRequestBody struct { + WorkflowID string `json:"workflow_id"` + StartStep int `json:"start_step"` + ApplicationVersion *string `json:"application_version,omitempty"` + NewWorkflowID *string `json:"new_workflow_id,omitempty"` +} + +// forkWorkflowConductorRequest is sent by the conductor to fork a workflow +type forkWorkflowConductorRequest struct { + baseMessage + Body forkWorkflowConductorRequestBody `json:"body"` +} + +// forkWorkflowConductorResponse is sent in response to fork workflow requests +type forkWorkflowConductorResponse struct { + baseResponse + NewWorkflowID *string `json:"new_workflow_id,omitempty"` +} + +// cancelWorkflowConductorRequest is sent by the conductor to cancel a workflow +type cancelWorkflowConductorRequest struct { + baseMessage + WorkflowID string `json:"workflow_id"` +} + +// cancelWorkflowConductorResponse is sent in response to cancel workflow requests +type cancelWorkflowConductorResponse struct { + baseResponse + Success bool `json:"success"` +} + +// recoveryConductorRequest is sent by the conductor to request recovery of pending workflows +type recoveryConductorRequest struct { + baseMessage + ExecutorIDs []string `json:"executor_ids"` +} + +// recoveryConductorResponse is sent in response to recovery requests +type recoveryConductorResponse struct { + baseResponse + Success bool `json:"success"` +} + +// existPendingWorkflowsConductorRequest is sent by the conductor to check for pending workflows +type existPendingWorkflowsConductorRequest struct { + baseMessage + ExecutorID string `json:"executor_id"` + ApplicationVersion string `json:"application_version"` +} + +// existPendingWorkflowsConductorResponse is sent in response to exist pending workflows requests +type existPendingWorkflowsConductorResponse struct { + baseResponse + Exist bool `json:"exist"` +} + +// resumeWorkflowConductorRequest is sent by the conductor to resume a workflow +type resumeWorkflowConductorRequest struct { + baseMessage + WorkflowID string `json:"workflow_id"` +} + +// resumeWorkflowConductorResponse is sent in response to resume workflow requests +type resumeWorkflowConductorResponse struct { + baseResponse + Success bool `json:"success"` +} + +// retentionConductorRequestBody contains retention policy parameters +type retentionConductorRequestBody struct { + GCCutoffEpochMs *int `json:"gc_cutoff_epoch_ms,omitempty"` + GCRowsThreshold *int `json:"gc_rows_threshold,omitempty"` + TimeoutCutoffEpochMs *int `json:"timeout_cutoff_epoch_ms,omitempty"` +} + +// retentionConductorRequest is sent by the conductor to enforce retention policies +type retentionConductorRequest struct { + baseMessage + Body retentionConductorRequestBody `json:"body"` +} + +// retentionConductorResponse is sent in response to retention requests +type retentionConductorResponse struct { + baseResponse + Success bool `json:"success"` +} diff --git a/dbos/conductor_test.go b/dbos/conductor_test.go new file mode 100644 index 00000000..9377cbf3 --- /dev/null +++ b/dbos/conductor_test.go @@ -0,0 +1,666 @@ +package dbos + +import ( + "context" + "fmt" + "log/slog" + "net/http" + "net/http/httptest" + "os" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/goleak" +) + +// writeCommand represents a command to write to the WebSocket connection +type writeCommand struct { + messageType int + data []byte + response chan error // Channel to send back the result +} + +// mockWebSocketServer provides a controllable WebSocket server for testing +type mockWebSocketServer struct { + server *httptest.Server + upgrader websocket.Upgrader + connMu sync.Mutex // Only for connection assignment/reassignment + conn *websocket.Conn + closed atomic.Bool + messages chan []byte + pings chan struct{} + writeCmds chan writeCommand // Channel for write commands + stopHandler chan struct{} + ignorePings atomic.Bool // When true, don't respond with pongs +} + +func newMockWebSocketServer() *mockWebSocketServer { + m := &mockWebSocketServer{ + upgrader: websocket.Upgrader{}, + messages: make(chan []byte, 100), + pings: make(chan struct{}, 100), + writeCmds: make(chan writeCommand, 10), + stopHandler: make(chan struct{}), + } + + m.server = httptest.NewServer(http.HandlerFunc(m.handleWebSocket)) + return m +} + +func (m *mockWebSocketServer) handleWebSocket(w http.ResponseWriter, r *http.Request) { + // Check if we're closed + if m.closed.Load() { + http.Error(w, "Server closed", http.StatusServiceUnavailable) + return + } + + conn, err := m.upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + + // Connection assignment - this is the only place we need mutex + m.connMu.Lock() + // Close any existing connection + if m.conn != nil { + m.conn.Close() + } + m.conn = conn + m.connMu.Unlock() + + // Ensure the connection gets cleared when this handler exits + defer func() { + m.connMu.Lock() + if m.conn == conn { + m.conn = nil + } + m.connMu.Unlock() + conn.Close() + }() + + // Handle connection lifecycle - this function owns all I/O on conn + + // We need to handle pings manually since we can't use the ping handler + // (it would cause concurrent writes with our main loop) + pingReceived := make(chan struct{}, 10) + + // Custom ping handler that just signals - no writing + conn.SetPingHandler(func(string) error { + select { + case m.pings <- struct{}{}: + default: + } + select { + case pingReceived <- struct{}{}: + default: + } + return nil + }) + + // Start dedicated read goroutine - only reads, never writes + readDone := make(chan error, 1) + go func() { + defer close(readDone) + for { + _, _, err := conn.ReadMessage() + if err != nil { + fmt.Printf("WebSocket read error: %v\n", err) + readDone <- err + return + } + } + }() + + // Main write loop - all writes happen here sequentially + for { + select { + case <-m.stopHandler: + fmt.Println("WebSocket connection closed by stop signal") + return + + case err := <-readDone: + fmt.Printf("WebSocket connection closed by read error: %v\n", err) + return + + case writeCmd := <-m.writeCmds: + // Handle write command + err := conn.WriteMessage(writeCmd.messageType, writeCmd.data) + if writeCmd.response != nil { + select { + case writeCmd.response <- err: + default: + } + } + if err != nil { + fmt.Printf("WebSocket write error: %v\n", err) + return + } + + case <-pingReceived: + // Handle ping response (send pong) + if !m.ignorePings.Load() { + err := conn.WriteMessage(websocket.PongMessage, nil) + if err != nil { + fmt.Printf("WebSocket pong write error: %v\n", err) + return + } + } + } + } +} + +func (m *mockWebSocketServer) getURL() string { + return "ws" + strings.TrimPrefix(m.server.URL, "http") +} + +func (m *mockWebSocketServer) close() { + m.closed.Store(true) + + // Signal handler to stop but don't block + select { + case m.stopHandler <- struct{}{}: + default: + } +} + +func (m *mockWebSocketServer) shutdown() { + m.close() + m.server.Close() +} + +func (m *mockWebSocketServer) restart() { + // Reset for new connections + m.closed.Store(false) + // Drain stop handler channel and write command channel + select { + case <-m.stopHandler: + default: + } + // Drain any pending write commands +drainLoop: + for { + select { + case cmd := <-m.writeCmds: + if cmd.response != nil { + select { + case cmd.response <- fmt.Errorf("server restarting"): + default: + } + } + default: + break drainLoop + } + } +} + +func (m *mockWebSocketServer) waitForConnection(timeout time.Duration) bool { + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + m.connMu.Lock() + hasConn := m.conn != nil + m.connMu.Unlock() + if hasConn { + return true + } + time.Sleep(10 * time.Millisecond) + } + return false +} + +// sendBinaryMessage sends a binary WebSocket message to the connected client +func (m *mockWebSocketServer) sendBinaryMessage(data []byte) error { + // Check if we have a connection without blocking + m.connMu.Lock() + hasConn := m.conn != nil + m.connMu.Unlock() + + if !hasConn { + return fmt.Errorf("no connection") + } + + // Send write command via channel + response := make(chan error, 1) + cmd := writeCommand{ + messageType: websocket.BinaryMessage, + data: data, + response: response, + } + + select { + case m.writeCmds <- cmd: + // Wait for response + select { + case err := <-response: + return err + case <-time.After(1 * time.Second): + return fmt.Errorf("write timeout") + } + case <-time.After(1 * time.Second): + return fmt.Errorf("write command queue full") + } +} + +// sendCloseMessage sends a WebSocket close message with specified code and reason +func (m *mockWebSocketServer) sendCloseMessage(code int, text string) error { + // Check if we have a connection without blocking + m.connMu.Lock() + hasConn := m.conn != nil + m.connMu.Unlock() + + if !hasConn { + return fmt.Errorf("no connection") + } + + // Format close message + message := websocket.FormatCloseMessage(code, text) + + // Send write command via channel + response := make(chan error, 1) + cmd := writeCommand{ + messageType: websocket.CloseMessage, + data: message, + response: response, + } + + select { + case m.writeCmds <- cmd: + // Wait for response + select { + case err := <-response: + // After sending close, close the connection from our side too + m.connMu.Lock() + if m.conn != nil { + m.conn.Close() + m.conn = nil + } + m.connMu.Unlock() + return err + case <-time.After(1 * time.Second): + return fmt.Errorf("write timeout") + } + case <-time.After(1 * time.Second): + return fmt.Errorf("write command queue full") + } +} + +// TestConductorReconnection tests various reconnection scenarios for the conductor +func TestConductorReconnection(t *testing.T) { + t.Run("ServerRestart", func(t *testing.T) { + defer goleak.VerifyNone(t) + + // Create and start mock server + mockServer := newMockWebSocketServer() + defer mockServer.shutdown() + + // Create conductor config + config := ConductorConfig{ + url: mockServer.getURL(), + apiKey: "test-key", + appName: "test-app", + } + + // Create context with timeout + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // Create dbosContext + dbosCtx := &dbosContext{ + ctx: ctx, + logger: slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug})), + } + + // Create conductor + conductor, err := NewConductor(dbosCtx, config) + require.NoError(t, err) + + // Speed up intervals for testing + conductor.pingInterval = 100 * time.Millisecond + conductor.pingTimeout = 200 * time.Millisecond + conductor.reconnectWait = 100 * time.Millisecond + + // Launch conductor + conductor.Launch() + + // Wait for initial connection + assert.True(t, mockServer.waitForConnection(5*time.Second), "Should establish initial connection") + + // Collect initial pings + initialPings := 0 + timeout := time.After(1 * time.Second) + collectInitialPings: + for { + select { + case <-mockServer.pings: + initialPings++ + case <-timeout: + break collectInitialPings + } + } + assert.Greater(t, initialPings, 0, "Should receive initial pings") + fmt.Printf("Received %d initial pings\n", initialPings) + + // Close the server connection (simulate disconnect) + fmt.Println("Closing server connection") + mockServer.close() + + // Wait a bit for conductor to notice and start reconnecting + time.Sleep(500 * time.Millisecond) + + // Restart the server + fmt.Println("Restarting server") + mockServer.restart() + + // Wait for reconnection + assert.True(t, mockServer.waitForConnection(10*time.Second), "Should reconnect after server restart") + + // Collect pings after reconnection + reconnectPings := 0 + timeout2 := time.After(1 * time.Second) + collectReconnectPings: + for { + select { + case <-mockServer.pings: + reconnectPings++ + case <-timeout2: + break collectReconnectPings + } + } + assert.Greater(t, reconnectPings, 0, "Should receive pings after reconnection") + t.Logf("Received %d pings after reconnection", reconnectPings) + + // Cancel the context to trigger shutdown + cancel() + + // Give conductor time to clean up + time.Sleep(500 * time.Millisecond) + }) + + t.Run("TestBinaryMessage", func(t *testing.T) { + defer goleak.VerifyNone(t) + + // Create and start mock server + mockServer := newMockWebSocketServer() + defer mockServer.shutdown() + + // Create conductor config + config := ConductorConfig{ + url: mockServer.getURL(), + apiKey: "test-key", + appName: "test-app", + } + + // Create context with timeout + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // Create dbosContext + dbosCtx := &dbosContext{ + ctx: ctx, + logger: slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug})), + } + + // Create conductor + conductor, err := NewConductor(dbosCtx, config) + require.NoError(t, err) + + // Speed up intervals for testing + conductor.pingInterval = 100 * time.Millisecond + conductor.pingTimeout = 200 * time.Millisecond + conductor.reconnectWait = 100 * time.Millisecond + + // Launch conductor + conductor.Launch() + + // Wait for initial connection + assert.True(t, mockServer.waitForConnection(5*time.Second), "Should establish initial connection") + + // Collect initial pings + initialPings := 0 + timeout := time.After(1 * time.Second) + collectInitialPings: + for { + select { + case <-mockServer.pings: + initialPings++ + case <-timeout: + break collectInitialPings + } + } + assert.Greater(t, initialPings, 0, "Should receive initial pings") + fmt.Printf("Received %d initial pings\n", initialPings) + + // Send binary message - conductor should disconnect and reconnect + fmt.Println("Sending binary message to trigger disconnect") + err = mockServer.sendBinaryMessage([]byte{0xDE, 0xAD, 0xBE, 0xEF}) + assert.NoError(t, err, "Should send binary message successfully") + + // Wait a bit for conductor to process the message and disconnect + time.Sleep(200 * time.Millisecond) + + // Wait for reconnection after binary message + assert.True(t, mockServer.waitForConnection(10*time.Second), "Should reconnect after receiving binary message") + + // Collect pings after reconnection + reconnectPings := 0 + timeout2 := time.After(1 * time.Second) + collectReconnectPings: + for { + select { + case <-mockServer.pings: + reconnectPings++ + case <-timeout2: + break collectReconnectPings + } + } + assert.Greater(t, reconnectPings, 0, "Should receive pings after reconnection from binary message") + t.Logf("Received %d pings after reconnection from binary message", reconnectPings) + + // Cancel the context to trigger shutdown + cancel() + + // Give conductor time to clean up + time.Sleep(500 * time.Millisecond) + }) + + // TestConductorPingTimeout tests that conductor reconnects when server stops responding to pings + t.Run("TestConductorPingTimeout", func(t *testing.T) { + defer goleak.VerifyNone(t) + + // Create and start mock server + mockServer := newMockWebSocketServer() + defer mockServer.shutdown() + + // Create conductor config + config := ConductorConfig{ + url: mockServer.getURL(), + apiKey: "test-key", + appName: "test-app", + } + + // Create context with timeout + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // Create dbosContext + dbosCtx := &dbosContext{ + ctx: ctx, + logger: slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug})), + } + + // Create conductor + conductor, err := NewConductor(dbosCtx, config) + require.NoError(t, err) + + // Speed up intervals for testing + conductor.pingInterval = 100 * time.Millisecond + conductor.pingTimeout = 200 * time.Millisecond + conductor.reconnectWait = 100 * time.Millisecond + + // Launch conductor + conductor.Launch() + + // Wait for initial connection + assert.True(t, mockServer.waitForConnection(5*time.Second), "Should establish initial connection") + + // Collect initial pings + initialPings := 0 + timeout := time.After(1 * time.Second) + collectInitialPings: + for { + select { + case <-mockServer.pings: + initialPings++ + case <-timeout: + break collectInitialPings + } + } + assert.Greater(t, initialPings, 0, "Should receive initial pings") + fmt.Printf("Received %d initial pings\n", initialPings) + + // Tell server to stop responding to pings (no pongs) + fmt.Println("Server stopping pong responses") + mockServer.ignorePings.Store(true) + + // Wait for conductor to detect the dead connection (should timeout after pingTimeout) + // Conductor should detect no pong response and close the connection + // This will cause the handler to exit when ReadMessage fails + time.Sleep(conductor.pingTimeout + 100*time.Millisecond) + + // Resume responding to pings after timeout + // This allows the new connection handler to respond properly + fmt.Println("Server resuming pong responses") + mockServer.ignorePings.Store(false) + + // Wait for reconnection + assert.True(t, mockServer.waitForConnection(10*time.Second), "Should reconnect after ping timeout") + + // Collect pings after reconnection + reconnectPings := 0 + timeout2 := time.After(1 * time.Second) + collectReconnectPings: + for { + select { + case <-mockServer.pings: + reconnectPings++ + case <-timeout2: + break collectReconnectPings + } + } + assert.Greater(t, reconnectPings, 0, "Should receive pings after reconnection") + t.Logf("Received %d pings after reconnection", reconnectPings) + + // Cancel the context to trigger shutdown + cancel() + + // Give conductor time to clean up + time.Sleep(500 * time.Millisecond) + }) + + t.Run("CloseMessages", func(t *testing.T) { + defer goleak.VerifyNone(t) + + // Create and start mock server + mockServer := newMockWebSocketServer() + defer mockServer.shutdown() + + // Create conductor config + config := ConductorConfig{ + url: mockServer.getURL(), + apiKey: "test-key", + appName: "test-app", + } + + // Create context with timeout + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // Create dbosContext + dbosCtx := &dbosContext{ + ctx: ctx, + logger: slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug})), + } + + // Create conductor + conductor, err := NewConductor(dbosCtx, config) + require.NoError(t, err) + + // Speed up intervals for testing + conductor.pingInterval = 100 * time.Millisecond + conductor.pingTimeout = 200 * time.Millisecond + conductor.reconnectWait = 100 * time.Millisecond + + // Launch conductor + conductor.Launch() + + // Wait for initial connection + assert.True(t, mockServer.waitForConnection(5*time.Second), "Should establish initial connection") + + // Test close message codes that should trigger reconnection + testCases := []struct { + code int + reason string + name string + }{ + {websocket.CloseGoingAway, "server going away", "CloseGoingAway"}, + {websocket.CloseAbnormalClosure, "abnormal closure", "CloseAbnormalClosure"}, + } + + for _, tc := range testCases { + t.Logf("Testing %s (code %d)", tc.name, tc.code) + + // Wait for stable connection before testing + assert.True(t, mockServer.waitForConnection(5*time.Second), "Should have stable connection before %s", tc.name) + time.Sleep(300 * time.Millisecond) // Give time for ping cycle to establish + + // Collect pings before sending close message + beforePings := 0 + timeout := time.After(200 * time.Millisecond) + collectBeforePings: + for { + select { + case <-mockServer.pings: + beforePings++ + case <-timeout: + break collectBeforePings + } + } + assert.Greater(t, beforePings, 0, "Should receive pings before %s", tc.name) + + // Send close message + err = mockServer.sendCloseMessage(tc.code, tc.reason) + assert.NoError(t, err, "Should send %s close message successfully", tc.name) + + // Wait for conductor to process and reconnect + time.Sleep(300 * time.Millisecond) + + // Wait for reconnection + assert.True(t, mockServer.waitForConnection(10*time.Second), "Should reconnect after %s", tc.name) + + // Verify pings after reconnection + afterPings := 0 + timeout2 := time.After(200 * time.Millisecond) + collectAfterPings: + for { + select { + case <-mockServer.pings: + afterPings++ + case <-timeout2: + break collectAfterPings + } + } + assert.Greater(t, afterPings, 0, "Should receive pings after reconnection from %s", tc.name) + } + + // Cancel the context to trigger shutdown + cancel() + + // Give conductor time to clean up + time.Sleep(500 * time.Millisecond) + }) +} diff --git a/dbos/dbos.go b/dbos/dbos.go index 8b08bfad..12c23f46 100644 --- a/dbos/dbos.go +++ b/dbos/dbos.go @@ -20,20 +20,24 @@ import ( "sync/atomic" "time" + "github.com/google/uuid" "github.com/robfig/cron/v3" ) const ( _DEFAULT_ADMIN_SERVER_PORT = 3001 + _DBOS_DOMAIN = "cloud.dbos.dev" ) // Config holds configuration parameters for initializing a DBOS context. // DatabaseURL and AppName are required. type Config struct { - DatabaseURL string // PostgreSQL connection string (required) - AppName string // Application name for identification (required) - Logger *slog.Logger // Custom logger instance (defaults to a new slog logger) - AdminServer bool // Enable Transact admin HTTP server + DatabaseURL string // PostgreSQL connection string (required) + AppName string // Application name for identification (required) + Logger *slog.Logger // Custom logger instance (defaults to a new slog logger) + AdminServer bool // Enable Transact admin HTTP server (disabled by default) + ConductorURL string // DBOS conductor service URL (optional) + ConductorAPIKey string // DBOS conductor API key (optional) } // processConfig enforces mandatory fields and applies defaults. @@ -47,10 +51,12 @@ func processConfig(inputConfig *Config) (*Config, error) { } dbosConfig := &Config{ - DatabaseURL: inputConfig.DatabaseURL, - AppName: inputConfig.AppName, - Logger: inputConfig.Logger, - AdminServer: inputConfig.AdminServer, + DatabaseURL: inputConfig.DatabaseURL, + AppName: inputConfig.AppName, + Logger: inputConfig.Logger, + AdminServer: inputConfig.AdminServer, + ConductorURL: inputConfig.ConductorURL, + ConductorAPIKey: inputConfig.ConductorAPIKey, } // Load defaults @@ -112,6 +118,9 @@ type dbosContext struct { // Queue runner queueRunner *queueRunner + // Conductor client + conductor *Conductor + // Application metadata applicationVersion string applicationID string @@ -277,6 +286,7 @@ func NewDBOSContext(inputConfig Config) (DBOSContext, error) { // Set global logger initExecutor.logger = config.Logger + initExecutor.logger.Info("Initializing DBOS context", "app_name", config.AppName) // Register types we serialize with gob var t time.Time @@ -286,23 +296,15 @@ func NewDBOSContext(inputConfig Config) (DBOSContext, error) { initExecutor.applicationVersion = os.Getenv("DBOS__APPVERSION") if initExecutor.applicationVersion == "" { initExecutor.applicationVersion = computeApplicationVersion() - initExecutor.logger.Info("DBOS__APPVERSION not set, using computed hash") } initExecutor.executorID = os.Getenv("DBOS__VMID") if initExecutor.executorID == "" { initExecutor.executorID = "local" - initExecutor.logger.Info("DBOS__VMID not set, using default", "executor_id", initExecutor.executorID) } initExecutor.applicationID = os.Getenv("DBOS__APPID") - initExecutor.logger = initExecutor.logger.With( - //"app_version", initExecutor.applicationVersion, // This is really verbose... - "executor_id", initExecutor.executorID, - //"app_id", initExecutor.applicationID, // This should stay internal - ) - // Create the system database systemDB, err := newSystemDatabase(initExecutor, config.DatabaseURL, initExecutor.logger) if err != nil { @@ -312,9 +314,32 @@ func NewDBOSContext(inputConfig Config) (DBOSContext, error) { initExecutor.logger.Info("System database initialized") // Initialize the queue runner and register DBOS internal queue - initExecutor.queueRunner = newQueueRunner() + initExecutor.queueRunner = newQueueRunner(initExecutor.logger) NewWorkflowQueue(initExecutor, _DBOS_INTERNAL_QUEUE_NAME) + // Initialize conductor if API key is provided + if config.ConductorAPIKey != "" { + initExecutor.executorID = uuid.NewString() + if config.ConductorURL == "" { + dbosDomain := os.Getenv("DBOS_DOMAIN") + if dbosDomain == "" { + dbosDomain = _DBOS_DOMAIN + } + config.ConductorURL = fmt.Sprintf("wss://%s/conductor/v1alpha1", dbosDomain) + } + conductorConfig := ConductorConfig{ + url: config.ConductorURL, + apiKey: config.ConductorAPIKey, + appName: config.AppName, + } + conductor, err := NewConductor(initExecutor, conductorConfig) + if err != nil { + return nil, newInitializationError(fmt.Sprintf("failed to initialize conductor: %v", err)) + } + initExecutor.conductor = conductor + initExecutor.logger.Info("Conductor initialized") + } + return initExecutor, nil } @@ -356,6 +381,12 @@ func (c *dbosContext) Launch() error { c.logger.Info("Workflow scheduler started") } + // Start the conductor if it has been initialized + if c.conductor != nil { + c.conductor.Launch() + c.logger.Info("Conductor started") + } + // Run a round of recovery on the local executor recoveryHandles, err := recoverPendingWorkflows(c, []string{c.executorID}) if err != nil { @@ -379,8 +410,9 @@ func (c *dbosContext) Launch() error { // 2. Waits for the queue runner to complete processing // 3. Stops the workflow scheduler and waits for scheduled jobs to finish // 4. Shuts down the system database connection pool and notification listener -// 5. Shuts down the admin server -// 6. Marks the context as not launched +// 5. Shuts down conductor +// 6. Shuts down the admin server +// 7. Marks the context as not launched // // Each step respects the provided timeout. If any component doesn't shut down within the timeout, // a warning is logged and the shutdown continues to the next component. @@ -413,7 +445,6 @@ func (c *dbosContext) Shutdown(timeout time.Duration) { select { case <-c.queueRunner.completionChan: c.logger.Info("Queue runner completed") - c.queueRunner = nil case <-time.After(timeout): c.logger.Warn("Timeout waiting for queue runner to complete", "timeout", timeout) } @@ -433,6 +464,12 @@ func (c *dbosContext) Shutdown(timeout time.Duration) { } } + // Shutdown the conductor + if c.conductor != nil { + c.logger.Info("Shutting down conductor") + c.conductor.Shutdown(timeout) + } + // Shutdown the admin server if c.adminServer != nil && c.launched.Load() { c.logger.Info("Shutting down admin server") @@ -442,14 +479,12 @@ func (c *dbosContext) Shutdown(timeout time.Duration) { } else { c.logger.Info("Admin server shutdown complete") } - c.adminServer = nil } // Close the system database if c.systemDB != nil { c.logger.Info("Shutting down system database") c.systemDB.shutdown(c, timeout) - c.systemDB = nil } c.launched.Store(false) diff --git a/dbos/dbos_test.go b/dbos/dbos_test.go index 98ef59b4..16d7e685 100644 --- a/dbos/dbos_test.go +++ b/dbos/dbos_test.go @@ -22,7 +22,7 @@ func TestConfigValidationErrorTypes(t *testing.T) { require.NoError(t, err) defer func() { if ctx != nil { - ctx.Shutdown(1*time.Minute) + ctx.Shutdown(1 * time.Minute) } }() // Clean up executor diff --git a/dbos/logger_test.go b/dbos/logger_test.go index dd305dbd..5ee5787e 100644 --- a/dbos/logger_test.go +++ b/dbos/logger_test.go @@ -23,7 +23,7 @@ func TestLogger(t *testing.T) { require.NoError(t, err) t.Cleanup(func() { if dbosCtx != nil { - dbosCtx.Shutdown(10*time.Second) + dbosCtx.Shutdown(10 * time.Second) } }) @@ -56,7 +56,7 @@ func TestLogger(t *testing.T) { require.NoError(t, err) t.Cleanup(func() { if dbosCtx != nil { - dbosCtx.Shutdown(10*time.Second) + dbosCtx.Shutdown(10 * time.Second) } }) diff --git a/dbos/queue.go b/dbos/queue.go index 290ba7f3..20825a7d 100644 --- a/dbos/queue.go +++ b/dbos/queue.go @@ -5,6 +5,7 @@ import ( "context" "encoding/base64" "encoding/gob" + "log/slog" "math" "math/rand" "time" @@ -132,6 +133,8 @@ func NewWorkflowQueue(dbosCtx DBOSContext, name string, options ...QueueOption) } type queueRunner struct { + logger *slog.Logger + // Queue runner iteration parameters baseInterval float64 minInterval float64 @@ -148,7 +151,7 @@ type queueRunner struct { completionChan chan struct{} } -func newQueueRunner() *queueRunner { +func newQueueRunner(logger *slog.Logger) *queueRunner { return &queueRunner{ baseInterval: 1.0, minInterval: 1.0, @@ -159,6 +162,7 @@ func newQueueRunner() *queueRunner { jitterMax: 1.05, workflowQueueRegistry: make(map[string]WorkflowQueue), completionChan: make(chan struct{}, 1), + logger: logger.With("service", "queue_runner"), } } @@ -193,31 +197,31 @@ func (qr *queueRunner) run(ctx *dbosContext) { hasBackoffError = true } } else { - ctx.logger.Error("Error dequeuing workflows from queue", "queue_name", queueName, "error", err) + qr.logger.Error("Error dequeuing workflows from queue", "queue_name", queueName, "error", err) } continue } if len(dequeuedWorkflows) > 0 { - ctx.logger.Debug("Dequeued workflows from queue", "queue_name", queueName, "workflows", dequeuedWorkflows) + qr.logger.Debug("Dequeued workflows from queue", "queue_name", queueName, "workflows", dequeuedWorkflows) } for _, workflow := range dequeuedWorkflows { // Find the workflow in the registry wfName, ok := ctx.workflowCustomNametoFQN.Load(workflow.name) if !ok { - ctx.logger.Error("Workflow not found in registry", "workflow_name", workflow.name) + qr.logger.Error("Workflow not found in registry", "workflow_name", workflow.name) continue } registeredWorkflowAny, exists := ctx.workflowRegistry.Load(wfName.(string)) if !exists { - ctx.logger.Error("workflow function not found in registry", "workflow_name", workflow.name) + qr.logger.Error("workflow function not found in registry", "workflow_name", workflow.name) continue } registeredWorkflow, ok := registeredWorkflowAny.(workflowRegistryEntry) if !ok { - ctx.logger.Error("invalid workflow registry entry type", "workflow_name", workflow.name) + qr.logger.Error("invalid workflow registry entry type", "workflow_name", workflow.name) continue } @@ -226,20 +230,20 @@ func (qr *queueRunner) run(ctx *dbosContext) { if len(workflow.input) > 0 { inputBytes, err := base64.StdEncoding.DecodeString(workflow.input) if err != nil { - ctx.logger.Error("failed to decode input for workflow", "workflow_id", workflow.id, "error", err) + qr.logger.Error("failed to decode input for workflow", "workflow_id", workflow.id, "error", err) continue } buf := bytes.NewBuffer(inputBytes) dec := gob.NewDecoder(buf) if err := dec.Decode(&input); err != nil { - ctx.logger.Error("failed to decode input for workflow", "workflow_id", workflow.id, "error", err) + qr.logger.Error("failed to decode input for workflow", "workflow_id", workflow.id, "error", err) continue } } _, err := registeredWorkflow.wrappedFunction(ctx, input, WithWorkflowID(workflow.id)) if err != nil { - ctx.logger.Error("Error running queued workflow", "error", err) + qr.logger.Error("Error running queued workflow", "error", err) } } } @@ -260,7 +264,7 @@ func (qr *queueRunner) run(ctx *dbosContext) { // Sleep with jittered interval, but allow early exit on context cancellation select { case <-ctx.Done(): - ctx.logger.Info("Queue runner stopping due to context cancellation", "cause", context.Cause(ctx)) + qr.logger.Info("Queue runner stopping due to context cancellation", "cause", context.Cause(ctx)) qr.completionChan <- struct{}{} return case <-time.After(sleepDuration): diff --git a/dbos/queues_test.go b/dbos/queues_test.go index 8b28fe6d..f5315431 100644 --- a/dbos/queues_test.go +++ b/dbos/queues_test.go @@ -5,6 +5,8 @@ import ( "errors" "fmt" "os" + "reflect" + "runtime" "sync" "sync/atomic" "testing" @@ -63,6 +65,29 @@ func TestWorkflowQueues(t *testing.T) { // Register workflows with dbosContext RegisterWorkflow(dbosCtx, queueWorkflow) + // Custom name workflows + queueWorkflowCustomName := func(ctx DBOSContext, input string) (string, error) { + return input, nil + } + RegisterWorkflow(dbosCtx, queueWorkflowCustomName, WithWorkflowName("custom-name")) + + queueWorkflowCustomNameEnqueingAnotherCustomNameWorkflow := func(ctx DBOSContext, input string) (string, error) { + // Start a child workflow + childHandle, err := RunAsWorkflow(ctx, queueWorkflowCustomName, input+"-enqueued", WithQueue(queue.Name)) + if err != nil { + return "", fmt.Errorf("failed to start child workflow: %v", err) + } + + // Get result from child workflow + childResult, err := childHandle.GetResult() + if err != nil { + return "", fmt.Errorf("failed to get child result: %v", err) + } + + return childResult, nil + } + RegisterWorkflow(dbosCtx, queueWorkflowCustomNameEnqueingAnotherCustomNameWorkflow, WithWorkflowName("custom-name-enqueuing")) + // Queue deduplication test workflows var dedupWorkflowEvent *Event childWorkflow := func(ctx DBOSContext, var1 string) (string, error) { @@ -133,6 +158,24 @@ func TestWorkflowQueues(t *testing.T) { } RegisterWorkflow(dbosCtx, enqueueWorkflowDLQ, WithMaxRetries(dlqMaxRetries)) + // Create a workflow that enqueues another workflow to test step tracking + workflowEnqueuesAnother := func(ctx DBOSContext, input string) (string, error) { + // Enqueue a child workflow + childHandle, err := RunAsWorkflow(ctx, queueWorkflow, input+"-child", WithQueue(queue.Name)) + if err != nil { + return "", fmt.Errorf("failed to enqueue child workflow: %v", err) + } + + // Get result from the child workflow + childResult, err := childHandle.GetResult() + if err != nil { + return "", fmt.Errorf("failed to get child result: %v", err) + } + + return childResult, nil + } + RegisterWorkflow(dbosCtx, workflowEnqueuesAnother) + err := dbosCtx.Launch() require.NoError(t, err) @@ -147,6 +190,26 @@ func TestWorkflowQueues(t *testing.T) { require.NoError(t, err) assert.Equal(t, "test-input", res) + // List steps: the workflow should have 1 step + steps, err := dbosCtx.(*dbosContext).systemDB.getWorkflowSteps(dbosCtx, handle.GetWorkflowID()) + require.NoError(t, err) + assert.Len(t, steps, 1) + assert.Equal(t, 0, steps[0].StepID) + + require.True(t, queueEntriesAreCleanedUp(dbosCtx), "expected queue entries to be cleaned up after global concurrency test") + }) + + t.Run("EnqueueWorkflowCustomName", func(t *testing.T) { + handle, err := RunAsWorkflow(dbosCtx, queueWorkflowCustomName, "test-input", WithQueue(queue.Name)) + require.NoError(t, err) + + _, ok := handle.(*workflowPollingHandle[string]) + require.True(t, ok, "expected handle to be of type workflowPollingHandle, got %T", handle) + + res, err := handle.GetResult() + require.NoError(t, err) + assert.Equal(t, "test-input", res) + require.True(t, queueEntriesAreCleanedUp(dbosCtx), "expected queue entries to be cleaned up after global concurrency test") }) @@ -161,10 +224,19 @@ func TestWorkflowQueues(t *testing.T) { expectedResult := "test-input-child" assert.Equal(t, expectedResult, res) + // List steps: the workflow should have 2 steps (Start the child and GetResult) + steps, err := dbosCtx.(*dbosContext).systemDB.getWorkflowSteps(dbosCtx, handle.GetWorkflowID()) + require.NoError(t, err) + assert.Len(t, steps, 2) + assert.Equal(t, runtime.FuncForPC(reflect.ValueOf(queueWorkflow).Pointer()).Name(), steps[0].StepName) + assert.Equal(t, 0, steps[0].StepID) + assert.Equal(t, "DBOS.getResult", steps[1].StepName) + assert.Equal(t, 1, steps[1].StepID) + require.True(t, queueEntriesAreCleanedUp(dbosCtx), "expected queue entries to be cleaned up after global concurrency test") }) - t.Run("WorkflowEnqueuesAnotherWorkflow", func(t *testing.T) { + t.Run("WorkflowEnqueuesAnother", func(t *testing.T) { handle, err := RunAsWorkflow(dbosCtx, queueWorkflowThatEnqueues, "test-input", WithQueue(queue.Name)) require.NoError(t, err) @@ -175,9 +247,67 @@ func TestWorkflowQueues(t *testing.T) { expectedResult := "test-input-enqueued" assert.Equal(t, expectedResult, res) + // List steps: the workflow should have 2 steps (Start the child and GetResult) + steps, err := dbosCtx.(*dbosContext).systemDB.getWorkflowSteps(dbosCtx, handle.GetWorkflowID()) + require.NoError(t, err) + assert.Len(t, steps, 2) + assert.Equal(t, runtime.FuncForPC(reflect.ValueOf(queueWorkflow).Pointer()).Name(), steps[0].StepName) + assert.Equal(t, 0, steps[0].StepID) + assert.Equal(t, "DBOS.getResult", steps[1].StepName) + assert.Equal(t, 1, steps[1].StepID) + require.True(t, queueEntriesAreCleanedUp(dbosCtx), "expected queue entries to be cleaned up after global concurrency test") }) + t.Run("CustomNameWorkflowEnqueuesAnotherCustomNameWorkflow", func(t *testing.T) { + handle, err := RunAsWorkflow(dbosCtx, queueWorkflowCustomNameEnqueingAnotherCustomNameWorkflow, "test-input", WithQueue(queue.Name)) + require.NoError(t, err) + + res, err := handle.GetResult() + require.NoError(t, err) + + // Expected result: enqueued workflow returns "test-input-enqueued" + expectedResult := "test-input-enqueued" + assert.Equal(t, expectedResult, res) + + // List steps: the workflow should have 2 steps (Start the child and GetResult) + steps, err := dbosCtx.(*dbosContext).systemDB.getWorkflowSteps(dbosCtx, handle.GetWorkflowID()) + require.NoError(t, err) + assert.Len(t, steps, 2) + assert.Equal(t, "custom-name", steps[0].StepName) + assert.Equal(t, 0, steps[0].StepID) + assert.Equal(t, "DBOS.getResult", steps[1].StepName) + assert.Equal(t, 1, steps[1].StepID) + + require.True(t, queueEntriesAreCleanedUp(dbosCtx), "expected queue entries to be cleaned up after global concurrency test") + }) + + t.Run("EnqueuedWorkflowEnqueuesAnother", func(t *testing.T) { + // Run the pre-registered workflow that enqueues another workflow + // Enqueue the parent workflow to a queue + handle, err := RunAsWorkflow(dbosCtx, workflowEnqueuesAnother, "test-input", WithQueue(queue.Name)) + require.NoError(t, err) + + res, err := handle.GetResult() + require.NoError(t, err) + + // Expected result: child workflow returns "test-input-child" + expectedResult := "test-input-child" + assert.Equal(t, expectedResult, res) + + // Check that the parent workflow (the one we ran directly) has 2 steps: + // one for enqueueing the child and one for calling GetResult + steps, err := dbosCtx.(*dbosContext).systemDB.getWorkflowSteps(dbosCtx, handle.GetWorkflowID()) + require.NoError(t, err) + assert.Len(t, steps, 2) + assert.Equal(t, runtime.FuncForPC(reflect.ValueOf(queueWorkflow).Pointer()).Name(), steps[0].StepName) + assert.Equal(t, 0, steps[0].StepID) + assert.Equal(t, "DBOS.getResult", steps[1].StepName) + assert.Equal(t, 1, steps[1].StepID) + + require.True(t, queueEntriesAreCleanedUp(dbosCtx), "expected queue entries to be cleaned up after workflow enqueues another workflow test") + }) + t.Run("DynamicRegistration", func(t *testing.T) { // Attempting to register a queue after launch should panic defer func() { diff --git a/dbos/system_database.go b/dbos/system_database.go index 6b0562aa..24ef13e2 100644 --- a/dbos/system_database.go +++ b/dbos/system_database.go @@ -243,7 +243,7 @@ func newSystemDatabase(ctx context.Context, databaseURL string, logger *slog.Log notificationListenerConnection: notificationListenerConnection, notificationsMap: notificationsMap, notificationLoopDone: make(chan struct{}), - logger: logger, + logger: logger.With("service", "system_database"), }, nil } @@ -820,16 +820,13 @@ func (s *sysDB) garbageCollectWorkflows(ctx context.Context, input garbageCollec var rowsBasedCutoff int64 err := s.pool.QueryRow(ctx, query, *input.rowsThreshold-1).Scan(&rowsBasedCutoff) - if err != nil { - if err == pgx.ErrNoRows { - // Not enough rows to apply threshold, no garbage collection needed - return nil - } + if err != nil && err != pgx.ErrNoRows { return fmt.Errorf("failed to query cutoff timestamp by rows threshold: %w", err) } - - // Use the more restrictive cutoff (higher timestamp = more recent = less deletion) - if cutoffTimestamp == nil || rowsBasedCutoff > *cutoffTimestamp { + // If we don't have a provided cutoffTimestamp and found one in the database + // Or if the found cutoffTimestamp is more restrictive (higher timestamp = more recent = less deletion) + // Use the cutoff timestamp found in the database + if rowsBasedCutoff > 0 && cutoffTimestamp == nil || (cutoffTimestamp != nil && rowsBasedCutoff > *cutoffTimestamp) { cutoffTimestamp = &rowsBasedCutoff } } @@ -2174,14 +2171,14 @@ func (s *sysDB) dequeueWorkflows(ctx context.Context, input dequeueWorkflowsInpu input.executorID, time.Now().UnixMilli(), id).Scan(&retWorkflow.name, &inputString) + if err != nil { + return nil, fmt.Errorf("failed to update workflow %s during dequeue: %w", id, err) + } if inputString != nil && len(*inputString) > 0 { retWorkflow.input = *inputString } - if err != nil { - return nil, fmt.Errorf("failed to update workflow %s during dequeue: %w", id, err) - } retWorkflows = append(retWorkflows, retWorkflow) } diff --git a/dbos/workflow.go b/dbos/workflow.go index c26469a0..9fc218c7 100644 --- a/dbos/workflow.go +++ b/dbos/workflow.go @@ -710,15 +710,6 @@ func (c *dbosContext) RunAsWorkflow(_ DBOSContext, fn WorkflowFunc, input any, o return nil, err } - // Return a polling handle if: we are enqueueing, the workflow is already in a terminal state (success or error), - if len(params.queueName) > 0 || insertStatusResult.status == WorkflowStatusSuccess || insertStatusResult.status == WorkflowStatusError { - // Commit the transaction to update the number of attempts and/or enact the enqueue - if err := tx.Commit(uncancellableCtx); err != nil { - return nil, newWorkflowExecutionError(workflowID, fmt.Sprintf("failed to commit transaction: %v", err)) - } - return newWorkflowPollingHandle[any](uncancellableCtx, workflowStatus.ID), nil - } - // Record child workflow relationship if this is a child workflow if isChildWorkflow { // Get the step ID that was used for generating the child workflow ID @@ -736,6 +727,15 @@ func (c *dbosContext) RunAsWorkflow(_ DBOSContext, fn WorkflowFunc, input any, o } } + // Return a polling handle if: we are enqueueing, the workflow is already in a terminal state (success or error), + if len(params.queueName) > 0 || insertStatusResult.status == WorkflowStatusSuccess || insertStatusResult.status == WorkflowStatusError { + // Commit the transaction to update the number of attempts and/or enact the enqueue + if err := tx.Commit(uncancellableCtx); err != nil { + return nil, newWorkflowExecutionError(workflowID, fmt.Sprintf("failed to commit transaction: %v", err)) + } + return newWorkflowPollingHandle[any](uncancellableCtx, workflowStatus.ID), nil + } + // Channel to receive the outcome from the goroutine // The buffer size of 1 allows the goroutine to send the outcome without blocking // In addition it allows the channel to be garbage collected @@ -1728,6 +1728,7 @@ type ListWorkflowsOptions struct { loadInput bool loadOutput bool queueName string + executorIDs []string } // ListWorkflowsOption is a functional option for configuring workflow listing parameters. @@ -1903,6 +1904,18 @@ func WithQueueName(queueName string) ListWorkflowsOption { } } +// WithExecutorIDs filters workflows by the specified executor IDs. +// +// Example: +// +// workflows, err := dbos.ListWorkflows(ctx, +// dbos.WithExecutorIDs([]string{"executor-123", "executor-456"})) +func WithExecutorIDs(executorIDs []string) ListWorkflowsOption { + return func(p *ListWorkflowsOptions) { + p.executorIDs = executorIDs + } +} + // ListWorkflows retrieves a list of workflows based on the provided filters. // // The function supports filtering by workflow IDs, status, time ranges, names, application versions, @@ -1976,6 +1989,7 @@ func (c *dbosContext) ListWorkflows(_ DBOSContext, opts ...ListWorkflowsOption) loadInput: params.loadInput, loadOutput: params.loadOutput, queueName: params.queueName, + executorIDs: params.executorIDs, } // Call the context method to list workflows diff --git a/dbos/workflows_test.go b/dbos/workflows_test.go index 4ed3bab7..64b3fb60 100644 --- a/dbos/workflows_test.go +++ b/dbos/workflows_test.go @@ -1695,7 +1695,7 @@ func TestSendRecv(t *testing.T) { assert.Equal(t, "--", result, "expected receive workflow result to be '--' (timeout)") }) - t.Run("TestSendRecv", func(t *testing.T) { + t.Run("TestConcurrentRecvs", func(t *testing.T) { // Test concurrent receivers - only 1 should timeout, others should get errors receiveTopic := "concurrent-recv-topic" @@ -1760,6 +1760,15 @@ func TestSendRecv(t *testing.T) { for err := range errors { t.Logf("Receiver error (expected): %v", err) + + // Check that the error is of the expected type + dbosErr, ok := err.(*DBOSError) + require.True(t, ok, "expected error to be of type *DBOSError, got %T", err) + require.Equal(t, ConflictingIDError, dbosErr.Code, "expected error code to be ConflictingIDError, got %v", dbosErr.Code) + require.Equal(t, "concurrent-recv-wfid", dbosErr.WorkflowID, "expected workflow ID to be 'concurrent-recv-wfid', got %s", dbosErr.WorkflowID) + require.True(t, dbosErr.IsBase, "expected error to have IsBase=true") + require.Contains(t, dbosErr.Message, "Conflicting workflow ID concurrent-recv-wfid", "expected error message to contain conflicting workflow ID") + errorCount++ } diff --git a/go.mod b/go.mod index 54633efa..6483f33c 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/gorilla/websocket v1.5.3 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect diff --git a/go.sum b/go.sum index c88bce93..30a855eb 100644 --- a/go.sum +++ b/go.sum @@ -27,6 +27,8 @@ github.com/golang-migrate/migrate/v4 v4.18.3 h1:EYGkoOsvgHHfm5U/naS1RP/6PL/Xv3S4 github.com/golang-migrate/migrate/v4 v4.18.3/go.mod h1:99BKpIi6ruaaXRM1A77eqZ+FWPQ3cfRa+ZVy5bmWMaY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=