diff --git a/re/api/endpoints.go b/re/api/endpoints.go index fa8c3b356..7c87b94b6 100644 --- a/re/api/endpoints.go +++ b/re/api/endpoints.go @@ -203,3 +203,41 @@ func disableRuleEndpoint(s re.Service) endpoint.Endpoint { return updateRuleStatusRes{Rule: rule}, err } } + +func abortRuleExecutionEndpoint(s re.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return nil, svcerr.ErrAuthorization + } + + req := request.(abortRuleExecutionReq) + if err := req.validate(); err != nil { + return abortRuleExecutionRes{}, err + } + err := s.AbortRuleExecution(ctx, session, req.id) + if err != nil { + return abortRuleExecutionRes{}, err + } + return abortRuleExecutionRes{}, nil + } +} + +func getRuleExecutionStatusEndpoint(s re.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return nil, svcerr.ErrAuthorization + } + + req := request.(getRuleExecutionStatusReq) + if err := req.validate(); err != nil { + return getRuleExecutionStatusRes{}, err + } + status, err := s.GetRuleExecutionStatus(ctx, session, req.id) + if err != nil { + return getRuleExecutionStatusRes{}, err + } + return getRuleExecutionStatusRes{RuleExecutionStatus: status}, nil + } +} diff --git a/re/api/requests.go b/re/api/requests.go index cdd7a9c6a..9da67ca9e 100644 --- a/re/api/requests.go +++ b/re/api/requests.go @@ -134,3 +134,27 @@ func (req deleteRuleReq) validate() error { return nil } + +type abortRuleExecutionReq struct { + id string +} + +func (req abortRuleExecutionReq) validate() error { + if req.id == "" { + return apiutil.ErrMissingID + } + + return nil +} + +type getRuleExecutionStatusReq struct { + id string +} + +func (req getRuleExecutionStatusReq) validate() error { + if req.id == "" { + return apiutil.ErrMissingID + } + + return nil +} diff --git a/re/api/responses.go b/re/api/responses.go index d13753dc1..7f6d53009 100644 --- a/re/api/responses.go +++ b/re/api/responses.go @@ -18,6 +18,8 @@ var ( _ supermq.Response = (*rulesPageRes)(nil) _ supermq.Response = (*updateRuleRes)(nil) _ supermq.Response = (*deleteRuleRes)(nil) + _ supermq.Response = (*abortRuleExecutionRes)(nil) + _ supermq.Response = (*getRuleExecutionStatusRes)(nil) ) type pageRes struct { @@ -136,3 +138,33 @@ func (res deleteRuleRes) Headers() map[string]string { func (res deleteRuleRes) Empty() bool { return true } + +type abortRuleExecutionRes struct{} + +func (res abortRuleExecutionRes) Code() int { + return http.StatusAccepted +} + +func (res abortRuleExecutionRes) Headers() map[string]string { + return map[string]string{} +} + +func (res abortRuleExecutionRes) Empty() bool { + return true +} + +type getRuleExecutionStatusRes struct { + re.RuleExecutionStatus `json:",inline"` +} + +func (res getRuleExecutionStatusRes) Code() int { + return http.StatusOK +} + +func (res getRuleExecutionStatusRes) Headers() map[string]string { + return map[string]string{} +} + +func (res getRuleExecutionStatusRes) Empty() bool { + return false +} diff --git a/re/api/transport.go b/re/api/transport.go index 54f0a41f9..afafe9547 100644 --- a/re/api/transport.go +++ b/re/api/transport.go @@ -23,8 +23,9 @@ import ( ) const ( - ruleIdKey = "ruleID" - inputChannelKey = "input_channel" + ruleIdKey = "ruleID" + inputChannelKey = "input_channel" + lastRunStatusKey = "last_run_status" ) // MakeHandler creates an HTTP handler for the service endpoints. @@ -99,6 +100,20 @@ func MakeHandler(svc re.Service, authn smqauthn.AuthNMiddleware, mux *chi.Mux, l api.EncodeResponse, opts..., ), "disable_rule").ServeHTTP) + + r.Post("/abort", otelhttp.NewHandler(kithttp.NewServer( + abortRuleExecutionEndpoint(svc), + decodeAbortRuleExecutionRequest, + api.EncodeResponse, + opts..., + ), "abort_rule_execution").ServeHTTP) + + r.Get("/execution-status", otelhttp.NewHandler(kithttp.NewServer( + getRuleExecutionStatusEndpoint(svc), + decodeGetRuleExecutionStatusRequest, + api.EncodeResponse, + opts..., + ), "get_rule_execution_status").ServeHTTP) }) }) }) @@ -198,6 +213,10 @@ func decodeListRulesRequest(_ context.Context, r *http.Request) (any, error) { if err != nil { return nil, errors.Wrap(apiutil.ErrValidation, err) } + lrs, err := apiutil.ReadStringQuery(r, lastRunStatusKey, re.NeverRun) + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } dir, err := apiutil.ReadStringQuery(r, api.DirKey, "desc") if err != nil { return nil, errors.Wrap(apiutil.ErrValidation, err) @@ -210,6 +229,10 @@ func decodeListRulesRequest(_ context.Context, r *http.Request) (any, error) { if err != nil { return nil, errors.Wrap(apiutil.ErrValidation, err) } + lrst, err := re.ToExecutionStatus(lrs) + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } tag, err := apiutil.ReadStringQuery(r, api.TagKey, "") if err != nil { return nil, errors.Wrap(apiutil.ErrValidation, err) @@ -217,14 +240,15 @@ func decodeListRulesRequest(_ context.Context, r *http.Request) (any, error) { return listRulesReq{ PageMeta: re.PageMeta{ - Offset: offset, - Limit: limit, - Name: name, - InputChannel: ic, - Status: st, - Dir: dir, - Order: order, - Tag: tag, + Offset: offset, + Limit: limit, + Name: name, + InputChannel: ic, + Status: st, + LastRunStatus: lrst, + Dir: dir, + Order: order, + Tag: tag, }, }, nil } @@ -234,3 +258,15 @@ func decodeDeleteRuleRequest(_ context.Context, r *http.Request) (any, error) { return deleteRuleReq{id: id}, nil } + +func decodeAbortRuleExecutionRequest(_ context.Context, r *http.Request) (any, error) { + id := chi.URLParam(r, ruleIdKey) + + return abortRuleExecutionReq{id: id}, nil +} + +func decodeGetRuleExecutionStatusRequest(_ context.Context, r *http.Request) (any, error) { + id := chi.URLParam(r, ruleIdKey) + + return getRuleExecutionStatusReq{id: id}, nil +} diff --git a/re/events/events.go b/re/events/events.go index ffc60228d..9e3fd8dff 100644 --- a/re/events/events.go +++ b/re/events/events.go @@ -21,6 +21,7 @@ const ( ruleUpdateSchedule = rulePrefix + "update_schedule" ruleEnable = rulePrefix + "enable" ruleDisable = rulePrefix + "disable" + ruleAbort = rulePrefix + "abort" ruleRemove = rulePrefix + "remove" ) @@ -33,6 +34,7 @@ var ( _ events.Event = (*updateRuleScheduleEvent)(nil) _ events.Event = (*enableRuleEvent)(nil) _ events.Event = (*disableRuleEvent)(nil) + _ events.Event = (*abortRuleExecutionEvent)(nil) _ events.Event = (*removeRuleEvent)(nil) ) @@ -187,3 +189,15 @@ func (rre removeRuleEvent) Encode() (map[string]any, error) { val["operation"] = ruleRemove return val, nil } + +type abortRuleExecutionEvent struct { + id string + baseRuleEvent +} + +func (aree abortRuleExecutionEvent) Encode() (map[string]any, error) { + val := aree.baseRuleEvent.Encode() + val["id"] = aree.id + val["operation"] = ruleAbort + return val, nil +} diff --git a/re/events/streams.go b/re/events/streams.go index dfeece402..c3f15148a 100644 --- a/re/events/streams.go +++ b/re/events/streams.go @@ -24,6 +24,7 @@ const ( UpdateScheduleStream = supermqPrefix + ruleUpdateSchedule EnableStream = supermqPrefix + ruleEnable DisableStream = supermqPrefix + ruleDisable + AbortStream = supermqPrefix + ruleAbort RemoveStream = supermqPrefix + ruleRemove ) @@ -183,6 +184,25 @@ func (es *eventStore) DisableRule(ctx context.Context, session authn.Session, id return rule, nil } +func (es *eventStore) AbortRuleExecution(ctx context.Context, session authn.Session, id string) error { + err := es.svc.AbortRuleExecution(ctx, session, id) + if err != nil { + return err + } + event := abortRuleExecutionEvent{ + id: id, + baseRuleEvent: newBaseRuleEvent(session, middleware.GetReqID(ctx)), + } + if err := es.Publish(ctx, AbortStream, event); err != nil { + return err + } + return nil +} + +func (es *eventStore) GetRuleExecutionStatus(ctx context.Context, session authn.Session, id string) (re.RuleExecutionStatus, error) { + return es.svc.GetRuleExecutionStatus(ctx, session, id) +} + func (es *eventStore) StartScheduler(ctx context.Context) error { return es.svc.StartScheduler(ctx) } diff --git a/re/execution_status.go b/re/execution_status.go new file mode 100644 index 000000000..aefc32be8 --- /dev/null +++ b/re/execution_status.go @@ -0,0 +1,101 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package re + +import ( + "encoding/json" + "strings" + + svcerr "github.com/absmach/supermq/pkg/errors/service" +) + +// ExecutionStatus represents the last run status of a rule execution. +type ExecutionStatus uint8 + +// Possible execution status values. +const ( + // NeverRunStatus represents a rule that has never been executed. + NeverRunStatus ExecutionStatus = iota + // SuccessStatus represents a successful rule execution. + SuccessStatus + // FailureStatus represents a failed rule execution. + FailureStatus + // PartialSuccessStatus represents a rule execution with partial success. + PartialSuccessStatus + // QueuedStatus represents a rule that is queued for execution. + QueuedStatus + // InProgressStatus represents a rule that is currently being executed. + InProgressStatus + // AbortedStatus represents a rule execution that was aborted. + AbortedStatus + // UnknownExecutionStatus represents an unknown execution status. + UnknownExecutionStatus +) + +// String representation of the possible execution status values. +const ( + NeverRun = "never_run" + Success = "success" + Failure = "failure" + PartialSuccess = "partial_success" + Queued = "queued" + InProgress = "in_progress" + Aborted = "aborted" + UnknownExecution = "unknown" +) + +func (es ExecutionStatus) String() string { + switch es { + case NeverRunStatus: + return NeverRun + case SuccessStatus: + return Success + case FailureStatus: + return Failure + case PartialSuccessStatus: + return PartialSuccess + case QueuedStatus: + return Queued + case InProgressStatus: + return InProgress + case AbortedStatus: + return Aborted + default: + return UnknownExecution + } +} + +// ToExecutionStatus converts string value to a valid execution status. +func ToExecutionStatus(status string) (ExecutionStatus, error) { + switch status { + case NeverRun: + return NeverRunStatus, nil + case Success: + return SuccessStatus, nil + case Failure: + return FailureStatus, nil + case PartialSuccess: + return PartialSuccessStatus, nil + case Queued: + return QueuedStatus, nil + case InProgress: + return InProgressStatus, nil + case Aborted: + return AbortedStatus, nil + case "", UnknownExecution: + return UnknownExecutionStatus, nil + } + return UnknownExecutionStatus, svcerr.ErrInvalidStatus +} + +func (es ExecutionStatus) MarshalJSON() ([]byte, error) { + return json.Marshal(es.String()) +} + +func (es *ExecutionStatus) UnmarshalJSON(data []byte) error { + str := strings.Trim(string(data), "\"") + val, err := ToExecutionStatus(str) + *es = val + return err +} diff --git a/re/execution_status_test.go b/re/execution_status_test.go new file mode 100644 index 000000000..d4ad404b3 --- /dev/null +++ b/re/execution_status_test.go @@ -0,0 +1,275 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package re + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestExecutionStatusString(t *testing.T) { + cases := []struct { + desc string + status ExecutionStatus + want string + }{ + { + desc: "Never Run status", + status: NeverRunStatus, + want: NeverRun, + }, + { + desc: "Success status", + status: SuccessStatus, + want: Success, + }, + { + desc: "Failure status", + status: FailureStatus, + want: Failure, + }, + { + desc: "Partial Success status", + status: PartialSuccessStatus, + want: PartialSuccess, + }, + { + desc: "Queued status", + status: QueuedStatus, + want: Queued, + }, + { + desc: "In Progress status", + status: InProgressStatus, + want: InProgress, + }, + { + desc: "Aborted status", + status: AbortedStatus, + want: Aborted, + }, + { + desc: "Unknown status", + status: UnknownExecutionStatus, + want: UnknownExecution, + }, + { + desc: "Invalid status (out of range)", + status: ExecutionStatus(99), + want: UnknownExecution, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + got := tc.status.String() + assert.Equal(t, tc.want, got) + }) + } +} + +func TestToExecutionStatus(t *testing.T) { + cases := []struct { + desc string + status string + want ExecutionStatus + wantErr bool + }{ + { + desc: "Never Run status", + status: NeverRun, + want: NeverRunStatus, + }, + { + desc: "Success status", + status: Success, + want: SuccessStatus, + }, + { + desc: "Failure status", + status: Failure, + want: FailureStatus, + }, + { + desc: "Partial Success status", + status: PartialSuccess, + want: PartialSuccessStatus, + }, + { + desc: "Queued status", + status: Queued, + want: QueuedStatus, + }, + { + desc: "In Progress status", + status: InProgress, + want: InProgressStatus, + }, + { + desc: "Aborted status", + status: Aborted, + want: AbortedStatus, + }, + { + desc: "Unknown status string", + status: UnknownExecution, + want: UnknownExecutionStatus, + }, + { + desc: "Empty string defaults to Unknown", + status: "", + want: UnknownExecutionStatus, + }, + { + desc: "Invalid status", + status: "invalid", + want: UnknownExecutionStatus, + wantErr: true, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + got, err := ToExecutionStatus(tc.status) + if tc.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.want, got) + }) + } +} + +func TestExecutionStatusMarshalJSON(t *testing.T) { + cases := []struct { + desc string + status ExecutionStatus + want string + }{ + { + desc: "Never Run status", + status: NeverRunStatus, + want: `"never_run"`, + }, + { + desc: "Success status", + status: SuccessStatus, + want: `"success"`, + }, + { + desc: "Failure status", + status: FailureStatus, + want: `"failure"`, + }, + { + desc: "Partial Success status", + status: PartialSuccessStatus, + want: `"partial_success"`, + }, + { + desc: "Queued status", + status: QueuedStatus, + want: `"queued"`, + }, + { + desc: "In Progress status", + status: InProgressStatus, + want: `"in_progress"`, + }, + { + desc: "Aborted status", + status: AbortedStatus, + want: `"aborted"`, + }, + { + desc: "Unknown status", + status: UnknownExecutionStatus, + want: `"unknown"`, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + got, err := tc.status.MarshalJSON() + assert.NoError(t, err) + assert.Equal(t, tc.want, string(got)) + }) + } +} + +func TestExecutionStatusUnmarshalJSON(t *testing.T) { + cases := []struct { + desc string + data string + want ExecutionStatus + wantErr bool + }{ + { + desc: "Never Run status", + data: `"never_run"`, + want: NeverRunStatus, + }, + { + desc: "Success status", + data: `"success"`, + want: SuccessStatus, + }, + { + desc: "Failure status", + data: `"failure"`, + want: FailureStatus, + }, + { + desc: "Partial Success status", + data: `"partial_success"`, + want: PartialSuccessStatus, + }, + { + desc: "Queued status", + data: `"queued"`, + want: QueuedStatus, + }, + { + desc: "In Progress status", + data: `"in_progress"`, + want: InProgressStatus, + }, + { + desc: "Aborted status", + data: `"aborted"`, + want: AbortedStatus, + }, + { + desc: "Unknown status string", + data: `"unknown"`, + want: UnknownExecutionStatus, + }, + { + desc: "Empty string", + data: `""`, + want: UnknownExecutionStatus, + }, + { + desc: "Invalid status", + data: `"invalid"`, + want: UnknownExecutionStatus, + wantErr: true, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + var status ExecutionStatus + err := status.UnmarshalJSON([]byte(tc.data)) + if tc.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.want, status) + }) + } +} diff --git a/re/golang.go b/re/golang.go index a7a831eeb..081500c28 100644 --- a/re/golang.go +++ b/re/golang.go @@ -31,6 +31,12 @@ type message struct { } func (re *re) processGo(ctx context.Context, details []slog.Attr, r Rule, msg *messaging.Message) pkglog.RunInfo { + select { + case <-ctx.Done(): + return pkglog.RunInfo{Level: slog.LevelError, Details: details, Message: "rule execution was cancelled"} + default: + } + i := golang.New(golang.Options{}) if err := i.Use(stdlib.Symbols); err != nil { return pkglog.RunInfo{Level: slog.LevelError, Details: details, Message: err.Error()} diff --git a/re/handlers.go b/re/handlers.go index ae4274da6..ffbd2a6ca 100644 --- a/re/handlers.go +++ b/re/handlers.go @@ -33,6 +33,11 @@ func (re *re) Handle(msg *messaging.Message) error { if n := len(msg.Payload); n > maxPayload { return errors.New(pldExceededFmt + strconv.Itoa(n)) } + + if re.workerMgr == nil { + return errors.New("worker manager not initialized") + } + // Skip filtering by message topic and fetch all non-scheduled rules instead. // It's cleaner and more efficient to match wildcards in Go, but we can // revisit this if it ever becomes a performance bottleneck. @@ -50,9 +55,24 @@ func (re *re) Handle(msg *messaging.Message) error { for _, r := range page.Rules { if matchSubject(msg.Subtopic, r.InputTopic) { - go func(ctx context.Context) { - re.runInfo <- re.process(ctx, r, msg) - }(ctx) + if workerStatus := re.workerMgr.GetWorkerStatus(r.ID); workerStatus != nil { + if processing, ok := workerStatus["processing"].(bool); ok && processing { + re.updateRuleExecutionStatus(ctx, r.ID, QueuedStatus, nil) + } + } + + if !re.workerMgr.SendMessage(msg, r) { + re.runInfo <- pkglog.RunInfo{ + Level: slog.LevelWarn, + Message: fmt.Sprintf("failed to send message to worker for rule %s", r.ID), + Details: []slog.Attr{ + slog.String("rule_id", r.ID), + slog.String("rule_name", r.Name), + slog.String("channel", msg.Channel), + slog.Time("time", time.Now().UTC()), + }, + } + } } } @@ -86,15 +106,55 @@ func (re *re) process(ctx context.Context, r Rule, msg *messaging.Message) pkglo slog.String("rule_name", r.Name), slog.Time("exec_time", time.Now().UTC()), } + + re.updateRuleExecutionStatus(ctx, r.ID, InProgressStatus, nil) + + var result pkglog.RunInfo switch r.Logic.Type { case GoType: - return re.processGo(ctx, details, r, msg) + result = re.processGo(ctx, details, r, msg) default: - return re.processLua(ctx, details, r, msg) + result = re.processLua(ctx, details, r, msg) } + + var execStatus ExecutionStatus + var errorMsg string + switch result.Level { + case slog.LevelInfo: + execStatus = SuccessStatus + case slog.LevelWarn: + if strings.Contains(result.Message, "logic returned false") || strings.Contains(result.Message, "no outputs") { + execStatus = SuccessStatus + } else { + execStatus = PartialSuccessStatus + errorMsg = result.Message + } + case slog.LevelError: + execStatus = FailureStatus + errorMsg = result.Message + default: + execStatus = FailureStatus + errorMsg = result.Message + } + + var execError error + if errorMsg != "" { + execError = errors.New(errorMsg) + } + + re.updateRuleExecutionStatus(ctx, r.ID, execStatus, execError) + + return result } func (re *re) handleOutput(ctx context.Context, o Runnable, r Rule, msg *messaging.Message, val any) error { + // Check if context is cancelled before handling output + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + switch o := o.(type) { case *outputs.Alarm: o.AlarmsPub = re.alarmsPub @@ -117,7 +177,46 @@ func (re *re) handleOutput(ctx context.Context, o Runnable, r Rule, msg *messagi } func (re *re) StartScheduler(ctx context.Context) error { + re.workerMgr = NewWorkerManager(ctx, re) + + workerMgr := re.workerMgr + + go func() { + for { + select { + case <-ctx.Done(): + return + case err := <-workerMgr.ErrorChan(): + if err != nil { + re.runInfo <- pkglog.RunInfo{ + Level: slog.LevelError, + Message: fmt.Sprintf("worker management error: %s", err), + Details: []slog.Attr{slog.Time("time", time.Now().UTC())}, + } + } + } + } + }() + + pm := PageMeta{ + Status: EnabledStatus, + } + page, err := re.repo.ListRules(ctx, pm) + if err == nil { + re.workerMgr.RefreshWorkers(ctx, page.Rules) + } + defer re.ticker.Stop() + defer func() { + if stopErr := re.workerMgr.StopAll(); stopErr != nil { + re.runInfo <- pkglog.RunInfo{ + Level: slog.LevelError, + Message: fmt.Sprintf("failed to stop worker manager: %s", stopErr), + Details: []slog.Attr{slog.Time("time", time.Now().UTC())}, + } + } + }() + for { select { case <-ctx.Done(): @@ -142,21 +241,45 @@ func (re *re) StartScheduler(ctx context.Context) error { } for _, r := range page.Rules { - go func(rule Rule) { - if _, err := re.repo.UpdateRuleDue(ctx, rule.ID, rule.Schedule.NextDue()); err != nil { - re.runInfo <- pkglog.RunInfo{Level: slog.LevelError, Message: fmt.Sprintf("failed to update rule: %s", err), Details: []slog.Attr{slog.Time("time", time.Now().UTC())}} - return + msg := &messaging.Message{ + Domain: r.DomainID, + Channel: r.InputChannel, + Subtopic: r.InputTopic, + Protocol: protocol, + Created: due.Unix(), + } + + if workerStatus := re.workerMgr.GetWorkerStatus(r.ID); workerStatus != nil { + if processing, ok := workerStatus["processing"].(bool); ok && processing { + re.updateRuleExecutionStatus(ctx, r.ID, QueuedStatus, nil) } + } + //nolint:contextcheck + if !re.workerMgr.SendMessage(msg, r) { + re.runInfo <- pkglog.RunInfo{ + Level: slog.LevelWarn, + Message: fmt.Sprintf("failed to send scheduled message to worker for rule %s", r.ID), + Details: []slog.Attr{ + slog.String("rule_id", r.ID), + slog.String("rule_name", r.Name), + slog.String("channel", msg.Channel), + slog.Time("scheduled_time", due), + }, + } + } - msg := &messaging.Message{ - Domain: rule.DomainID, - Channel: rule.InputChannel, - Subtopic: rule.InputTopic, - Protocol: protocol, - Created: due.Unix(), + go func(ruleID string, nextDue time.Time) { + if _, err := re.repo.UpdateRuleDue(ctx, ruleID, nextDue); err != nil { + re.runInfo <- pkglog.RunInfo{ + Level: slog.LevelError, + Message: fmt.Sprintf("failed to update rule due time: %s", err), + Details: []slog.Attr{ + slog.String("rule_id", ruleID), + slog.Time("time", time.Now().UTC()), + }, + } } - re.runInfo <- re.process(ctx, rule, msg) - }(r) + }(r.ID, r.Schedule.NextDue()) } // Reset due, it will reset in the page meta as well. due = time.Now().UTC() diff --git a/re/lua.go b/re/lua.go index 21e26592f..332125c1c 100644 --- a/re/lua.go +++ b/re/lua.go @@ -32,6 +32,12 @@ import ( const payloadKey = "payload" func (re *re) processLua(ctx context.Context, details []slog.Attr, r Rule, msg *messaging.Message) pkglog.RunInfo { + select { + case <-ctx.Done(): + return pkglog.RunInfo{Level: slog.LevelError, Details: details, Message: "rule execution was cancelled"} + default: + } + l := lua.NewState() defer l.Close() preload(l) diff --git a/re/middleware/authorization.go b/re/middleware/authorization.go index 3dfe56fe6..8cffc4068 100644 --- a/re/middleware/authorization.go +++ b/re/middleware/authorization.go @@ -258,6 +258,46 @@ func (am *authorizationMiddleware) DisableRule(ctx context.Context, session auth return am.svc.DisableRule(ctx, session, id) } +func (am *authorizationMiddleware) AbortRuleExecution(ctx context.Context, session authn.Session, id string) error { + if err := am.authorize(ctx, smqauthz.PolicyReq{ + Domain: session.DomainID, + SubjectType: policies.UserType, + SubjectKind: policies.UsersKind, + Subject: session.DomainUserID, + Object: session.DomainID, + ObjectType: policies.DomainType, + Permission: policies.MembershipPermission, + }); err != nil { + return errors.Wrap(errDomainUpdateRules, err) + } + + params := map[string]any{ + "entity_id": id, + } + + if err := am.callOut(ctx, session, re.OpAbortRuleExecution, params); err != nil { + return err + } + + return am.svc.AbortRuleExecution(ctx, session, id) +} + +func (am *authorizationMiddleware) GetRuleExecutionStatus(ctx context.Context, session authn.Session, id string) (re.RuleExecutionStatus, error) { + if err := am.authorize(ctx, smqauthz.PolicyReq{ + Domain: session.DomainID, + SubjectType: policies.UserType, + SubjectKind: policies.UsersKind, + Subject: session.DomainUserID, + Object: session.DomainID, + ObjectType: policies.DomainType, + Permission: policies.MembershipPermission, + }); err != nil { + return re.RuleExecutionStatus{}, errors.Wrap(errDomainViewRules, err) + } + + return am.svc.GetRuleExecutionStatus(ctx, session, id) +} + func (am *authorizationMiddleware) StartScheduler(ctx context.Context) error { return am.svc.StartScheduler(ctx) } diff --git a/re/middleware/logging.go b/re/middleware/logging.go index 948e528ed..305a86f5b 100644 --- a/re/middleware/logging.go +++ b/re/middleware/logging.go @@ -200,6 +200,44 @@ func (lm *loggingMiddleware) DisableRule(ctx context.Context, session authn.Sess return lm.svc.DisableRule(ctx, session, id) } +func (lm *loggingMiddleware) AbortRuleExecution(ctx context.Context, session authn.Session, id string) (err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("domain_id", session.DomainID), + slog.String("rule_id", id), + } + if err != nil { + args = append(args, slog.String("error", err.Error())) + lm.logger.Warn("Abort rule execution failed", args...) + return + } + lm.logger.Info("Abort rule execution completed successfully", args...) + }(time.Now()) + return lm.svc.AbortRuleExecution(ctx, session, id) +} + +func (lm *loggingMiddleware) GetRuleExecutionStatus(ctx context.Context, session authn.Session, id string) (res re.RuleExecutionStatus, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("domain_id", session.DomainID), + slog.Group("rule", + slog.String("id", id), + slog.Bool("worker_running", res.WorkerRunning), + slog.Int("queue_length", res.QueueLength), + ), + } + if err != nil { + args = append(args, slog.String("error", err.Error())) + lm.logger.Warn("Get rule execution status failed", args...) + return + } + lm.logger.Info("Get rule execution status completed successfully", args...) + }(time.Now()) + return lm.svc.GetRuleExecutionStatus(ctx, session, id) +} + func (lm *loggingMiddleware) StartScheduler(ctx context.Context) (err error) { defer func(begin time.Time) { args := []any{ @@ -238,5 +276,5 @@ func (lm *loggingMiddleware) Handle(msg *messaging.Message) (err error) { } func (lm *loggingMiddleware) Cancel() error { - return lm.Cancel() + return lm.svc.Cancel() } diff --git a/re/mocks/repository.go b/re/mocks/repository.go index a6458c5db..c0138fe09 100644 --- a/re/mocks/repository.go +++ b/re/mocks/repository.go @@ -369,6 +369,63 @@ func (_c *Repository_UpdateRuleDue_Call) RunAndReturn(run func(ctx context.Conte return _c } +// UpdateRuleExecutionStatus provides a mock function for the type Repository +func (_mock *Repository) UpdateRuleExecutionStatus(ctx context.Context, r re.Rule) error { + ret := _mock.Called(ctx, r) + + if len(ret) == 0 { + panic("no return value specified for UpdateRuleExecutionStatus") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, re.Rule) error); ok { + r0 = returnFunc(ctx, r) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Repository_UpdateRuleExecutionStatus_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateRuleExecutionStatus' +type Repository_UpdateRuleExecutionStatus_Call struct { + *mock.Call +} + +// UpdateRuleExecutionStatus is a helper method to define mock.On call +// - ctx context.Context +// - r re.Rule +func (_e *Repository_Expecter) UpdateRuleExecutionStatus(ctx interface{}, r interface{}) *Repository_UpdateRuleExecutionStatus_Call { + return &Repository_UpdateRuleExecutionStatus_Call{Call: _e.mock.On("UpdateRuleExecutionStatus", ctx, r)} +} + +func (_c *Repository_UpdateRuleExecutionStatus_Call) Run(run func(ctx context.Context, r re.Rule)) *Repository_UpdateRuleExecutionStatus_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 re.Rule + if args[1] != nil { + arg1 = args[1].(re.Rule) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Repository_UpdateRuleExecutionStatus_Call) Return(err error) *Repository_UpdateRuleExecutionStatus_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Repository_UpdateRuleExecutionStatus_Call) RunAndReturn(run func(ctx context.Context, r re.Rule) error) *Repository_UpdateRuleExecutionStatus_Call { + _c.Call.Return(run) + return _c +} + // UpdateRuleSchedule provides a mock function for the type Repository func (_mock *Repository) UpdateRuleSchedule(ctx context.Context, r re.Rule) (re.Rule, error) { ret := _mock.Called(ctx, r) diff --git a/re/mocks/service.go b/re/mocks/service.go index 5a3d500f8..2b26f7c4e 100644 --- a/re/mocks/service.go +++ b/re/mocks/service.go @@ -43,6 +43,69 @@ func (_m *Service) EXPECT() *Service_Expecter { return &Service_Expecter{mock: &_m.Mock} } +// AbortRuleExecution provides a mock function for the type Service +func (_mock *Service) AbortRuleExecution(ctx context.Context, session authn.Session, id string) error { + ret := _mock.Called(ctx, session, id) + + if len(ret) == 0 { + panic("no return value specified for AbortRuleExecution") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string) error); ok { + r0 = returnFunc(ctx, session, id) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_AbortRuleExecution_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AbortRuleExecution' +type Service_AbortRuleExecution_Call struct { + *mock.Call +} + +// AbortRuleExecution is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - id string +func (_e *Service_Expecter) AbortRuleExecution(ctx interface{}, session interface{}, id interface{}) *Service_AbortRuleExecution_Call { + return &Service_AbortRuleExecution_Call{Call: _e.mock.On("AbortRuleExecution", ctx, session, id)} +} + +func (_c *Service_AbortRuleExecution_Call) Run(run func(ctx context.Context, session authn.Session, id string)) *Service_AbortRuleExecution_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Service_AbortRuleExecution_Call) Return(err error) *Service_AbortRuleExecution_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_AbortRuleExecution_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, id string) error) *Service_AbortRuleExecution_Call { + _c.Call.Return(run) + return _c +} + // AddRule provides a mock function for the type Service func (_mock *Service) AddRule(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, error) { ret := _mock.Called(ctx, session, r) @@ -303,6 +366,78 @@ func (_c *Service_EnableRule_Call) RunAndReturn(run func(ctx context.Context, se return _c } +// GetRuleExecutionStatus provides a mock function for the type Service +func (_mock *Service) GetRuleExecutionStatus(ctx context.Context, session authn.Session, id string) (re.RuleExecutionStatus, error) { + ret := _mock.Called(ctx, session, id) + + if len(ret) == 0 { + panic("no return value specified for GetRuleExecutionStatus") + } + + var r0 re.RuleExecutionStatus + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string) (re.RuleExecutionStatus, error)); ok { + return returnFunc(ctx, session, id) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string) re.RuleExecutionStatus); ok { + r0 = returnFunc(ctx, session, id) + } else { + r0 = ret.Get(0).(re.RuleExecutionStatus) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, string) error); ok { + r1 = returnFunc(ctx, session, id) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_GetRuleExecutionStatus_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetRuleExecutionStatus' +type Service_GetRuleExecutionStatus_Call struct { + *mock.Call +} + +// GetRuleExecutionStatus is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - id string +func (_e *Service_Expecter) GetRuleExecutionStatus(ctx interface{}, session interface{}, id interface{}) *Service_GetRuleExecutionStatus_Call { + return &Service_GetRuleExecutionStatus_Call{Call: _e.mock.On("GetRuleExecutionStatus", ctx, session, id)} +} + +func (_c *Service_GetRuleExecutionStatus_Call) Run(run func(ctx context.Context, session authn.Session, id string)) *Service_GetRuleExecutionStatus_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Service_GetRuleExecutionStatus_Call) Return(ruleExecutionStatus re.RuleExecutionStatus, err error) *Service_GetRuleExecutionStatus_Call { + _c.Call.Return(ruleExecutionStatus, err) + return _c +} + +func (_c *Service_GetRuleExecutionStatus_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, id string) (re.RuleExecutionStatus, error)) *Service_GetRuleExecutionStatus_Call { + _c.Call.Return(run) + return _c +} + // Handle provides a mock function for the type Service func (_mock *Service) Handle(msg *messaging.Message) error { ret := _mock.Called(msg) diff --git a/re/postgres/init.go b/re/postgres/init.go index 99eb347ff..549b02c78 100644 --- a/re/postgres/init.go +++ b/re/postgres/init.go @@ -50,6 +50,21 @@ func Migration() *migrate.MemoryMigrationSource { `ALTER TABLE rules DROP COLUMN tags;`, }, }, + { + Id: "rules_03", + Up: []string{ + `ALTER TABLE rules ADD COLUMN last_run_status SMALLINT NOT NULL DEFAULT 6 CHECK (last_run_status >= 0);`, // 6 = NeverRunStatus + `ALTER TABLE rules ADD COLUMN last_run_time TIMESTAMP;`, + `ALTER TABLE rules ADD COLUMN last_run_error_message TEXT;`, + `ALTER TABLE rules ADD COLUMN execution_count BIGINT NOT NULL DEFAULT 0;`, + }, + Down: []string{ + `ALTER TABLE rules DROP COLUMN last_run_status;`, + `ALTER TABLE rules DROP COLUMN last_run_time;`, + `ALTER TABLE rules DROP COLUMN last_run_error_message;`, + `ALTER TABLE rules DROP COLUMN execution_count;`, + }, + }, }, } } diff --git a/re/postgres/repository.go b/re/postgres/repository.go index 3596a9f7d..242a656d3 100644 --- a/re/postgres/repository.go +++ b/re/postgres/repository.go @@ -28,11 +28,14 @@ func NewRepository(db postgres.Database) re.Repository { func (repo *PostgresRepository) AddRule(ctx context.Context, r re.Rule) (re.Rule, error) { q := ` INSERT INTO rules (id, name, domain_id, tags, metadata, input_channel, input_topic, logic_type, logic_value, - outputs, start_datetime, time, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status) + outputs, start_datetime, time, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status, + last_run_status, last_run_time, last_run_error_message, execution_count) VALUES (:id, :name, :domain_id, :tags, :metadata, :input_channel, :input_topic, :logic_type, :logic_value, - :outputs, :start_datetime, :time, :recurring, :recurring_period, :created_at, :created_by, :updated_at, :updated_by, :status) + :outputs, :start_datetime, :time, :recurring, :recurring_period, :created_at, :created_by, :updated_at, :updated_by, :status, + :last_run_status, :last_run_time, :last_run_error_message, :execution_count) RETURNING id, name, domain_id, tags, metadata, input_channel, input_topic, logic_type, logic_value, - outputs, start_datetime, time, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status; + outputs, start_datetime, time, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status, + last_run_status, last_run_time, last_run_error_message, execution_count; ` dbr, err := ruleToDb(r) if err != nil { @@ -62,7 +65,8 @@ func (repo *PostgresRepository) AddRule(ctx context.Context, r re.Rule) (re.Rule func (repo *PostgresRepository) ViewRule(ctx context.Context, id string) (re.Rule, error) { q := ` SELECT id, name, domain_id, tags, metadata, input_channel, input_topic, logic_type, logic_value, outputs, - start_datetime, time, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status + start_datetime, time, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status, + last_run_status, last_run_time, last_run_error_message, execution_count FROM rules WHERE id = $1; ` @@ -87,11 +91,31 @@ func (repo *PostgresRepository) UpdateRuleStatus(ctx context.Context, r re.Rule) SET status = :status, updated_at = :updated_at, updated_by = :updated_by WHERE id = :id RETURNING id, name, domain_id, tags, metadata, input_channel, input_topic, logic_type, logic_value, - outputs, start_datetime, time, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status;` + outputs, start_datetime, time, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status, + last_run_status, last_run_time, last_run_error_message, execution_count;` return repo.update(ctx, r, q) } +func (repo *PostgresRepository) UpdateRuleExecutionStatus(ctx context.Context, r re.Rule) error { + q := `UPDATE rules + SET last_run_status = :last_run_status, last_run_time = :last_run_time, + last_run_error_message = :last_run_error_message, execution_count = :execution_count + WHERE id = :id;` + + dbr, err := ruleToDb(r) + if err != nil { + return errors.Wrap(repoerr.ErrUpdateEntity, err) + } + + _, err = repo.DB.NamedExecContext(ctx, q, dbr) + if err != nil { + return postgres.HandleError(repoerr.ErrUpdateEntity, err) + } + + return nil +} + func (repo *PostgresRepository) UpdateRule(ctx context.Context, r re.Rule) (re.Rule, error) { var query []string var upq string @@ -119,7 +143,8 @@ func (repo *PostgresRepository) UpdateRule(ctx context.Context, r re.Rule) (re.R UPDATE rules SET %s updated_at = :updated_at, updated_by = :updated_by WHERE id = :id RETURNING id, name, domain_id, tags, metadata, input_channel, input_topic, logic_type, logic_value, - outputs, start_datetime, time, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status; + outputs, start_datetime, time, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status, + last_run_status, last_run_time, last_run_error_message, execution_count; `, upq) return repo.update(ctx, r, q) @@ -129,7 +154,8 @@ func (repo *PostgresRepository) UpdateRuleTags(ctx context.Context, r re.Rule) ( q := `UPDATE rules SET tags = :tags, updated_at = :updated_at, updated_by = :updated_by WHERE id = :id AND status = :status RETURNING id, name, domain_id, tags, metadata, input_channel, input_topic, logic_type, logic_value, - outputs, start_datetime, time, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status;` + outputs, start_datetime, time, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status, + last_run_status, last_run_time, last_run_error_message, execution_count;` r.Status = re.EnabledStatus return repo.update(ctx, r, q) @@ -141,7 +167,8 @@ func (repo *PostgresRepository) UpdateRuleSchedule(ctx context.Context, r re.Rul SET start_datetime = :start_datetime, time = :time, recurring = :recurring, recurring_period = :recurring_period, updated_at = :updated_at, updated_by = :updated_by WHERE id = :id RETURNING id, name, domain_id, tags, metadata, input_channel, input_topic, logic_type, logic_value, - outputs, start_datetime, time, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status; + outputs, start_datetime, time, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status, + last_run_status, last_run_time, last_run_error_message, execution_count; ` return repo.update(ctx, r, q) } @@ -223,7 +250,8 @@ func (repo *PostgresRepository) ListRules(ctx context.Context, pm re.PageMeta) ( q := fmt.Sprintf(` SELECT id, name, domain_id, tags, input_channel, input_topic, logic_type, logic_value, outputs, - start_datetime, time, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status + start_datetime, time, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status, + last_run_status, last_run_time, last_run_error_message, execution_count FROM rules r %s %s %s; `, pq, orderClause, pgData) rows, err := repo.DB.NamedQueryContext(ctx, q, pm) @@ -266,7 +294,8 @@ func (repo *PostgresRepository) UpdateRuleDue(ctx context.Context, id string, du UPDATE rules SET time = :time, updated_at = :updated_at WHERE id = :id RETURNING id, name, domain_id, tags, metadata, input_channel, input_topic, logic_type, logic_value, - outputs, start_datetime, time, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status; + outputs, start_datetime, time, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status, + last_run_status, last_run_time, last_run_error_message, execution_count; ` dbr := dbRule{ ID: id, @@ -304,6 +333,9 @@ func pageRulesQuery(pm re.PageMeta) string { if pm.Status != re.AllStatus { query = append(query, "r.status = :status") } + if pm.LastRunStatus != re.NeverRunStatus { + query = append(query, "r.last_run_status = :last_run_status") + } if pm.Domain != "" { query = append(query, "r.domain_id = :domain_id") } diff --git a/re/postgres/rule.go b/re/postgres/rule.go index 3027e3d66..8655a772e 100644 --- a/re/postgres/rule.go +++ b/re/postgres/rule.go @@ -16,25 +16,29 @@ import ( // dbRule represents the database structure for a Rule. type dbRule struct { - ID string `db:"id"` - Name string `db:"name"` - DomainID string `db:"domain_id"` - Tags pgtype.TextArray `db:"tags,omitempty"` - Metadata []byte `db:"metadata,omitempty"` - InputChannel string `db:"input_channel"` - InputTopic sql.NullString `db:"input_topic"` - LogicType re.ScriptType `db:"logic_type"` - LogicValue string `db:"logic_value"` - Outputs []byte `db:"outputs"` - StartDateTime sql.NullTime `db:"start_datetime"` - Time sql.NullTime `db:"time"` - Recurring schedule.Recurring `db:"recurring"` - RecurringPeriod uint `db:"recurring_period"` - Status re.Status `db:"status"` - CreatedAt time.Time `db:"created_at"` - CreatedBy string `db:"created_by"` - UpdatedAt time.Time `db:"updated_at"` - UpdatedBy string `db:"updated_by"` + ID string `db:"id"` + Name string `db:"name"` + DomainID string `db:"domain_id"` + Tags pgtype.TextArray `db:"tags,omitempty"` + Metadata []byte `db:"metadata,omitempty"` + InputChannel string `db:"input_channel"` + InputTopic sql.NullString `db:"input_topic"` + LogicType re.ScriptType `db:"logic_type"` + LogicValue string `db:"logic_value"` + Outputs []byte `db:"outputs"` + StartDateTime sql.NullTime `db:"start_datetime"` + Time sql.NullTime `db:"time"` + Recurring schedule.Recurring `db:"recurring"` + RecurringPeriod uint `db:"recurring_period"` + Status re.Status `db:"status"` + LastRunStatus re.ExecutionStatus `db:"last_run_status"` + LastRunTime sql.NullTime `db:"last_run_time"` + LastRunErrorMessage sql.NullString `db:"last_run_error_message"` + ExecutionCount uint64 `db:"execution_count"` + CreatedAt time.Time `db:"created_at"` + CreatedBy string `db:"created_by"` + UpdatedAt time.Time `db:"updated_at"` + UpdatedBy string `db:"updated_by"` } func ruleToDb(r re.Rule) (dbRule, error) { @@ -55,6 +59,13 @@ func ruleToDb(r re.Rule) (dbRule, error) { if !r.Schedule.Time.IsZero() { t.Valid = true } + + lastRunTime := sql.NullTime{} + if r.LastRunTime != nil && !r.LastRunTime.IsZero() { + lastRunTime.Time = *r.LastRunTime + lastRunTime.Valid = true + } + var tags pgtype.TextArray if err := tags.Set(r.Tags); err != nil { return dbRule{}, err @@ -66,25 +77,29 @@ func ruleToDb(r re.Rule) (dbRule, error) { } return dbRule{ - ID: r.ID, - Name: r.Name, - DomainID: r.DomainID, - Tags: tags, - Metadata: metadata, - InputChannel: r.InputChannel, - InputTopic: toNullString(r.InputTopic), - LogicType: r.Logic.Type, - LogicValue: r.Logic.Value, - Outputs: outputs, - StartDateTime: start, - Time: t, - Recurring: r.Schedule.Recurring, - RecurringPeriod: r.Schedule.RecurringPeriod, - Status: r.Status, - CreatedAt: r.CreatedAt, - CreatedBy: r.CreatedBy, - UpdatedAt: r.UpdatedAt, - UpdatedBy: r.UpdatedBy, + ID: r.ID, + Name: r.Name, + DomainID: r.DomainID, + Tags: tags, + Metadata: metadata, + InputChannel: r.InputChannel, + InputTopic: toNullString(r.InputTopic), + LogicType: r.Logic.Type, + LogicValue: r.Logic.Value, + Outputs: outputs, + StartDateTime: start, + Time: t, + Recurring: r.Schedule.Recurring, + RecurringPeriod: r.Schedule.RecurringPeriod, + Status: r.Status, + LastRunStatus: r.LastRunStatus, + LastRunTime: lastRunTime, + LastRunErrorMessage: toNullString(r.LastRunErrorMessage), + ExecutionCount: r.ExecutionCount, + CreatedAt: r.CreatedAt, + CreatedBy: r.CreatedBy, + UpdatedAt: r.UpdatedAt, + UpdatedBy: r.UpdatedBy, }, nil } @@ -108,6 +123,11 @@ func dbToRule(dto dbRule) (re.Rule, error) { } } + var lastRunTime *time.Time + if dto.LastRunTime.Valid { + lastRunTime = &dto.LastRunTime.Time + } + return re.Rule{ ID: dto.ID, Name: dto.Name, @@ -127,11 +147,15 @@ func dbToRule(dto dbRule) (re.Rule, error) { Recurring: dto.Recurring, RecurringPeriod: dto.RecurringPeriod, }, - Status: dto.Status, - CreatedAt: dto.CreatedAt, - CreatedBy: dto.CreatedBy, - UpdatedAt: dto.UpdatedAt, - UpdatedBy: dto.UpdatedBy, + Status: dto.Status, + LastRunStatus: dto.LastRunStatus, + LastRunTime: lastRunTime, + LastRunErrorMessage: fromNullString(dto.LastRunErrorMessage), + ExecutionCount: dto.ExecutionCount, + CreatedAt: dto.CreatedAt, + CreatedBy: dto.CreatedBy, + UpdatedAt: dto.UpdatedAt, + UpdatedBy: dto.UpdatedBy, }, nil } diff --git a/re/rule.go b/re/rule.go index 695dd28fc..38b948e32 100644 --- a/re/rule.go +++ b/re/rule.go @@ -30,6 +30,7 @@ const ( OpRemoveRule = "OpRemoveRule" OpEnableRule = "OpEnableRule" OpDisableRule = "OpDisableRule" + OpAbortRuleExecution = "OpAbortRuleExecution" ) type ( @@ -54,32 +55,38 @@ var outputRegistry = map[outputs.OutputType]func() Runnable{ } type Rule struct { - ID string `json:"id"` - Name string `json:"name"` - DomainID string `json:"domain"` - Metadata Metadata `json:"metadata,omitempty"` - Tags []string `json:"tags,omitempty"` - InputChannel string `json:"input_channel"` - InputTopic string `json:"input_topic"` - Logic Script `json:"logic"` - Outputs Outputs `json:"outputs,omitempty"` - Schedule schedule.Schedule `json:"schedule,omitempty"` - Status Status `json:"status"` - CreatedAt time.Time `json:"created_at"` - CreatedBy string `json:"created_by"` - UpdatedAt time.Time `json:"updated_at"` - UpdatedBy string `json:"updated_by"` + ID string `json:"id"` + Name string `json:"name"` + DomainID string `json:"domain"` + Metadata Metadata `json:"metadata,omitempty"` + Tags []string `json:"tags,omitempty"` + InputChannel string `json:"input_channel"` + InputTopic string `json:"input_topic"` + Logic Script `json:"logic"` + Outputs Outputs `json:"outputs,omitempty"` + Schedule schedule.Schedule `json:"schedule,omitempty"` + Status Status `json:"status"` + LastRunStatus ExecutionStatus `json:"last_run_status"` + LastRunTime *time.Time `json:"last_run_time,omitempty"` + LastRunErrorMessage string `json:"last_run_error_message,omitempty"` + ExecutionCount uint64 `json:"execution_count"` + CreatedAt time.Time `json:"created_at"` + CreatedBy string `json:"created_by"` + UpdatedAt time.Time `json:"updated_at"` + UpdatedBy string `json:"updated_by"` } // EventEncode converts a Rule struct to map[string]any at event producer. func (r Rule) EventEncode() (map[string]any, error) { m := map[string]any{ - "id": r.ID, - "name": r.Name, - "created_at": r.CreatedAt.Format(time.RFC3339Nano), - "created_by": r.CreatedBy, - "schedule": r.Schedule.EventEncode(), - "status": r.Status.String(), + "id": r.ID, + "name": r.Name, + "created_at": r.CreatedAt.Format(time.RFC3339Nano), + "created_by": r.CreatedBy, + "schedule": r.Schedule.EventEncode(), + "status": r.Status.String(), + "last_run_status": r.LastRunStatus.String(), + "execution_count": r.ExecutionCount, } if r.Name != "" { @@ -98,6 +105,14 @@ func (r Rule) EventEncode() (map[string]any, error) { m["updated_by"] = r.UpdatedBy } + if r.LastRunTime != nil { + m["last_run_time"] = r.LastRunTime.Format(time.RFC3339Nano) + } + + if r.LastRunErrorMessage != "" { + m["last_run_error_message"] = r.LastRunErrorMessage + } + if len(r.Metadata) > 0 { m["metadata"] = r.Metadata } @@ -175,6 +190,7 @@ type PageMeta struct { Scheduled *bool `json:"scheduled,omitempty"` OutputChannel string `json:"output_channel,omitempty" db:"output_channel"` Status Status `json:"status,omitempty" db:"status"` + LastRunStatus ExecutionStatus `json:"last_run_status,omitempty" db:"last_run_status"` Domain string `json:"domain_id,omitempty" db:"domain_id"` Tag string `json:"tag,omitempty"` ScheduledBefore *time.Time `json:"scheduled_before,omitempty" db:"scheduled_before"` // Filter rules scheduled before this time @@ -185,11 +201,12 @@ type PageMeta struct { // EventEncode converts a PageMeta struct to map[string]any. func (pm PageMeta) EventEncode() map[string]any { m := map[string]any{ - "total": pm.Total, - "offset": pm.Offset, - "limit": pm.Limit, - "status": pm.Status.String(), - "domain_id": pm.Domain, + "total": pm.Total, + "offset": pm.Offset, + "limit": pm.Limit, + "status": pm.Status.String(), + "last_run_status": pm.LastRunStatus.String(), + "domain_id": pm.Domain, } if pm.Dir != "" { @@ -233,6 +250,12 @@ type Page struct { Rules []Rule `json:"rules"` } +type RuleExecutionStatus struct { + Rule Rule `json:"rule"` + WorkerRunning bool `json:"worker_running"` + QueueLength int `json:"queue_length"` +} + type Service interface { messaging.MessageHandler AddRule(ctx context.Context, session authn.Session, r Rule) (Rule, error) @@ -244,6 +267,8 @@ type Service interface { RemoveRule(ctx context.Context, session authn.Session, id string) error EnableRule(ctx context.Context, session authn.Session, id string) (Rule, error) DisableRule(ctx context.Context, session authn.Session, id string) (Rule, error) + AbortRuleExecution(ctx context.Context, session authn.Session, id string) error + GetRuleExecutionStatus(ctx context.Context, session authn.Session, id string) (RuleExecutionStatus, error) StartScheduler(ctx context.Context) error } @@ -256,6 +281,7 @@ type Repository interface { UpdateRuleSchedule(ctx context.Context, r Rule) (Rule, error) RemoveRule(ctx context.Context, id string) error UpdateRuleStatus(ctx context.Context, r Rule) (Rule, error) + UpdateRuleExecutionStatus(ctx context.Context, r Rule) error ListRules(ctx context.Context, pm PageMeta) (Page, error) UpdateRuleDue(ctx context.Context, id string, due time.Time) (Rule, error) } diff --git a/re/service.go b/re/service.go index 241fe6ea6..68ce5e81e 100644 --- a/re/service.go +++ b/re/service.go @@ -5,11 +5,14 @@ package re import ( "context" + "fmt" + "log/slog" "time" grpcReadersV1 "github.com/absmach/magistrala/api/grpc/readers/v1" "github.com/absmach/magistrala/pkg/emailer" pkglog "github.com/absmach/magistrala/pkg/logger" + "github.com/absmach/magistrala/pkg/schedule" "github.com/absmach/magistrala/pkg/ticker" "github.com/absmach/supermq" "github.com/absmach/supermq/pkg/authn" @@ -28,10 +31,11 @@ type re struct { ticker ticker.Ticker email emailer.Emailer readers grpcReadersV1.ReadersServiceClient + workerMgr *WorkerManager } func NewService(repo Repository, runInfo chan pkglog.RunInfo, idp supermq.IDProvider, rePubSub messaging.PubSub, writersPub, alarmsPub messaging.Publisher, tck ticker.Ticker, emailer emailer.Emailer, readers grpcReadersV1.ReadersServiceClient) Service { - return &re{ + reEngine := &re{ repo: repo, idp: idp, runInfo: runInfo, @@ -42,6 +46,26 @@ func NewService(repo Repository, runInfo chan pkglog.RunInfo, idp supermq.IDProv email: emailer, readers: readers, } + return reEngine +} + +func shouldCreateWorker(rule Rule) bool { + if rule.Status != EnabledStatus { + return false + } + + if rule.Schedule.Recurring == schedule.None { + return true + } + + now := time.Now().UTC() + dueTime := rule.Schedule.Time + + if dueTime.IsZero() || dueTime.Before(now) { + return true + } + + return dueTime.Sub(now) <= time.Hour } func (re *re) AddRule(ctx context.Context, session authn.Session, r Rule) (Rule, error) { @@ -55,6 +79,8 @@ func (re *re) AddRule(ctx context.Context, session authn.Session, r Rule) (Rule, r.CreatedBy = session.UserID r.DomainID = session.DomainID r.Status = EnabledStatus + r.LastRunStatus = NeverRunStatus + r.ExecutionCount = 0 if !r.Schedule.StartDateTime.IsZero() { r.Schedule.StartDateTime = now @@ -66,6 +92,10 @@ func (re *re) AddRule(ctx context.Context, session authn.Session, r Rule) (Rule, return Rule{}, errors.Wrap(svcerr.ErrCreateEntity, err) } + if shouldCreateWorker(rule) { + re.workerMgr.AddWorker(ctx, rule) + } + return rule, nil } @@ -86,6 +116,12 @@ func (re *re) UpdateRule(ctx context.Context, session authn.Session, r Rule) (Ru return Rule{}, errors.Wrap(svcerr.ErrUpdateEntity, err) } + if shouldCreateWorker(rule) { + re.workerMgr.UpdateWorker(ctx, rule) + } else { + re.workerMgr.RemoveWorker(rule.ID) + } + return rule, nil } @@ -108,6 +144,12 @@ func (re *re) UpdateRuleSchedule(ctx context.Context, session authn.Session, r R return Rule{}, errors.Wrap(svcerr.ErrUpdateEntity, err) } + if shouldCreateWorker(rule) { + re.workerMgr.UpdateWorker(ctx, rule) + } else { + re.workerMgr.RemoveWorker(rule.ID) + } + return rule, nil } @@ -125,6 +167,8 @@ func (re *re) RemoveRule(ctx context.Context, session authn.Session, id string) return errors.Wrap(svcerr.ErrRemoveEntity, err) } + re.workerMgr.RemoveWorker(id) + return nil } @@ -143,6 +187,11 @@ func (re *re) EnableRule(ctx context.Context, session authn.Session, id string) if err != nil { return Rule{}, errors.Wrap(svcerr.ErrUpdateEntity, err) } + + if shouldCreateWorker(rule) { + re.workerMgr.AddWorker(ctx, rule) + } + return rule, nil } @@ -161,9 +210,113 @@ func (re *re) DisableRule(ctx context.Context, session authn.Session, id string) if err != nil { return Rule{}, errors.Wrap(svcerr.ErrUpdateEntity, err) } + + re.workerMgr.RemoveWorker(id) + return rule, nil } func (re *re) Cancel() error { + return re.workerMgr.StopAll() +} + +func (re *re) AbortRuleExecution(ctx context.Context, session authn.Session, id string) error { + rule, err := re.repo.ViewRule(ctx, id) + if err != nil { + return errors.Wrap(svcerr.ErrViewEntity, err) + } + + if rule.LastRunStatus != InProgressStatus && rule.LastRunStatus != QueuedStatus { + return errors.Wrap(errors.New(fmt.Sprintf("cannot abort rule with status '%s': rule must be in 'in_progress' or 'queued' status", + rule.LastRunStatus.String())), svcerr.ErrMalformedEntity) + } + + if re.workerMgr != nil { + workerStatus := re.workerMgr.GetWorkerStatus(id) + if workerStatus == nil { + return errors.Wrap(errors.New("no active worker found for this rule"), svcerr.ErrNotFound) + } + + running, ok := workerStatus["running"].(bool) + if !ok || !running { + return errors.Wrap(errors.New("cannot abort: worker is not currently running"), svcerr.ErrMalformedEntity) + } + + queueLen, _ := workerStatus["queue_length"].(int) + if queueLen == 0 && rule.LastRunStatus != InProgressStatus { + return errors.Wrap(errors.New("cannot abort: no execution in progress"), svcerr.ErrMalformedEntity) + } + + re.workerMgr.AbortRule(id) + } + return nil } + +func (re *re) GetRuleExecutionStatus(ctx context.Context, session authn.Session, id string) (RuleExecutionStatus, error) { + rule, err := re.repo.ViewRule(ctx, id) + if err != nil { + return RuleExecutionStatus{}, errors.Wrap(svcerr.ErrViewEntity, err) + } + + status := RuleExecutionStatus{ + Rule: rule, + WorkerRunning: false, + QueueLength: 0, + } + + if re.workerMgr != nil { + workerStatus := re.workerMgr.GetWorkerStatus(id) + if workerStatus != nil { + if running, ok := workerStatus["running"].(bool); ok { + status.WorkerRunning = running + } + if queueLen, ok := workerStatus["queue_length"].(int); ok { + status.QueueLength = queueLen + } + } + } + + return status, nil +} + +func (re *re) updateRuleExecutionStatus(ctx context.Context, ruleID string, status ExecutionStatus, err error) { + now := time.Now().UTC() + rule := Rule{ + ID: ruleID, + LastRunStatus: status, + LastRunTime: &now, + } + + if err != nil { + rule.LastRunErrorMessage = err.Error() + } + + currentRule, viewErr := re.repo.ViewRule(ctx, ruleID) + if viewErr != nil { + switch status { + case SuccessStatus, PartialSuccessStatus, FailureStatus: + rule.ExecutionCount = 1 + default: + rule.ExecutionCount = 0 + } + } else { + rule.ExecutionCount = currentRule.ExecutionCount + + switch status { + case SuccessStatus, PartialSuccessStatus, FailureStatus: + rule.ExecutionCount = currentRule.ExecutionCount + 1 + } + } + + if err := re.repo.UpdateRuleExecutionStatus(ctx, rule); err != nil { + re.runInfo <- pkglog.RunInfo{ + Level: slog.LevelWarn, + Message: fmt.Sprintf("failed to update rule execution status: %s", err), + Details: []slog.Attr{ + slog.String("rule_id", ruleID), + slog.String("status", status.String()), + }, + } + } +} diff --git a/re/service_test.go b/re/service_test.go index b5fb8c3de..c8af8ca3d 100644 --- a/re/service_test.go +++ b/re/service_test.go @@ -54,11 +54,56 @@ func newService(t *testing.T, runInfo chan pkglog.RunInfo) (re.Service, *mocks.R pubsub := pubsubmocks.NewPubSub(t) readersSvc := new(readmocks.ReadersServiceClient) e := new(emocks.Emailer) - return re.NewService(repo, runInfo, idProvider, pubsub, pubsub, pubsub, mockTicker, e, readersSvc), repo, pubsub, mockTicker + + svc := re.NewService(repo, runInfo, idProvider, pubsub, pubsub, pubsub, mockTicker, e, readersSvc) + + repocall1 := repo.On("ListRules", mock.Anything, mock.Anything).Return(re.Page{Rules: []re.Rule{}}, nil) + defer repocall1.Unset() + + tickCh := make(chan time.Time) + mockTicker.On("Tick").Return((<-chan time.Time)(tickCh)) + mockTicker.On("Stop").Return() + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + + schedulerDone := make(chan struct{}) + go func() { + _ = svc.StartScheduler(ctx) + close(schedulerDone) + }() + + t.Cleanup(func() { + cancel() + + // Wait for scheduler to stop (can take time for worker cleanup) + <-schedulerDone + + time.Sleep(100 * time.Millisecond) + }) + time.Sleep(50 * time.Millisecond) + + return svc, repo, pubsub, mockTicker +} + +func newServiceForSchedulerTest(t *testing.T, runInfo chan pkglog.RunInfo) (re.Service, *mocks.Repository, *pubsubmocks.PubSub, *tmocks.Ticker) { + repo := new(mocks.Repository) + mockTicker := new(tmocks.Ticker) + idProvider := uuid.NewMock() + pubsub := pubsubmocks.NewPubSub(t) + readersSvc := new(readmocks.ReadersServiceClient) + e := new(emocks.Emailer) + + svc := re.NewService(repo, runInfo, idProvider, pubsub, pubsub, pubsub, mockTicker, e, readersSvc) + + tickCh := make(chan time.Time) + mockTicker.On("Tick").Return((<-chan time.Time)(tickCh)) + mockTicker.On("Stop").Return() + + return svc, repo, pubsub, mockTicker } func TestAddRule(t *testing.T) { - svc, repo, _, _ := newService(t, make(chan pkglog.RunInfo)) + svc, repo, _, _ := newService(t, make(chan pkglog.RunInfo, 100)) ruleName := namegen.Generate() now := time.Now().Add(time.Hour) cases := []struct { @@ -133,7 +178,7 @@ func TestAddRule(t *testing.T) { } func TestViewRule(t *testing.T) { - svc, repo, _, _ := newService(t, make(chan pkglog.RunInfo)) + svc, repo, _, _ := newService(t, make(chan pkglog.RunInfo, 100)) now := time.Now().Add(time.Hour) cases := []struct { @@ -191,7 +236,7 @@ func TestViewRule(t *testing.T) { } func TestUpdateRule(t *testing.T) { - svc, repo, _, _ := newService(t, make(chan pkglog.RunInfo)) + svc, repo, _, _ := newService(t, make(chan pkglog.RunInfo, 100)) newName := namegen.Generate() now := time.Now().Add(time.Hour) @@ -276,7 +321,7 @@ func TestUpdateRule(t *testing.T) { } func TestUpdateRuleTags(t *testing.T) { - svc, repo, _, _ := newService(t, make(chan pkglog.RunInfo)) + svc, repo, _, _ := newService(t, make(chan pkglog.RunInfo, 100)) cases := []struct { desc string @@ -332,7 +377,7 @@ func TestUpdateRuleTags(t *testing.T) { } func TestListRules(t *testing.T) { - svc, repo, _, _ := newService(t, make(chan pkglog.RunInfo)) + svc, repo, _, _ := newService(t, make(chan pkglog.RunInfo, 100)) numRules := 50 now := time.Now().Add(time.Hour) var rules []re.Rule @@ -418,7 +463,10 @@ func TestListRules(t *testing.T) { DomainID: domainID, }, pageMeta: re.PageMeta{}, - err: svcerr.ErrViewEntity, + res: re.Page{ + Rules: []re.Rule{}, + }, + err: svcerr.ErrViewEntity, }, } @@ -437,7 +485,7 @@ func TestListRules(t *testing.T) { } func TestRemoveRule(t *testing.T) { - svc, repo, _, _ := newService(t, make(chan pkglog.RunInfo)) + svc, repo, _, _ := newService(t, make(chan pkglog.RunInfo, 100)) cases := []struct { desc string @@ -477,7 +525,7 @@ func TestRemoveRule(t *testing.T) { } func TestEnableRule(t *testing.T) { - svc, repo, _, _ := newService(t, make(chan pkglog.RunInfo)) + svc, repo, _, _ := newService(t, make(chan pkglog.RunInfo, 100)) now := time.Now() @@ -536,7 +584,7 @@ func TestEnableRule(t *testing.T) { } func TestDisableRule(t *testing.T) { - svc, repo, _, _ := newService(t, make(chan pkglog.RunInfo)) + svc, repo, _, _ := newService(t, make(chan pkglog.RunInfo, 100)) now := time.Now() @@ -594,8 +642,69 @@ func TestDisableRule(t *testing.T) { } } +func TestAbortRuleExecution(t *testing.T) { + cases := []struct { + desc string + session authn.Session + id string + rule re.Rule + repoErr error + err error + }{ + { + desc: "abort rule execution with wrong status", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + id: ruleID, + rule: re.Rule{ + ID: ruleID, + LastRunStatus: re.SuccessStatus, + }, + repoErr: nil, + err: svcerr.ErrMalformedEntity, + }, + { + desc: "abort rule execution with never_run status", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + id: ruleID, + rule: re.Rule{ + ID: ruleID, + LastRunStatus: re.NeverRunStatus, + }, + repoErr: nil, + err: svcerr.ErrMalformedEntity, + }, + { + desc: "abort rule execution with non-existent rule", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + id: "non-existent-rule", + rule: re.Rule{}, + repoErr: svcerr.ErrNotFound, + err: svcerr.ErrViewEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + svc, repo, _, _ := newService(t, make(chan pkglog.RunInfo, 100)) + repo.On("ViewRule", mock.Anything, tc.id).Return(tc.rule, tc.repoErr) + err := svc.AbortRuleExecution(context.Background(), tc.session, tc.id) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + repo.AssertExpectations(t) + }) + } +} + func TestHandle(t *testing.T) { - svc, repo, pubmocks, _ := newService(t, make(chan pkglog.RunInfo)) + svc, repo, pubmocks, _ := newService(t, make(chan pkglog.RunInfo, 100)) now := time.Now() scheduled := false @@ -619,7 +728,7 @@ func TestHandle(t *testing.T) { listErr: nil, }, { - desc: "consume message with rules", + desc: "consume message with enabled rules", message: &messaging.Message{ Channel: inputChannel, Created: now.Unix(), @@ -646,6 +755,34 @@ func TestHandle(t *testing.T) { }, listErr: nil, }, + { + desc: "consume message with disabled rules", + message: &messaging.Message{ + Channel: inputChannel, + Created: now.Unix(), + }, + page: re.Page{ + Rules: []re.Rule{ + { + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + InputChannel: inputChannel, + Status: re.DisabledStatus, + Logic: re.Script{ + Type: re.ScriptType(0), + }, + Outputs: re.Outputs{ + &outputs.ChannelPublisher{ + Channel: "output.channel", + Topic: "output.topic", + }, + }, + Schedule: schedule, + }, + }, + }, + listErr: nil, + }, } for _, tc := range cases { @@ -657,23 +794,36 @@ func TestHandle(t *testing.T) { err = tc.listErr } }) - repoCall1 := pubmocks.On("Publish", mock.Anything, mock.Anything, mock.Anything).Return(tc.publishErr) + + enabledRulesCount := 0 + for _, rule := range tc.page.Rules { + if rule.Status == re.EnabledStatus { + enabledRulesCount++ + } + } + + if enabledRulesCount > 0 { + repoCall1 := pubmocks.On("Publish", mock.Anything, mock.Anything, mock.Anything).Return(tc.publishErr).Times(enabledRulesCount) + defer repoCall1.Unset() + } err = svc.Handle(tc.message) - assert.Nil(t, err) - assert.True(t, errors.Contains(err, tc.listErr), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.listErr, err)) + assert.NoError(t, err, "Handle should not return errors with worker architecture") + + if tc.listErr != nil { + assert.True(t, errors.Contains(err, tc.listErr), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.listErr, err)) + } repoCall.Unset() - repoCall1.Unset() }) } } func TestStartScheduler(t *testing.T) { now := time.Now().Truncate(time.Minute) - ri := make(chan pkglog.RunInfo) - svc, repo, _, ticker := newService(t, ri) + ri := make(chan pkglog.RunInfo, 100) + svc, repo, _, ticker := newServiceForSchedulerTest(t, ri) ctxCases := []struct { desc string diff --git a/re/worker.go b/re/worker.go new file mode 100644 index 000000000..5fce93426 --- /dev/null +++ b/re/worker.go @@ -0,0 +1,649 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package re + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/absmach/supermq/pkg/messaging" + "golang.org/x/sync/errgroup" +) + +// WorkerMessage represents a message to be processed by a rule worker. +type WorkerMessage struct { + Message *messaging.Message + Rule Rule +} + +// RuleWorker manages execution of a single rule in its own goroutine. +type RuleWorker struct { + rule Rule + engine *re + msgChan chan WorkerMessage + ctx context.Context + cancel context.CancelFunc + running int32 + maxQueueSize int +} + +// NewRuleWorker creates a new rule worker for the given rule. +func NewRuleWorker(rule Rule, engine *re) *RuleWorker { + return &RuleWorker{ + rule: rule, + engine: engine, + msgChan: make(chan WorkerMessage, 100), + running: 0, // 0 = not running, 1 = running + maxQueueSize: 100, + } +} + +// Start begins the worker goroutine for processing messages. +func (w *RuleWorker) Start(ctx context.Context) { + if !atomic.CompareAndSwapInt32(&w.running, 0, 1) { + return + } + + w.ctx, w.cancel = context.WithCancel(ctx) + go func() { + defer atomic.StoreInt32(&w.running, 0) + w.run(w.ctx) + }() +} + +// Stop stops the worker goroutine and waits for it to finish. +func (w *RuleWorker) Stop() error { + if !atomic.CompareAndSwapInt32(&w.running, 1, 0) { + return nil + } + + if w.cancel != nil { + w.cancel() + } + + return nil +} + +// AbortExecution aborts the current execution if the worker is running. +func (w *RuleWorker) AbortExecution(ctx context.Context) { + if atomic.LoadInt32(&w.running) == 1 { + if w.cancel != nil { + w.cancel() + } + + w.engine.updateRuleExecutionStatus(ctx, w.rule.ID, AbortedStatus, fmt.Errorf("rule execution manually aborted")) + } +} + +// Send sends a message to the worker for processing. +func (w *RuleWorker) Send(msg WorkerMessage) bool { + if atomic.LoadInt32(&w.running) == 0 { + return false + } + + queueLen := len(w.msgChan) + if queueLen >= w.maxQueueSize { + return false + } + + if queueLen > 0 { + w.engine.updateRuleExecutionStatus(context.Background(), msg.Rule.ID, QueuedStatus, nil) + } + + select { + case w.msgChan <- msg: + return true + default: + return false + } +} + +// IsRunning returns true if the worker is currently running. +func (w *RuleWorker) IsRunning() bool { + return atomic.LoadInt32(&w.running) == 1 +} + +// GetQueueLength returns the current number of queued messages. +func (w *RuleWorker) GetQueueLength() int { + return len(w.msgChan) +} + +// GetRule returns the current rule configuration. +func (w *RuleWorker) GetRule() Rule { + return w.rule +} + +// run is the main worker loop that processes messages. +func (w *RuleWorker) run(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case workerMsg := <-w.msgChan: + w.processMessage(ctx, workerMsg) + } + } +} + +// processMessage processes a single message using the rule logic. +func (w *RuleWorker) processMessage(ctx context.Context, workerMsg WorkerMessage) { + currentRule := w.GetRule() + + if currentRule.Status != EnabledStatus { + return + } + + select { + case <-w.ctx.Done(): + w.engine.updateRuleExecutionStatus(ctx, currentRule.ID, AbortedStatus, w.ctx.Err()) + return + default: + } + + runInfo := w.engine.process(ctx, currentRule, workerMsg.Message) + + if w.ctx.Err() == context.Canceled { + w.engine.updateRuleExecutionStatus(ctx, currentRule.ID, AbortedStatus, w.ctx.Err()) + return + } + + select { + case w.engine.runInfo <- runInfo: + default: + } +} + +type WorkerCommandType uint8 + +const ( + CmdAdd WorkerCommandType = iota + CmdRemove + CmdUpdate + CmdStopAll + CmdCount + CmdList + CmdAbort + CmdGetStatus +) + +func (c WorkerCommandType) String() string { + switch c { + case CmdAdd: + return "add" + case CmdRemove: + return "remove" + case CmdUpdate: + return "update" + case CmdStopAll: + return "stop_all" + case CmdCount: + return "count" + case CmdList: + return "list" + case CmdAbort: + return "abort" + case CmdGetStatus: + return "get_status" + default: + return "unknown" + } +} + +// WorkerManagerCommand represents commands for worker management. +type WorkerManagerCommand struct { + Type WorkerCommandType + Rule Rule + RuleID string + Response chan interface{} +} + +// WorkerManager manages all rule workers. +type WorkerManager struct { + workers map[string]*RuleWorker + engine *re + g *errgroup.Group + ctx context.Context + commandCh chan WorkerManagerCommand + errorCh chan error + mu sync.RWMutex + running int32 +} + +// NewWorkerManager creates a new worker manager. +func NewWorkerManager(ctx context.Context, engine *re) *WorkerManager { + g, ctx := errgroup.WithContext(ctx) + wm := &WorkerManager{ + workers: make(map[string]*RuleWorker), + engine: engine, + g: g, + ctx: ctx, + commandCh: make(chan WorkerManagerCommand, 100), + errorCh: make(chan error, 100), + running: 0, + } + + wm.g.Go(func() error { + return wm.manageWorkers(ctx) + }) + + atomic.StoreInt32(&wm.running, 1) + return wm +} + +func (wm *WorkerManager) manageWorkers(ctx context.Context) error { + defer func() { + atomic.StoreInt32(&wm.running, 0) + }() + + for { + select { + case <-ctx.Done(): + for _, worker := range wm.workers { + if err := worker.Stop(); err != nil { + select { + case wm.errorCh <- err: + default: + } + } + } + wm.workers = make(map[string]*RuleWorker) + return ctx.Err() + + case cmd := <-wm.commandCh: + wm.handleCommand(cmd) + } + } +} + +func (wm *WorkerManager) handleCommand(cmd WorkerManagerCommand) { + switch cmd.Type { + case CmdAdd: + if err := wm.addWorker(cmd.Rule); err != nil { + select { + case wm.errorCh <- err: + default: + } + } + case CmdRemove: + if err := wm.removeWorker(cmd.RuleID); err != nil { + select { + case wm.errorCh <- err: + default: + } + } + case CmdUpdate: + if err := wm.updateWorker(cmd.Rule); err != nil { + select { + case wm.errorCh <- err: + default: + } + } + case CmdAbort: + wm.abortWorker(cmd.RuleID) + case CmdStopAll: + if err := wm.stopAll(); err != nil { + select { + case wm.errorCh <- err: + default: + } + } + if cmd.Response != nil { + cmd.Response <- true + } + case CmdCount: + wm.mu.RLock() + count := len(wm.workers) + wm.mu.RUnlock() + if cmd.Response != nil { + cmd.Response <- count + } + case CmdList: + wm.mu.RLock() + ruleIDs := make([]string, 0, len(wm.workers)) + for ruleID := range wm.workers { + ruleIDs = append(ruleIDs, ruleID) + } + wm.mu.RUnlock() + if cmd.Response != nil { + cmd.Response <- ruleIDs + } + case CmdGetStatus: + wm.mu.RLock() + var status map[string]interface{} + if worker, exists := wm.workers[cmd.RuleID]; exists { + status = map[string]interface{}{ + "running": worker.IsRunning(), + "queue_length": worker.GetQueueLength(), + } + } + wm.mu.RUnlock() + if cmd.Response != nil { + cmd.Response <- status + } + } +} + +func (wm *WorkerManager) addWorker(rule Rule) error { + wm.mu.Lock() + defer wm.mu.Unlock() + + oldWorker, exists := wm.workers[rule.ID] + + if rule.Status != EnabledStatus { + if exists { + if err := oldWorker.Stop(); err != nil { + return err + } + } + delete(wm.workers, rule.ID) + return nil + } + + newWorker := NewRuleWorker(rule, wm.engine) + newWorker.Start(wm.ctx) + + wm.workers[rule.ID] = newWorker + + if exists { + if err := oldWorker.Stop(); err != nil { + select { + case wm.errorCh <- err: + default: + } + } + } + + return nil +} + +func (wm *WorkerManager) removeWorker(ruleID string) error { + wm.mu.Lock() + defer wm.mu.Unlock() + + if worker, ok := wm.workers[ruleID]; ok { + if err := worker.Stop(); err != nil { + return err + } + delete(wm.workers, ruleID) + } + return nil +} + +func (wm *WorkerManager) updateWorker(rule Rule) error { + wm.mu.Lock() + defer wm.mu.Unlock() + + oldWorker, exists := wm.workers[rule.ID] + + if rule.Status != EnabledStatus { + if exists { + if err := oldWorker.Stop(); err != nil { + return err + } + } + delete(wm.workers, rule.ID) + return nil + } + + newWorker := NewRuleWorker(rule, wm.engine) + newWorker.Start(wm.ctx) + + wm.workers[rule.ID] = newWorker + + if exists { + if err := oldWorker.Stop(); err != nil { + select { + case wm.errorCh <- err: + default: + } + } + } + + return nil +} + +func (wm *WorkerManager) abortWorker(ruleID string) { + wm.mu.Lock() + defer wm.mu.Unlock() + + worker, exists := wm.workers[ruleID] + if !exists { + return + } + + worker.AbortExecution(wm.ctx) + + rule := worker.GetRule() + + if err := worker.Stop(); err != nil { + select { + case wm.errorCh <- err: + default: + } + } + + delete(wm.workers, ruleID) + + if rule.Status == EnabledStatus { + newWorker := NewRuleWorker(rule, wm.engine) + newWorker.Start(wm.ctx) + wm.workers[ruleID] = newWorker + } +} + +func (wm *WorkerManager) sendMessage(msg *messaging.Message, rule Rule) bool { + wm.mu.RLock() + worker, ok := wm.workers[rule.ID] + wm.mu.RUnlock() + + if !ok || !worker.IsRunning() { + return false + } + + return worker.Send(WorkerMessage{ + Message: msg, + Rule: rule, + }) +} + +func (wm *WorkerManager) stopAll() error { + wm.mu.Lock() + defer wm.mu.Unlock() + + for _, worker := range wm.workers { + if err := worker.Stop(); err != nil { + return err + } + } + wm.workers = make(map[string]*RuleWorker) + return nil +} + +func (wm *WorkerManager) AddWorker(ctx context.Context, rule Rule) { + if atomic.LoadInt32(&wm.running) == 0 { + return + } + + cmd := WorkerManagerCommand{ + Type: CmdAdd, + Rule: rule, + } + + select { + case wm.commandCh <- cmd: + case <-ctx.Done(): + } +} + +func (wm *WorkerManager) RemoveWorker(ruleID string) { + if atomic.LoadInt32(&wm.running) == 0 { + return + } + + cmd := WorkerManagerCommand{ + Type: CmdRemove, + RuleID: ruleID, + } + + wm.commandCh <- cmd +} + +func (wm *WorkerManager) UpdateWorker(ctx context.Context, rule Rule) { + if atomic.LoadInt32(&wm.running) == 0 { + return + } + + cmd := WorkerManagerCommand{ + Type: CmdUpdate, + Rule: rule, + } + + select { + case wm.commandCh <- cmd: + case <-ctx.Done(): + } +} + +func (wm *WorkerManager) SendMessage(msg *messaging.Message, rule Rule) bool { + if atomic.LoadInt32(&wm.running) == 0 { + return false + } + + return wm.sendMessage(msg, rule) +} + +func (wm *WorkerManager) StopAll() error { + if !atomic.CompareAndSwapInt32(&wm.running, 1, 0) { + return nil + } + + // If context is already cancelled, manageWorkers is stopping, just wait + select { + case <-wm.ctx.Done(): + return wm.g.Wait() + default: + } + + responseCh := make(chan interface{}, 1) + cmd := WorkerManagerCommand{ + Type: CmdStopAll, + Response: responseCh, + } + + select { + case wm.commandCh <- cmd: + select { + case <-responseCh: + case <-wm.ctx.Done(): + case <-time.After(100 * time.Millisecond): + } + case <-time.After(100 * time.Millisecond): + } + + return wm.g.Wait() +} + +func (wm *WorkerManager) GetWorkerCount() int { + if atomic.LoadInt32(&wm.running) == 0 { + return 0 + } + + responseCh := make(chan interface{}, 1) + cmd := WorkerManagerCommand{ + Type: CmdCount, + Response: responseCh, + } + + select { + case wm.commandCh <- cmd: + if result := <-responseCh; result != nil { + if count, ok := result.(int); ok { + return count + } + } + default: + } + return 0 +} + +func (wm *WorkerManager) ListWorkers() []string { + if atomic.LoadInt32(&wm.running) == 0 { + return nil + } + + responseCh := make(chan interface{}, 1) + cmd := WorkerManagerCommand{ + Type: CmdList, + Response: responseCh, + } + + select { + case wm.commandCh <- cmd: + if result := <-responseCh; result != nil { + if list, ok := result.([]string); ok { + return list + } + } + default: + } + return nil +} + +func (wm *WorkerManager) ErrorChan() <-chan error { + return wm.errorCh +} + +func (wm *WorkerManager) RefreshWorkers(ctx context.Context, rules []Rule) { + if atomic.LoadInt32(&wm.running) == 0 { + return + } + + for _, rule := range rules { + if rule.Status == EnabledStatus { + wm.UpdateWorker(ctx, rule) + } else { + wm.RemoveWorker(rule.ID) + } + } +} + +func (wm *WorkerManager) AbortRule(ruleID string) { + if atomic.LoadInt32(&wm.running) == 0 { + return + } + + cmd := WorkerManagerCommand{ + Type: CmdAbort, + RuleID: ruleID, + } + + wm.commandCh <- cmd +} + +func (wm *WorkerManager) GetWorkerStatus(ruleID string) map[string]interface{} { + if atomic.LoadInt32(&wm.running) == 0 { + return nil + } + + responseCh := make(chan interface{}, 1) + cmd := WorkerManagerCommand{ + Type: CmdGetStatus, + RuleID: ruleID, + Response: responseCh, + } + + select { + case wm.commandCh <- cmd: + if result := <-responseCh; result != nil { + if status, ok := result.(map[string]interface{}); ok { + return status + } + } + default: + } + return nil +}