From b49f0c2e91d062db1faa9c7d84a80479ba48cab2 Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Mon, 22 Sep 2025 14:39:45 +0300 Subject: [PATCH 01/25] initial implementation Signed-off-by: nyagamunene --- re/service.go | 55 ++++++++-- re/worker.go | 286 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 331 insertions(+), 10 deletions(-) create mode 100644 re/worker.go diff --git a/re/service.go b/re/service.go index 241fe6ea6..4261480b5 100644 --- a/re/service.go +++ b/re/service.go @@ -10,6 +10,7 @@ import ( grpcReadersV1 "github.com/absmach/magistrala/api/grpc/readers/v1" "github.com/absmach/magistrala/pkg/emailer" pkglog "github.com/absmach/magistrala/pkg/logger" + "github.com/absmach/magistrala/pkg/schedule" "github.com/absmach/magistrala/pkg/ticker" "github.com/absmach/supermq" "github.com/absmach/supermq/pkg/authn" @@ -19,19 +20,20 @@ import ( ) type re struct { - repo Repository - runInfo chan pkglog.RunInfo - idp supermq.IDProvider - rePubSub messaging.PubSub - writersPub messaging.Publisher - alarmsPub messaging.Publisher - ticker ticker.Ticker - email emailer.Emailer - readers grpcReadersV1.ReadersServiceClient + repo Repository + runInfo chan pkglog.RunInfo + idp supermq.IDProvider + rePubSub messaging.PubSub + writersPub messaging.Publisher + alarmsPub messaging.Publisher + ticker ticker.Ticker + email emailer.Emailer + readers grpcReadersV1.ReadersServiceClient + workerMgr *WorkerManager } func NewService(repo Repository, runInfo chan pkglog.RunInfo, idp supermq.IDProvider, rePubSub messaging.PubSub, writersPub, alarmsPub messaging.Publisher, tck ticker.Ticker, emailer emailer.Emailer, readers grpcReadersV1.ReadersServiceClient) Service { - return &re{ + reEngine := &re{ repo: repo, idp: idp, runInfo: runInfo, @@ -42,6 +44,8 @@ func NewService(repo Repository, runInfo chan pkglog.RunInfo, idp supermq.IDProv email: emailer, readers: readers, } + reEngine.workerMgr = NewWorkerManager(reEngine) + return reEngine } func (re *re) AddRule(ctx context.Context, session authn.Session, r Rule) (Rule, error) { @@ -66,6 +70,10 @@ func (re *re) AddRule(ctx context.Context, session authn.Session, r Rule) (Rule, return Rule{}, errors.Wrap(svcerr.ErrCreateEntity, err) } + if rule.Status == EnabledStatus && rule.Schedule.Recurring == schedule.None { + re.workerMgr.AddWorker(ctx, rule) + } + return rule, nil } @@ -86,6 +94,12 @@ func (re *re) UpdateRule(ctx context.Context, session authn.Session, r Rule) (Ru return Rule{}, errors.Wrap(svcerr.ErrUpdateEntity, err) } + if rule.Schedule.Recurring == schedule.None { + re.workerMgr.UpdateWorker(ctx, rule) + } else { + re.workerMgr.RemoveWorker(rule.ID) + } + return rule, nil } @@ -108,6 +122,14 @@ func (re *re) UpdateRuleSchedule(ctx context.Context, session authn.Session, r R return Rule{}, errors.Wrap(svcerr.ErrUpdateEntity, err) } + // Update worker based on schedule + if rule.Schedule.Recurring == schedule.None && rule.Status == EnabledStatus { + re.workerMgr.UpdateWorker(ctx, rule) + } else { + // Rule is scheduled or disabled, remove from workers + re.workerMgr.RemoveWorker(rule.ID) + } + return rule, nil } @@ -125,6 +147,9 @@ func (re *re) RemoveRule(ctx context.Context, session authn.Session, id string) return errors.Wrap(svcerr.ErrRemoveEntity, err) } + // Remove worker for the deleted rule + re.workerMgr.RemoveWorker(id) + return nil } @@ -143,6 +168,12 @@ func (re *re) EnableRule(ctx context.Context, session authn.Session, id string) if err != nil { return Rule{}, errors.Wrap(svcerr.ErrUpdateEntity, err) } + + // Add worker for enabled rule if it's not scheduled + if rule.Schedule.Recurring == schedule.None { + re.workerMgr.AddWorker(ctx, rule) + } + return rule, nil } @@ -161,6 +192,10 @@ func (re *re) DisableRule(ctx context.Context, session authn.Session, id string) if err != nil { return Rule{}, errors.Wrap(svcerr.ErrUpdateEntity, err) } + + // Remove worker for disabled rule + re.workerMgr.RemoveWorker(id) + return rule, nil } diff --git a/re/worker.go b/re/worker.go new file mode 100644 index 000000000..750998bb8 --- /dev/null +++ b/re/worker.go @@ -0,0 +1,286 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package re + +import ( + "context" + "sync" + + "github.com/absmach/supermq/pkg/messaging" +) + +// WorkerMessage represents a message to be processed by a rule worker. +type WorkerMessage struct { + Message *messaging.Message + Rule Rule +} + +// RuleWorker manages execution of a single rule in its own goroutine. +type RuleWorker struct { + rule Rule + engine *re + msgChan chan WorkerMessage + stopChan chan struct{} + doneChan chan struct{} + running bool + mu sync.RWMutex +} + +// NewRuleWorker creates a new rule worker for the given rule. +func NewRuleWorker(rule Rule, engine *re) *RuleWorker { + return &RuleWorker{ + rule: rule, + engine: engine, + msgChan: make(chan WorkerMessage, 100), // Buffer to prevent blocking + stopChan: make(chan struct{}), + doneChan: make(chan struct{}), + running: false, + } +} + +// Start begins the worker goroutine for processing messages. +func (w *RuleWorker) Start(ctx context.Context) { + w.mu.Lock() + if w.running { + w.mu.Unlock() + return + } + w.running = true + w.mu.Unlock() + + go w.run(ctx) +} + +// Stop stops the worker goroutine and waits for it to finish. +func (w *RuleWorker) Stop() { + w.mu.Lock() + if !w.running { + w.mu.Unlock() + return + } + w.mu.Unlock() + + close(w.stopChan) + <-w.doneChan +} + +// Send sends a message to the worker for processing. +func (w *RuleWorker) Send(msg WorkerMessage) bool { + w.mu.RLock() + running := w.running + w.mu.RUnlock() + + if !running { + return false + } + + select { + case w.msgChan <- msg: + return true + default: + return false + } +} + +// IsRunning returns true if the worker is currently running. +func (w *RuleWorker) IsRunning() bool { + w.mu.RLock() + defer w.mu.RUnlock() + return w.running +} + +// UpdateRule updates the rule configuration for this worker. +func (w *RuleWorker) UpdateRule(rule Rule) { + w.mu.Lock() + w.rule = rule + w.mu.Unlock() +} + +// GetRule returns the current rule configuration. +func (w *RuleWorker) GetRule() Rule { + w.mu.RLock() + defer w.mu.RUnlock() + return w.rule +} + +// run is the main worker loop that processes messages. +func (w *RuleWorker) run(ctx context.Context) { + defer func() { + w.mu.Lock() + w.running = false + w.mu.Unlock() + close(w.doneChan) + }() + + for { + select { + case <-ctx.Done(): + return + case <-w.stopChan: + return + case workerMsg := <-w.msgChan: + w.processMessage(ctx, workerMsg) + } + } +} + +// processMessage processes a single message using the rule logic. +func (w *RuleWorker) processMessage(ctx context.Context, workerMsg WorkerMessage) { + currentRule := w.GetRule() + + if currentRule.Status != EnabledStatus { + return + } + + runInfo := w.engine.process(ctx, currentRule, workerMsg.Message) + + // Send run info to the logging channel + select { + case w.engine.runInfo <- runInfo: + default: + } +} + +// WorkerManager manages all rule workers. +type WorkerManager struct { + workers map[string]*RuleWorker + engine *re + mu sync.RWMutex +} + +// NewWorkerManager creates a new worker manager. +func NewWorkerManager(engine *re) *WorkerManager { + return &WorkerManager{ + workers: make(map[string]*RuleWorker), + engine: engine, + } +} + +// AddWorker adds a new worker for the given rule. +func (wm *WorkerManager) AddWorker(ctx context.Context, rule Rule) { + wm.mu.Lock() + defer wm.mu.Unlock() + + if existing, ok := wm.workers[rule.ID]; ok { + existing.Stop() + } + + if rule.Status != EnabledStatus { + delete(wm.workers, rule.ID) + return + } + + worker := NewRuleWorker(rule, wm.engine) + worker.Start(ctx) + wm.workers[rule.ID] = worker +} + +// RemoveWorker removes and stops the worker for the given rule ID. +func (wm *WorkerManager) RemoveWorker(ruleID string) { + wm.mu.Lock() + defer wm.mu.Unlock() + + if worker, ok := wm.workers[ruleID]; ok { + worker.Stop() + delete(wm.workers, ruleID) + } +} + +// UpdateWorker updates the rule configuration for an existing worker. +func (wm *WorkerManager) UpdateWorker(ctx context.Context, rule Rule) { + wm.mu.Lock() + defer wm.mu.Unlock() + + if rule.Status != EnabledStatus { + if worker, ok := wm.workers[rule.ID]; ok { + worker.Stop() + delete(wm.workers, rule.ID) + } + return + } + + if worker, ok := wm.workers[rule.ID]; ok { + worker.UpdateRule(rule) + } else { + worker := NewRuleWorker(rule, wm.engine) + worker.Start(ctx) + wm.workers[rule.ID] = worker + } +} + +// SendMessage sends a message to the appropriate worker for processing. +func (wm *WorkerManager) SendMessage(msg *messaging.Message, rule Rule) bool { + wm.mu.RLock() + worker, ok := wm.workers[rule.ID] + wm.mu.RUnlock() + + if !ok || !worker.IsRunning() { + return false + } + + return worker.Send(WorkerMessage{ + Message: msg, + Rule: rule, + }) +} + +// StopAll stops all workers. +func (wm *WorkerManager) StopAll() { + wm.mu.Lock() + defer wm.mu.Unlock() + + for _, worker := range wm.workers { + worker.Stop() + } + wm.workers = make(map[string]*RuleWorker) +} + +// GetWorkerCount returns the number of active workers. +func (wm *WorkerManager) GetWorkerCount() int { + wm.mu.RLock() + defer wm.mu.RUnlock() + return len(wm.workers) +} + +// ListWorkers returns a slice of rule IDs that have active workers. +func (wm *WorkerManager) ListWorkers() []string { + wm.mu.RLock() + defer wm.mu.RUnlock() + + ruleIDs := make([]string, 0, len(wm.workers)) + for ruleID := range wm.workers { + ruleIDs = append(ruleIDs, ruleID) + } + return ruleIDs +} + +// RefreshWorkers synchronizes workers with the current set of enabled rules. +func (wm *WorkerManager) RefreshWorkers(ctx context.Context, rules []Rule) { + wm.mu.Lock() + defer wm.mu.Unlock() + + currentRules := make(map[string]Rule) + for _, rule := range rules { + if rule.Status == EnabledStatus { + currentRules[rule.ID] = rule + } + } + + for ruleID, worker := range wm.workers { + if _, exists := currentRules[ruleID]; !exists { + worker.Stop() + delete(wm.workers, ruleID) + } + } + + for ruleID, rule := range currentRules { + if worker, exists := wm.workers[ruleID]; exists { + worker.UpdateRule(rule) + } else { + worker := NewRuleWorker(rule, wm.engine) + worker.Start(ctx) + wm.workers[ruleID] = worker + } + } +} From e6f81aca56a0c56fd80d070db2d7493fc4089177 Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Mon, 22 Sep 2025 18:01:10 +0300 Subject: [PATCH 02/25] initial implementation Signed-off-by: nyagamunene --- re/handlers.go | 55 ++++++- re/service.go | 17 +- re/worker.go | 416 +++++++++++++++++++++++++++++++++++-------------- 3 files changed, 359 insertions(+), 129 deletions(-) diff --git a/re/handlers.go b/re/handlers.go index ae4274da6..2c093873c 100644 --- a/re/handlers.go +++ b/re/handlers.go @@ -33,6 +33,12 @@ func (re *re) Handle(msg *messaging.Message) error { if n := len(msg.Payload); n > maxPayload { return errors.New(pldExceededFmt + strconv.Itoa(n)) } + + // If WorkerManager is not initialized yet, fall back to old behavior + if re.workerMgr == nil { + return re.handleWithoutWorkers(msg) + } + // Skip filtering by message topic and fetch all non-scheduled rules instead. // It's cleaner and more efficient to match wildcards in Go, but we can // revisit this if it ever becomes a performance bottleneck. @@ -50,9 +56,38 @@ func (re *re) Handle(msg *messaging.Message) error { for _, r := range page.Rules { if matchSubject(msg.Subtopic, r.InputTopic) { - go func(ctx context.Context) { - re.runInfo <- re.process(ctx, r, msg) - }(ctx) + // Send message to the appropriate worker + if !re.workerMgr.SendMessage(msg, r) { + // Worker not found or busy, fall back to direct processing + go func(ctx context.Context, rule Rule) { + re.runInfo <- re.process(ctx, rule, msg) + }(ctx, r) + } + } + } + + return nil +} + +// handleWithoutWorkers processes messages using the legacy direct goroutine approach +func (re *re) handleWithoutWorkers(msg *messaging.Message) error { + pm := PageMeta{ + Domain: msg.Domain, + InputChannel: msg.Channel, + Status: EnabledStatus, + Scheduled: &scheduledFalse, + } + ctx := context.Background() + page, err := re.repo.ListRules(ctx, pm) + if err != nil { + return err + } + + for _, r := range page.Rules { + if matchSubject(msg.Subtopic, r.InputTopic) { + go func(ctx context.Context, rule Rule) { + re.runInfo <- re.process(ctx, rule, msg) + }(ctx, r) } } @@ -117,7 +152,21 @@ func (re *re) handleOutput(ctx context.Context, o Runnable, r Rule, msg *messagi } func (re *re) StartScheduler(ctx context.Context) error { + // Initialize WorkerManager with context + re.workerMgr = NewWorkerManager(re, ctx) + + // Load and start workers for existing enabled rules + pm := PageMeta{ + Status: EnabledStatus, + } + page, err := re.repo.ListRules(ctx, pm) + if err == nil { + re.workerMgr.RefreshWorkers(ctx, page.Rules) + } + defer re.ticker.Stop() + defer re.workerMgr.StopAll() + for { select { case <-ctx.Done(): diff --git a/re/service.go b/re/service.go index 4261480b5..604baccea 100644 --- a/re/service.go +++ b/re/service.go @@ -44,7 +44,6 @@ func NewService(repo Repository, runInfo chan pkglog.RunInfo, idp supermq.IDProv email: emailer, readers: readers, } - reEngine.workerMgr = NewWorkerManager(reEngine) return reEngine } @@ -70,7 +69,7 @@ func (re *re) AddRule(ctx context.Context, session authn.Session, r Rule) (Rule, return Rule{}, errors.Wrap(svcerr.ErrCreateEntity, err) } - if rule.Status == EnabledStatus && rule.Schedule.Recurring == schedule.None { + if rule.Status == EnabledStatus && rule.Schedule.Recurring == schedule.None && re.workerMgr != nil { re.workerMgr.AddWorker(ctx, rule) } @@ -94,10 +93,12 @@ func (re *re) UpdateRule(ctx context.Context, session authn.Session, r Rule) (Ru return Rule{}, errors.Wrap(svcerr.ErrUpdateEntity, err) } - if rule.Schedule.Recurring == schedule.None { - re.workerMgr.UpdateWorker(ctx, rule) - } else { - re.workerMgr.RemoveWorker(rule.ID) + if re.workerMgr != nil { + if rule.Schedule.Recurring == schedule.None { + re.workerMgr.UpdateWorker(ctx, rule) + } else { + re.workerMgr.RemoveWorker(rule.ID) + } } return rule, nil @@ -122,11 +123,9 @@ func (re *re) UpdateRuleSchedule(ctx context.Context, session authn.Session, r R return Rule{}, errors.Wrap(svcerr.ErrUpdateEntity, err) } - // Update worker based on schedule if rule.Schedule.Recurring == schedule.None && rule.Status == EnabledStatus { re.workerMgr.UpdateWorker(ctx, rule) } else { - // Rule is scheduled or disabled, remove from workers re.workerMgr.RemoveWorker(rule.ID) } @@ -200,5 +199,7 @@ func (re *re) DisableRule(ctx context.Context, session authn.Session, id string) } func (re *re) Cancel() error { + // Stop all workers when the service is cancelled + re.workerMgr.StopAll() return nil } diff --git a/re/worker.go b/re/worker.go index 750998bb8..6c2b9bcd8 100644 --- a/re/worker.go +++ b/re/worker.go @@ -5,9 +5,10 @@ package re import ( "context" - "sync" + "sync/atomic" "github.com/absmach/supermq/pkg/messaging" + "golang.org/x/sync/errgroup" ) // WorkerMessage represents a message to be processed by a rule worker. @@ -18,60 +19,55 @@ type WorkerMessage struct { // RuleWorker manages execution of a single rule in its own goroutine. type RuleWorker struct { - rule Rule - engine *re - msgChan chan WorkerMessage - stopChan chan struct{} - doneChan chan struct{} - running bool - mu sync.RWMutex + rule Rule + engine *re + msgChan chan WorkerMessage + updateChan chan Rule + ctx context.Context + cancel context.CancelFunc + g *errgroup.Group + running int32 } // NewRuleWorker creates a new rule worker for the given rule. func NewRuleWorker(rule Rule, engine *re) *RuleWorker { return &RuleWorker{ - rule: rule, - engine: engine, - msgChan: make(chan WorkerMessage, 100), // Buffer to prevent blocking - stopChan: make(chan struct{}), - doneChan: make(chan struct{}), - running: false, + rule: rule, + engine: engine, + msgChan: make(chan WorkerMessage, 100), + updateChan: make(chan Rule, 1), + running: 0, // 0 = not running, 1 = running } } // Start begins the worker goroutine for processing messages. func (w *RuleWorker) Start(ctx context.Context) { - w.mu.Lock() - if w.running { - w.mu.Unlock() + if !atomic.CompareAndSwapInt32(&w.running, 0, 1) { return } - w.running = true - w.mu.Unlock() - go w.run(ctx) + w.ctx, w.cancel = context.WithCancel(ctx) + w.g, w.ctx = errgroup.WithContext(w.ctx) + + w.g.Go(func() error { + return w.run(w.ctx) + }) } // Stop stops the worker goroutine and waits for it to finish. -func (w *RuleWorker) Stop() { - w.mu.Lock() - if !w.running { - w.mu.Unlock() - return +func (w *RuleWorker) Stop() error { + if !atomic.CompareAndSwapInt32(&w.running, 1, 0) { + return nil } - w.mu.Unlock() - close(w.stopChan) - <-w.doneChan + w.cancel() + + return w.g.Wait() } // Send sends a message to the worker for processing. func (w *RuleWorker) Send(msg WorkerMessage) bool { - w.mu.RLock() - running := w.running - w.mu.RUnlock() - - if !running { + if atomic.LoadInt32(&w.running) == 0 { return false } @@ -85,40 +81,42 @@ func (w *RuleWorker) Send(msg WorkerMessage) bool { // IsRunning returns true if the worker is currently running. func (w *RuleWorker) IsRunning() bool { - w.mu.RLock() - defer w.mu.RUnlock() - return w.running + return atomic.LoadInt32(&w.running) == 1 } // UpdateRule updates the rule configuration for this worker. func (w *RuleWorker) UpdateRule(rule Rule) { - w.mu.Lock() - w.rule = rule - w.mu.Unlock() + select { + case w.updateChan <- rule: + default: + // If channel is full, just overwrite the current rule + // This ensures we always have the latest rule + select { + case <-w.updateChan: // drain the channel + default: + } + w.updateChan <- rule + } } // GetRule returns the current rule configuration. func (w *RuleWorker) GetRule() Rule { - w.mu.RLock() - defer w.mu.RUnlock() - return w.rule + return w.rule // Since rule updates happen via channels in the worker loop, this is safe } // run is the main worker loop that processes messages. -func (w *RuleWorker) run(ctx context.Context) { +func (w *RuleWorker) run(ctx context.Context) error { defer func() { - w.mu.Lock() - w.running = false - w.mu.Unlock() - close(w.doneChan) + atomic.StoreInt32(&w.running, 0) }() for { select { case <-ctx.Done(): - return - case <-w.stopChan: - return + return ctx.Err() + case rule := <-w.updateChan: + // Update the rule configuration + w.rule = rule case workerMsg := <-w.msgChan: w.processMessage(ctx, workerMsg) } @@ -128,40 +126,115 @@ func (w *RuleWorker) run(ctx context.Context) { // processMessage processes a single message using the rule logic. func (w *RuleWorker) processMessage(ctx context.Context, workerMsg WorkerMessage) { currentRule := w.GetRule() - + if currentRule.Status != EnabledStatus { return } runInfo := w.engine.process(ctx, currentRule, workerMsg.Message) - - // Send run info to the logging channel + select { case w.engine.runInfo <- runInfo: default: } } -// WorkerManager manages all rule workers. +// WorkerManagerCommand represents commands for worker management +type WorkerManagerCommand struct { + Type string // "add", "remove", "update", "send", "stop_all" + Rule Rule + RuleID string + Message *messaging.Message + Response chan interface{} // For responses (e.g., SendMessage result) +} + +// WorkerManager manages all rule workers using channels instead of mutex type WorkerManager struct { - workers map[string]*RuleWorker - engine *re - mu sync.RWMutex + workers map[string]*RuleWorker + engine *re + g *errgroup.Group + ctx context.Context + commandCh chan WorkerManagerCommand + running int32 } // NewWorkerManager creates a new worker manager. -func NewWorkerManager(engine *re) *WorkerManager { - return &WorkerManager{ - workers: make(map[string]*RuleWorker), - engine: engine, +func NewWorkerManager(engine *re, ctx context.Context) *WorkerManager { + g, ctx := errgroup.WithContext(ctx) + wm := &WorkerManager{ + workers: make(map[string]*RuleWorker), + engine: engine, + g: g, + ctx: ctx, + commandCh: make(chan WorkerManagerCommand, 100), + running: 0, } + + // Start the worker manager goroutine + wm.g.Go(func() error { + return wm.manageWorkers(ctx) + }) + + atomic.StoreInt32(&wm.running, 1) + return wm } -// AddWorker adds a new worker for the given rule. -func (wm *WorkerManager) AddWorker(ctx context.Context, rule Rule) { - wm.mu.Lock() - defer wm.mu.Unlock() +// manageWorkers is the main loop that handles all worker management operations +func (wm *WorkerManager) manageWorkers(ctx context.Context) error { + defer atomic.StoreInt32(&wm.running, 0) + + for { + select { + case <-ctx.Done(): + // Stop all workers before exiting + for _, worker := range wm.workers { + worker.Stop() + } + wm.workers = make(map[string]*RuleWorker) + return ctx.Err() + + case cmd := <-wm.commandCh: + wm.handleCommand(cmd) + } + } +} + +// handleCommand processes worker management commands +func (wm *WorkerManager) handleCommand(cmd WorkerManagerCommand) { + switch cmd.Type { + case "add": + wm.addWorkerUnsafe(cmd.Rule) + case "remove": + wm.removeWorkerUnsafe(cmd.RuleID) + case "update": + wm.updateWorkerUnsafe(cmd.Rule) + case "send": + result := wm.sendMessageUnsafe(cmd.Message, cmd.Rule) + if cmd.Response != nil { + cmd.Response <- result + } + case "stop_all": + wm.stopAllUnsafe() + if cmd.Response != nil { + cmd.Response <- true + } + case "count": + if cmd.Response != nil { + cmd.Response <- len(wm.workers) + } + case "list": + if cmd.Response != nil { + ruleIDs := make([]string, 0, len(wm.workers)) + for ruleID := range wm.workers { + ruleIDs = append(ruleIDs, ruleID) + } + cmd.Response <- ruleIDs + } + } +} +// addWorkerUnsafe adds a worker without locking (called from manager goroutine) +func (wm *WorkerManager) addWorkerUnsafe(rule Rule) { if existing, ok := wm.workers[rule.ID]; ok { existing.Stop() } @@ -172,26 +245,20 @@ func (wm *WorkerManager) AddWorker(ctx context.Context, rule Rule) { } worker := NewRuleWorker(rule, wm.engine) - worker.Start(ctx) + worker.Start(wm.ctx) wm.workers[rule.ID] = worker } -// RemoveWorker removes and stops the worker for the given rule ID. -func (wm *WorkerManager) RemoveWorker(ruleID string) { - wm.mu.Lock() - defer wm.mu.Unlock() - +// removeWorkerUnsafe removes a worker without locking (called from manager goroutine) +func (wm *WorkerManager) removeWorkerUnsafe(ruleID string) { if worker, ok := wm.workers[ruleID]; ok { worker.Stop() delete(wm.workers, ruleID) } } -// UpdateWorker updates the rule configuration for an existing worker. -func (wm *WorkerManager) UpdateWorker(ctx context.Context, rule Rule) { - wm.mu.Lock() - defer wm.mu.Unlock() - +// updateWorkerUnsafe updates a worker without locking (called from manager goroutine) +func (wm *WorkerManager) updateWorkerUnsafe(rule Rule) { if rule.Status != EnabledStatus { if worker, ok := wm.workers[rule.ID]; ok { worker.Stop() @@ -204,17 +271,14 @@ func (wm *WorkerManager) UpdateWorker(ctx context.Context, rule Rule) { worker.UpdateRule(rule) } else { worker := NewRuleWorker(rule, wm.engine) - worker.Start(ctx) + worker.Start(wm.ctx) wm.workers[rule.ID] = worker } } -// SendMessage sends a message to the appropriate worker for processing. -func (wm *WorkerManager) SendMessage(msg *messaging.Message, rule Rule) bool { - wm.mu.RLock() +// sendMessageUnsafe sends a message to a worker without locking (called from manager goroutine) +func (wm *WorkerManager) sendMessageUnsafe(msg *messaging.Message, rule Rule) bool { worker, ok := wm.workers[rule.ID] - wm.mu.RUnlock() - if !ok || !worker.IsRunning() { return false } @@ -225,62 +289,178 @@ func (wm *WorkerManager) SendMessage(msg *messaging.Message, rule Rule) bool { }) } -// StopAll stops all workers. -func (wm *WorkerManager) StopAll() { - wm.mu.Lock() - defer wm.mu.Unlock() - +// stopAllUnsafe stops all workers without locking (called from manager goroutine) +func (wm *WorkerManager) stopAllUnsafe() { for _, worker := range wm.workers { worker.Stop() } wm.workers = make(map[string]*RuleWorker) } -// GetWorkerCount returns the number of active workers. -func (wm *WorkerManager) GetWorkerCount() int { - wm.mu.RLock() - defer wm.mu.RUnlock() - return len(wm.workers) +// AddWorker adds a new worker for the given rule. +func (wm *WorkerManager) AddWorker(ctx context.Context, rule Rule) { + if atomic.LoadInt32(&wm.running) == 0 { + return + } + + cmd := WorkerManagerCommand{ + Type: "add", + Rule: rule, + } + + select { + case wm.commandCh <- cmd: + case <-ctx.Done(): + } } -// ListWorkers returns a slice of rule IDs that have active workers. -func (wm *WorkerManager) ListWorkers() []string { - wm.mu.RLock() - defer wm.mu.RUnlock() +// RemoveWorker removes and stops the worker for the given rule ID. +func (wm *WorkerManager) RemoveWorker(ruleID string) { + if atomic.LoadInt32(&wm.running) == 0 { + return + } + + cmd := WorkerManagerCommand{ + Type: "remove", + RuleID: ruleID, + } + + select { + case wm.commandCh <- cmd: + default: + // Non-blocking, if channel is full, skip + } +} - ruleIDs := make([]string, 0, len(wm.workers)) - for ruleID := range wm.workers { - ruleIDs = append(ruleIDs, ruleID) +// UpdateWorker updates the rule configuration for an existing worker. +func (wm *WorkerManager) UpdateWorker(ctx context.Context, rule Rule) { + if atomic.LoadInt32(&wm.running) == 0 { + return + } + + cmd := WorkerManagerCommand{ + Type: "update", + Rule: rule, + } + + select { + case wm.commandCh <- cmd: + case <-ctx.Done(): } - return ruleIDs } -// RefreshWorkers synchronizes workers with the current set of enabled rules. -func (wm *WorkerManager) RefreshWorkers(ctx context.Context, rules []Rule) { - wm.mu.Lock() - defer wm.mu.Unlock() +// SendMessage sends a message to the appropriate worker for processing. +func (wm *WorkerManager) SendMessage(msg *messaging.Message, rule Rule) bool { + if atomic.LoadInt32(&wm.running) == 0 { + return false + } + + responseCh := make(chan interface{}, 1) + cmd := WorkerManagerCommand{ + Type: "send", + Rule: rule, + Message: msg, + Response: responseCh, + } + + select { + case wm.commandCh <- cmd: + select { + case result := <-responseCh: + if b, ok := result.(bool); ok { + return b + } + return false + case <-wm.ctx.Done(): + return false + } + default: + return false + } +} - currentRules := make(map[string]Rule) - for _, rule := range rules { - if rule.Status == EnabledStatus { - currentRules[rule.ID] = rule +// StopAll stops all workers and waits for them to finish. +func (wm *WorkerManager) StopAll() error { + if !atomic.CompareAndSwapInt32(&wm.running, 1, 0) { + return nil + } + + responseCh := make(chan interface{}, 1) + cmd := WorkerManagerCommand{ + Type: "stop_all", + Response: responseCh, + } + + select { + case wm.commandCh <- cmd: + <-responseCh // Wait for completion + default: + // Channel full, force stop + } + + // Wait for all workers to finish + return wm.g.Wait() +}// GetWorkerCount returns the number of active workers. +func (wm *WorkerManager) GetWorkerCount() int { + if atomic.LoadInt32(&wm.running) == 0 { + return 0 + } + + responseCh := make(chan interface{}, 1) + cmd := WorkerManagerCommand{ + Type: "count", + Response: responseCh, + } + + select { + case wm.commandCh <- cmd: + if result := <-responseCh; result != nil { + if count, ok := result.(int); ok { + return count + } } + default: } + return 0 +} - for ruleID, worker := range wm.workers { - if _, exists := currentRules[ruleID]; !exists { - worker.Stop() - delete(wm.workers, ruleID) +// ListWorkers returns a slice of rule IDs that have active workers. +func (wm *WorkerManager) ListWorkers() []string { + if atomic.LoadInt32(&wm.running) == 0 { + return nil + } + + responseCh := make(chan interface{}, 1) + cmd := WorkerManagerCommand{ + Type: "list", + Response: responseCh, + } + + select { + case wm.commandCh <- cmd: + if result := <-responseCh; result != nil { + if list, ok := result.([]string); ok { + return list + } } + default: } + return nil +} - for ruleID, rule := range currentRules { - if worker, exists := wm.workers[ruleID]; exists { - worker.UpdateRule(rule) +// RefreshWorkers synchronizes workers with the current set of enabled rules. +func (wm *WorkerManager) RefreshWorkers(ctx context.Context, rules []Rule) { + if atomic.LoadInt32(&wm.running) == 0 { + return + } + + // For simplicity, let's process refresh by individual add/update/remove commands + // First get current workers, then sync + for _, rule := range rules { + if rule.Status == EnabledStatus { + wm.UpdateWorker(ctx, rule) } else { - worker := NewRuleWorker(rule, wm.engine) - worker.Start(ctx) - wm.workers[ruleID] = worker + wm.RemoveWorker(rule.ID) } } } From eb823db4eb9ff586b79e53bb27b6932e84b43e2b Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Mon, 22 Sep 2025 18:35:43 +0300 Subject: [PATCH 03/25] fix failing linter Signed-off-by: nyagamunene --- re/handlers.go | 23 +++++----- re/service.go | 48 +++++++++++---------- re/worker.go | 114 +++++++++++++++++++++++++++++++------------------ 3 files changed, 112 insertions(+), 73 deletions(-) diff --git a/re/handlers.go b/re/handlers.go index 2c093873c..bcafcc8d8 100644 --- a/re/handlers.go +++ b/re/handlers.go @@ -33,12 +33,11 @@ func (re *re) Handle(msg *messaging.Message) error { if n := len(msg.Payload); n > maxPayload { return errors.New(pldExceededFmt + strconv.Itoa(n)) } - - // If WorkerManager is not initialized yet, fall back to old behavior + if re.workerMgr == nil { return re.handleWithoutWorkers(msg) } - + // Skip filtering by message topic and fetch all non-scheduled rules instead. // It's cleaner and more efficient to match wildcards in Go, but we can // revisit this if it ever becomes a performance bottleneck. @@ -56,9 +55,7 @@ func (re *re) Handle(msg *messaging.Message) error { for _, r := range page.Rules { if matchSubject(msg.Subtopic, r.InputTopic) { - // Send message to the appropriate worker if !re.workerMgr.SendMessage(msg, r) { - // Worker not found or busy, fall back to direct processing go func(ctx context.Context, rule Rule) { re.runInfo <- re.process(ctx, rule, msg) }(ctx, r) @@ -152,10 +149,8 @@ func (re *re) handleOutput(ctx context.Context, o Runnable, r Rule, msg *messagi } func (re *re) StartScheduler(ctx context.Context) error { - // Initialize WorkerManager with context re.workerMgr = NewWorkerManager(re, ctx) - - // Load and start workers for existing enabled rules + pm := PageMeta{ Status: EnabledStatus, } @@ -165,8 +160,16 @@ func (re *re) StartScheduler(ctx context.Context) error { } defer re.ticker.Stop() - defer re.workerMgr.StopAll() - + defer func() { + if stopErr := re.workerMgr.StopAll(); stopErr != nil { + re.runInfo <- pkglog.RunInfo{ + Level: slog.LevelError, + Message: fmt.Sprintf("failed to stop worker manager: %s", stopErr), + Details: []slog.Attr{slog.Time("time", time.Now().UTC())}, + } + } + }() + for { select { case <-ctx.Done(): diff --git a/re/service.go b/re/service.go index 604baccea..89fb3aa24 100644 --- a/re/service.go +++ b/re/service.go @@ -20,16 +20,16 @@ import ( ) type re struct { - repo Repository - runInfo chan pkglog.RunInfo - idp supermq.IDProvider - rePubSub messaging.PubSub - writersPub messaging.Publisher - alarmsPub messaging.Publisher - ticker ticker.Ticker - email emailer.Emailer - readers grpcReadersV1.ReadersServiceClient - workerMgr *WorkerManager + repo Repository + runInfo chan pkglog.RunInfo + idp supermq.IDProvider + rePubSub messaging.PubSub + writersPub messaging.Publisher + alarmsPub messaging.Publisher + ticker ticker.Ticker + email emailer.Emailer + readers grpcReadersV1.ReadersServiceClient + workerMgr *WorkerManager } func NewService(repo Repository, runInfo chan pkglog.RunInfo, idp supermq.IDProvider, rePubSub messaging.PubSub, writersPub, alarmsPub messaging.Publisher, tck ticker.Ticker, emailer emailer.Emailer, readers grpcReadersV1.ReadersServiceClient) Service { @@ -123,10 +123,12 @@ func (re *re) UpdateRuleSchedule(ctx context.Context, session authn.Session, r R return Rule{}, errors.Wrap(svcerr.ErrUpdateEntity, err) } - if rule.Schedule.Recurring == schedule.None && rule.Status == EnabledStatus { - re.workerMgr.UpdateWorker(ctx, rule) - } else { - re.workerMgr.RemoveWorker(rule.ID) + if re.workerMgr != nil { + if rule.Schedule.Recurring == schedule.None && rule.Status == EnabledStatus { + re.workerMgr.UpdateWorker(ctx, rule) + } else { + re.workerMgr.RemoveWorker(rule.ID) + } } return rule, nil @@ -146,8 +148,9 @@ func (re *re) RemoveRule(ctx context.Context, session authn.Session, id string) return errors.Wrap(svcerr.ErrRemoveEntity, err) } - // Remove worker for the deleted rule - re.workerMgr.RemoveWorker(id) + if re.workerMgr != nil { + re.workerMgr.RemoveWorker(id) + } return nil } @@ -168,8 +171,7 @@ func (re *re) EnableRule(ctx context.Context, session authn.Session, id string) return Rule{}, errors.Wrap(svcerr.ErrUpdateEntity, err) } - // Add worker for enabled rule if it's not scheduled - if rule.Schedule.Recurring == schedule.None { + if re.workerMgr != nil && rule.Schedule.Recurring == schedule.None { re.workerMgr.AddWorker(ctx, rule) } @@ -192,14 +194,16 @@ func (re *re) DisableRule(ctx context.Context, session authn.Session, id string) return Rule{}, errors.Wrap(svcerr.ErrUpdateEntity, err) } - // Remove worker for disabled rule - re.workerMgr.RemoveWorker(id) + if re.workerMgr != nil { + re.workerMgr.RemoveWorker(id) + } return rule, nil } func (re *re) Cancel() error { - // Stop all workers when the service is cancelled - re.workerMgr.StopAll() + if re.workerMgr != nil { + return re.workerMgr.StopAll() + } return nil } diff --git a/re/worker.go b/re/worker.go index 6c2b9bcd8..d47482e3a 100644 --- a/re/worker.go +++ b/re/worker.go @@ -89,10 +89,8 @@ func (w *RuleWorker) UpdateRule(rule Rule) { select { case w.updateChan <- rule: default: - // If channel is full, just overwrite the current rule - // This ensures we always have the latest rule select { - case <-w.updateChan: // drain the channel + case <-w.updateChan: default: } w.updateChan <- rule @@ -101,7 +99,7 @@ func (w *RuleWorker) UpdateRule(rule Rule) { // GetRule returns the current rule configuration. func (w *RuleWorker) GetRule() Rule { - return w.rule // Since rule updates happen via channels in the worker loop, this is safe + return w.rule } // run is the main worker loop that processes messages. @@ -115,7 +113,6 @@ func (w *RuleWorker) run(ctx context.Context) error { case <-ctx.Done(): return ctx.Err() case rule := <-w.updateChan: - // Update the rule configuration w.rule = rule case workerMsg := <-w.msgChan: w.processMessage(ctx, workerMsg) @@ -139,9 +136,44 @@ func (w *RuleWorker) processMessage(ctx context.Context, workerMsg WorkerMessage } } +// WorkerCommandType represents the type of worker management command +type WorkerCommandType uint8 + +const ( + CmdAdd WorkerCommandType = iota + CmdRemove + CmdUpdate + CmdSend + CmdStopAll + CmdCount + CmdList +) + +// String returns a string representation of the command type +func (c WorkerCommandType) String() string { + switch c { + case CmdAdd: + return "add" + case CmdRemove: + return "remove" + case CmdUpdate: + return "update" + case CmdSend: + return "send" + case CmdStopAll: + return "stop_all" + case CmdCount: + return "count" + case CmdList: + return "list" + default: + return "unknown" + } +} + // WorkerManagerCommand represents commands for worker management type WorkerManagerCommand struct { - Type string // "add", "remove", "update", "send", "stop_all" + Type WorkerCommandType Rule Rule RuleID string Message *messaging.Message @@ -169,12 +201,12 @@ func NewWorkerManager(engine *re, ctx context.Context) *WorkerManager { commandCh: make(chan WorkerManagerCommand, 100), running: 0, } - + // Start the worker manager goroutine wm.g.Go(func() error { return wm.manageWorkers(ctx) }) - + atomic.StoreInt32(&wm.running, 1) return wm } @@ -182,7 +214,7 @@ func NewWorkerManager(engine *re, ctx context.Context) *WorkerManager { // manageWorkers is the main loop that handles all worker management operations func (wm *WorkerManager) manageWorkers(ctx context.Context) error { defer atomic.StoreInt32(&wm.running, 0) - + for { select { case <-ctx.Done(): @@ -192,7 +224,7 @@ func (wm *WorkerManager) manageWorkers(ctx context.Context) error { } wm.workers = make(map[string]*RuleWorker) return ctx.Err() - + case cmd := <-wm.commandCh: wm.handleCommand(cmd) } @@ -202,27 +234,27 @@ func (wm *WorkerManager) manageWorkers(ctx context.Context) error { // handleCommand processes worker management commands func (wm *WorkerManager) handleCommand(cmd WorkerManagerCommand) { switch cmd.Type { - case "add": + case CmdAdd: wm.addWorkerUnsafe(cmd.Rule) - case "remove": + case CmdRemove: wm.removeWorkerUnsafe(cmd.RuleID) - case "update": + case CmdUpdate: wm.updateWorkerUnsafe(cmd.Rule) - case "send": + case CmdSend: result := wm.sendMessageUnsafe(cmd.Message, cmd.Rule) if cmd.Response != nil { cmd.Response <- result } - case "stop_all": + case CmdStopAll: wm.stopAllUnsafe() if cmd.Response != nil { cmd.Response <- true } - case "count": + case CmdCount: if cmd.Response != nil { cmd.Response <- len(wm.workers) } - case "list": + case CmdList: if cmd.Response != nil { ruleIDs := make([]string, 0, len(wm.workers)) for ruleID := range wm.workers { @@ -302,12 +334,12 @@ func (wm *WorkerManager) AddWorker(ctx context.Context, rule Rule) { if atomic.LoadInt32(&wm.running) == 0 { return } - + cmd := WorkerManagerCommand{ - Type: "add", + Type: CmdAdd, Rule: rule, } - + select { case wm.commandCh <- cmd: case <-ctx.Done(): @@ -319,12 +351,12 @@ func (wm *WorkerManager) RemoveWorker(ruleID string) { if atomic.LoadInt32(&wm.running) == 0 { return } - + cmd := WorkerManagerCommand{ - Type: "remove", + Type: CmdRemove, RuleID: ruleID, } - + select { case wm.commandCh <- cmd: default: @@ -337,12 +369,12 @@ func (wm *WorkerManager) UpdateWorker(ctx context.Context, rule Rule) { if atomic.LoadInt32(&wm.running) == 0 { return } - + cmd := WorkerManagerCommand{ - Type: "update", + Type: CmdUpdate, Rule: rule, } - + select { case wm.commandCh <- cmd: case <-ctx.Done(): @@ -354,15 +386,15 @@ func (wm *WorkerManager) SendMessage(msg *messaging.Message, rule Rule) bool { if atomic.LoadInt32(&wm.running) == 0 { return false } - + responseCh := make(chan interface{}, 1) cmd := WorkerManagerCommand{ - Type: "send", + Type: CmdSend, Rule: rule, Message: msg, Response: responseCh, } - + select { case wm.commandCh <- cmd: select { @@ -384,34 +416,34 @@ func (wm *WorkerManager) StopAll() error { if !atomic.CompareAndSwapInt32(&wm.running, 1, 0) { return nil } - + responseCh := make(chan interface{}, 1) cmd := WorkerManagerCommand{ - Type: "stop_all", + Type: CmdStopAll, Response: responseCh, } - + select { case wm.commandCh <- cmd: <-responseCh // Wait for completion default: // Channel full, force stop } - + // Wait for all workers to finish return wm.g.Wait() -}// GetWorkerCount returns the number of active workers. +} // GetWorkerCount returns the number of active workers. func (wm *WorkerManager) GetWorkerCount() int { if atomic.LoadInt32(&wm.running) == 0 { return 0 } - + responseCh := make(chan interface{}, 1) cmd := WorkerManagerCommand{ - Type: "count", + Type: CmdCount, Response: responseCh, } - + select { case wm.commandCh <- cmd: if result := <-responseCh; result != nil { @@ -429,13 +461,13 @@ func (wm *WorkerManager) ListWorkers() []string { if atomic.LoadInt32(&wm.running) == 0 { return nil } - + responseCh := make(chan interface{}, 1) cmd := WorkerManagerCommand{ - Type: "list", + Type: CmdList, Response: responseCh, } - + select { case wm.commandCh <- cmd: if result := <-responseCh; result != nil { @@ -453,7 +485,7 @@ func (wm *WorkerManager) RefreshWorkers(ctx context.Context, rules []Rule) { if atomic.LoadInt32(&wm.running) == 0 { return } - + // For simplicity, let's process refresh by individual add/update/remove commands // First get current workers, then sync for _, rule := range rules { From f2f647aabd5e08717644cd687ffb751d179b54fc Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Tue, 23 Sep 2025 15:25:03 +0300 Subject: [PATCH 04/25] update methods error handling Signed-off-by: nyagamunene --- docker/.env | 2 +- re/handlers.go | 58 ++++++++--------- re/service.go | 56 ++++++++++------- re/worker.go | 168 ++++++++++++++++++++++++++++++------------------- 4 files changed, 164 insertions(+), 120 deletions(-) diff --git a/docker/.env b/docker/.env index 693897bda..a8e6e7e1b 100644 --- a/docker/.env +++ b/docker/.env @@ -396,4 +396,4 @@ MG_RELEASE_TAG=latest SMQ_ALLOW_UNVERIFIED_USER=true # Set to yes to accept the EULA for the UI services. To view the EULA visit: https://github.com/absmach/eula -MG_UI_DOCKER_ACCEPT_EULA=no +MG_UI_DOCKER_ACCEPT_EULA=yes diff --git a/re/handlers.go b/re/handlers.go index bcafcc8d8..a8e8690a6 100644 --- a/re/handlers.go +++ b/re/handlers.go @@ -35,7 +35,7 @@ func (re *re) Handle(msg *messaging.Message) error { } if re.workerMgr == nil { - return re.handleWithoutWorkers(msg) + return errors.New("worker manager not initialized - scheduler must be started first") } // Skip filtering by message topic and fetch all non-scheduled rules instead. @@ -56,9 +56,16 @@ func (re *re) Handle(msg *messaging.Message) error { for _, r := range page.Rules { if matchSubject(msg.Subtopic, r.InputTopic) { if !re.workerMgr.SendMessage(msg, r) { - go func(ctx context.Context, rule Rule) { - re.runInfo <- re.process(ctx, rule, msg) - }(ctx, r) + re.runInfo <- pkglog.RunInfo{ + Level: slog.LevelWarn, + Message: fmt.Sprintf("failed to send message to worker for rule %s", r.ID), + Details: []slog.Attr{ + slog.String("rule_id", r.ID), + slog.String("rule_name", r.Name), + slog.String("channel", msg.Channel), + slog.Time("time", time.Now().UTC()), + }, + } } } } @@ -66,31 +73,6 @@ func (re *re) Handle(msg *messaging.Message) error { return nil } -// handleWithoutWorkers processes messages using the legacy direct goroutine approach -func (re *re) handleWithoutWorkers(msg *messaging.Message) error { - pm := PageMeta{ - Domain: msg.Domain, - InputChannel: msg.Channel, - Status: EnabledStatus, - Scheduled: &scheduledFalse, - } - ctx := context.Background() - page, err := re.repo.ListRules(ctx, pm) - if err != nil { - return err - } - - for _, r := range page.Rules { - if matchSubject(msg.Subtopic, r.InputTopic) { - go func(ctx context.Context, rule Rule) { - re.runInfo <- re.process(ctx, rule, msg) - }(ctx, r) - } - } - - return nil -} - // Match NATS subject to support wildcards. func matchSubject(published, subscribed string) bool { p := strings.Split(published, ".") @@ -151,6 +133,24 @@ func (re *re) handleOutput(ctx context.Context, o Runnable, r Rule, msg *messagi func (re *re) StartScheduler(ctx context.Context) error { re.workerMgr = NewWorkerManager(re, ctx) + // Start goroutine to monitor worker manager errors + go func() { + for { + select { + case <-ctx.Done(): + return + case err := <-re.workerMgr.ErrorChan(): + if err != nil { + re.runInfo <- pkglog.RunInfo{ + Level: slog.LevelError, + Message: fmt.Sprintf("worker management error: %s", err), + Details: []slog.Attr{slog.Time("time", time.Now().UTC())}, + } + } + } + } + }() + pm := PageMeta{ Status: EnabledStatus, } diff --git a/re/service.go b/re/service.go index 89fb3aa24..5f53e8b42 100644 --- a/re/service.go +++ b/re/service.go @@ -47,6 +47,25 @@ func NewService(repo Repository, runInfo chan pkglog.RunInfo, idp supermq.IDProv return reEngine } +func shouldCreateWorker(rule Rule) bool { + if rule.Status != EnabledStatus { + return false + } + + if rule.Schedule.Recurring == schedule.None { + return true + } + + now := time.Now().UTC() + dueTime := rule.Schedule.Time + + if dueTime.IsZero() || dueTime.Before(now) { + return true + } + + return dueTime.Sub(now) <= time.Hour +} + func (re *re) AddRule(ctx context.Context, session authn.Session, r Rule) (Rule, error) { id, err := re.idp.ID() if err != nil { @@ -69,7 +88,7 @@ func (re *re) AddRule(ctx context.Context, session authn.Session, r Rule) (Rule, return Rule{}, errors.Wrap(svcerr.ErrCreateEntity, err) } - if rule.Status == EnabledStatus && rule.Schedule.Recurring == schedule.None && re.workerMgr != nil { + if shouldCreateWorker(rule) { re.workerMgr.AddWorker(ctx, rule) } @@ -93,12 +112,10 @@ func (re *re) UpdateRule(ctx context.Context, session authn.Session, r Rule) (Ru return Rule{}, errors.Wrap(svcerr.ErrUpdateEntity, err) } - if re.workerMgr != nil { - if rule.Schedule.Recurring == schedule.None { - re.workerMgr.UpdateWorker(ctx, rule) - } else { - re.workerMgr.RemoveWorker(rule.ID) - } + if shouldCreateWorker(rule) { + re.workerMgr.UpdateWorker(ctx, rule) + } else { + re.workerMgr.RemoveWorker(rule.ID) } return rule, nil @@ -123,12 +140,10 @@ func (re *re) UpdateRuleSchedule(ctx context.Context, session authn.Session, r R return Rule{}, errors.Wrap(svcerr.ErrUpdateEntity, err) } - if re.workerMgr != nil { - if rule.Schedule.Recurring == schedule.None && rule.Status == EnabledStatus { - re.workerMgr.UpdateWorker(ctx, rule) - } else { - re.workerMgr.RemoveWorker(rule.ID) - } + if shouldCreateWorker(rule) { + re.workerMgr.UpdateWorker(ctx, rule) + } else { + re.workerMgr.RemoveWorker(rule.ID) } return rule, nil @@ -148,9 +163,7 @@ func (re *re) RemoveRule(ctx context.Context, session authn.Session, id string) return errors.Wrap(svcerr.ErrRemoveEntity, err) } - if re.workerMgr != nil { - re.workerMgr.RemoveWorker(id) - } + re.workerMgr.RemoveWorker(id) return nil } @@ -171,7 +184,7 @@ func (re *re) EnableRule(ctx context.Context, session authn.Session, id string) return Rule{}, errors.Wrap(svcerr.ErrUpdateEntity, err) } - if re.workerMgr != nil && rule.Schedule.Recurring == schedule.None { + if shouldCreateWorker(rule) { re.workerMgr.AddWorker(ctx, rule) } @@ -194,16 +207,11 @@ func (re *re) DisableRule(ctx context.Context, session authn.Session, id string) return Rule{}, errors.Wrap(svcerr.ErrUpdateEntity, err) } - if re.workerMgr != nil { - re.workerMgr.RemoveWorker(id) - } + re.workerMgr.RemoveWorker(id) return rule, nil } func (re *re) Cancel() error { - if re.workerMgr != nil { - return re.workerMgr.StopAll() - } - return nil + return re.workerMgr.StopAll() } diff --git a/re/worker.go b/re/worker.go index d47482e3a..0bd079eb2 100644 --- a/re/worker.go +++ b/re/worker.go @@ -5,6 +5,7 @@ package re import ( "context" + "sync" "sync/atomic" "github.com/absmach/supermq/pkg/messaging" @@ -136,7 +137,6 @@ func (w *RuleWorker) processMessage(ctx context.Context, workerMsg WorkerMessage } } -// WorkerCommandType represents the type of worker management command type WorkerCommandType uint8 const ( @@ -149,7 +149,6 @@ const ( CmdList ) -// String returns a string representation of the command type func (c WorkerCommandType) String() string { switch c { case CmdAdd: @@ -171,22 +170,24 @@ func (c WorkerCommandType) String() string { } } -// WorkerManagerCommand represents commands for worker management +// WorkerManagerCommand represents commands for worker management. type WorkerManagerCommand struct { Type WorkerCommandType Rule Rule RuleID string Message *messaging.Message - Response chan interface{} // For responses (e.g., SendMessage result) + Response chan interface{} } -// WorkerManager manages all rule workers using channels instead of mutex +// WorkerManager manages all rule workers. type WorkerManager struct { workers map[string]*RuleWorker engine *re g *errgroup.Group ctx context.Context commandCh chan WorkerManagerCommand + errorCh chan error + mu sync.RWMutex running int32 } @@ -199,6 +200,7 @@ func NewWorkerManager(engine *re, ctx context.Context) *WorkerManager { g: g, ctx: ctx, commandCh: make(chan WorkerManagerCommand, 100), + errorCh: make(chan error, 100), running: 0, } @@ -211,16 +213,19 @@ func NewWorkerManager(engine *re, ctx context.Context) *WorkerManager { return wm } -// manageWorkers is the main loop that handles all worker management operations func (wm *WorkerManager) manageWorkers(ctx context.Context) error { defer atomic.StoreInt32(&wm.running, 0) for { select { case <-ctx.Done(): - // Stop all workers before exiting for _, worker := range wm.workers { - worker.Stop() + if err := worker.Stop(); err != nil { + select { + case wm.errorCh <- err: + default: + } + } } wm.workers = make(map[string]*RuleWorker) return ctx.Err() @@ -231,86 +236,125 @@ func (wm *WorkerManager) manageWorkers(ctx context.Context) error { } } -// handleCommand processes worker management commands func (wm *WorkerManager) handleCommand(cmd WorkerManagerCommand) { switch cmd.Type { case CmdAdd: - wm.addWorkerUnsafe(cmd.Rule) + if err := wm.addWorker(cmd.Rule); err != nil { + select { + case wm.errorCh <- err: + default: + } + } case CmdRemove: - wm.removeWorkerUnsafe(cmd.RuleID) + if err := wm.removeWorker(cmd.RuleID); err != nil { + select { + case wm.errorCh <- err: + default: + } + } case CmdUpdate: - wm.updateWorkerUnsafe(cmd.Rule) + if err := wm.updateWorker(cmd.Rule); err != nil { + select { + case wm.errorCh <- err: + default: + } + } case CmdSend: - result := wm.sendMessageUnsafe(cmd.Message, cmd.Rule) + result := wm.sendMessage(cmd.Message, cmd.Rule) if cmd.Response != nil { cmd.Response <- result } case CmdStopAll: - wm.stopAllUnsafe() + if err := wm.stopAll(); err != nil { + select { + case wm.errorCh <- err: + default: + } + } if cmd.Response != nil { cmd.Response <- true } case CmdCount: + wm.mu.RLock() + count := len(wm.workers) + wm.mu.RUnlock() if cmd.Response != nil { - cmd.Response <- len(wm.workers) + cmd.Response <- count } case CmdList: + wm.mu.RLock() + ruleIDs := make([]string, 0, len(wm.workers)) + for ruleID := range wm.workers { + ruleIDs = append(ruleIDs, ruleID) + } + wm.mu.RUnlock() if cmd.Response != nil { - ruleIDs := make([]string, 0, len(wm.workers)) - for ruleID := range wm.workers { - ruleIDs = append(ruleIDs, ruleID) - } cmd.Response <- ruleIDs } } } -// addWorkerUnsafe adds a worker without locking (called from manager goroutine) -func (wm *WorkerManager) addWorkerUnsafe(rule Rule) { +func (wm *WorkerManager) addWorker(rule Rule) error { + wm.mu.Lock() + defer wm.mu.Unlock() + if existing, ok := wm.workers[rule.ID]; ok { - existing.Stop() + if err := existing.Stop(); err != nil { + return err + } } if rule.Status != EnabledStatus { delete(wm.workers, rule.ID) - return + return nil } worker := NewRuleWorker(rule, wm.engine) worker.Start(wm.ctx) wm.workers[rule.ID] = worker + return nil } -// removeWorkerUnsafe removes a worker without locking (called from manager goroutine) -func (wm *WorkerManager) removeWorkerUnsafe(ruleID string) { +func (wm *WorkerManager) removeWorker(ruleID string) error { + wm.mu.Lock() + defer wm.mu.Unlock() + if worker, ok := wm.workers[ruleID]; ok { - worker.Stop() + if err := worker.Stop(); err != nil { + return err + } delete(wm.workers, ruleID) } + return nil } -// updateWorkerUnsafe updates a worker without locking (called from manager goroutine) -func (wm *WorkerManager) updateWorkerUnsafe(rule Rule) { - if rule.Status != EnabledStatus { - if worker, ok := wm.workers[rule.ID]; ok { - worker.Stop() - delete(wm.workers, rule.ID) +func (wm *WorkerManager) updateWorker(rule Rule) error { + wm.mu.Lock() + defer wm.mu.Unlock() + + if worker, ok := wm.workers[rule.ID]; ok { + if err := worker.Stop(); err != nil { + return err } - return + delete(wm.workers, rule.ID) } - if worker, ok := wm.workers[rule.ID]; ok { - worker.UpdateRule(rule) - } else { - worker := NewRuleWorker(rule, wm.engine) - worker.Start(wm.ctx) - wm.workers[rule.ID] = worker + if rule.Status != EnabledStatus { + delete(wm.workers, rule.ID) + return nil } + + worker := NewRuleWorker(rule, wm.engine) + worker.Start(wm.ctx) + wm.workers[rule.ID] = worker + return nil } -// sendMessageUnsafe sends a message to a worker without locking (called from manager goroutine) -func (wm *WorkerManager) sendMessageUnsafe(msg *messaging.Message, rule Rule) bool { +func (wm *WorkerManager) sendMessage(msg *messaging.Message, rule Rule) bool { + wm.mu.RLock() worker, ok := wm.workers[rule.ID] + wm.mu.RUnlock() + if !ok || !worker.IsRunning() { return false } @@ -321,15 +365,19 @@ func (wm *WorkerManager) sendMessageUnsafe(msg *messaging.Message, rule Rule) bo }) } -// stopAllUnsafe stops all workers without locking (called from manager goroutine) -func (wm *WorkerManager) stopAllUnsafe() { +func (wm *WorkerManager) stopAll() error { + wm.mu.Lock() + defer wm.mu.Unlock() + for _, worker := range wm.workers { - worker.Stop() + if err := worker.Stop(); err != nil { + return err + } } wm.workers = make(map[string]*RuleWorker) + return nil } -// AddWorker adds a new worker for the given rule. func (wm *WorkerManager) AddWorker(ctx context.Context, rule Rule) { if atomic.LoadInt32(&wm.running) == 0 { return @@ -346,7 +394,6 @@ func (wm *WorkerManager) AddWorker(ctx context.Context, rule Rule) { } } -// RemoveWorker removes and stops the worker for the given rule ID. func (wm *WorkerManager) RemoveWorker(ruleID string) { if atomic.LoadInt32(&wm.running) == 0 { return @@ -357,14 +404,9 @@ func (wm *WorkerManager) RemoveWorker(ruleID string) { RuleID: ruleID, } - select { - case wm.commandCh <- cmd: - default: - // Non-blocking, if channel is full, skip - } + wm.commandCh <- cmd } -// UpdateWorker updates the rule configuration for an existing worker. func (wm *WorkerManager) UpdateWorker(ctx context.Context, rule Rule) { if atomic.LoadInt32(&wm.running) == 0 { return @@ -381,7 +423,6 @@ func (wm *WorkerManager) UpdateWorker(ctx context.Context, rule Rule) { } } -// SendMessage sends a message to the appropriate worker for processing. func (wm *WorkerManager) SendMessage(msg *messaging.Message, rule Rule) bool { if atomic.LoadInt32(&wm.running) == 0 { return false @@ -411,7 +452,6 @@ func (wm *WorkerManager) SendMessage(msg *messaging.Message, rule Rule) bool { } } -// StopAll stops all workers and waits for them to finish. func (wm *WorkerManager) StopAll() error { if !atomic.CompareAndSwapInt32(&wm.running, 1, 0) { return nil @@ -423,16 +463,12 @@ func (wm *WorkerManager) StopAll() error { Response: responseCh, } - select { - case wm.commandCh <- cmd: - <-responseCh // Wait for completion - default: - // Channel full, force stop - } + wm.commandCh <- cmd + <-responseCh - // Wait for all workers to finish return wm.g.Wait() -} // GetWorkerCount returns the number of active workers. +} + func (wm *WorkerManager) GetWorkerCount() int { if atomic.LoadInt32(&wm.running) == 0 { return 0 @@ -456,7 +492,6 @@ func (wm *WorkerManager) GetWorkerCount() int { return 0 } -// ListWorkers returns a slice of rule IDs that have active workers. func (wm *WorkerManager) ListWorkers() []string { if atomic.LoadInt32(&wm.running) == 0 { return nil @@ -480,14 +515,15 @@ func (wm *WorkerManager) ListWorkers() []string { return nil } -// RefreshWorkers synchronizes workers with the current set of enabled rules. +func (wm *WorkerManager) ErrorChan() <-chan error { + return wm.errorCh +} + func (wm *WorkerManager) RefreshWorkers(ctx context.Context, rules []Rule) { if atomic.LoadInt32(&wm.running) == 0 { return } - // For simplicity, let's process refresh by individual add/update/remove commands - // First get current workers, then sync for _, rule := range rules { if rule.Status == EnabledStatus { wm.UpdateWorker(ctx, rule) From a05bb74852fe330c82ecfe357bf316256e21a8c4 Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Tue, 23 Sep 2025 15:33:42 +0300 Subject: [PATCH 05/25] fix failing linter Signed-off-by: nyagamunene --- re/service.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/re/service.go b/re/service.go index 5f53e8b42..ae70c7c75 100644 --- a/re/service.go +++ b/re/service.go @@ -58,11 +58,11 @@ func shouldCreateWorker(rule Rule) bool { now := time.Now().UTC() dueTime := rule.Schedule.Time - + if dueTime.IsZero() || dueTime.Before(now) { return true } - + return dueTime.Sub(now) <= time.Hour } From 1bfd44f2bb4d05b8dbf5ebe2a12c556a3289006a Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Wed, 24 Sep 2025 11:59:44 +0300 Subject: [PATCH 06/25] remove comments Signed-off-by: nyagamunene --- re/handlers.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/re/handlers.go b/re/handlers.go index a8e8690a6..27056e508 100644 --- a/re/handlers.go +++ b/re/handlers.go @@ -35,7 +35,7 @@ func (re *re) Handle(msg *messaging.Message) error { } if re.workerMgr == nil { - return errors.New("worker manager not initialized - scheduler must be started first") + return errors.New("worker manager not initialized") } // Skip filtering by message topic and fetch all non-scheduled rules instead. @@ -133,7 +133,6 @@ func (re *re) handleOutput(ctx context.Context, o Runnable, r Rule, msg *messagi func (re *re) StartScheduler(ctx context.Context) error { re.workerMgr = NewWorkerManager(re, ctx) - // Start goroutine to monitor worker manager errors go func() { for { select { From a257fdfee4a9c3ada8d01145b63930a837136e95 Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Wed, 24 Sep 2025 14:28:28 +0300 Subject: [PATCH 07/25] fix update worker Signed-off-by: nyagamunene --- re/worker.go | 133 +++++++++++++++++++++------------------------------ 1 file changed, 55 insertions(+), 78 deletions(-) diff --git a/re/worker.go b/re/worker.go index 0bd079eb2..c75de883a 100644 --- a/re/worker.go +++ b/re/worker.go @@ -20,24 +20,22 @@ type WorkerMessage struct { // RuleWorker manages execution of a single rule in its own goroutine. type RuleWorker struct { - rule Rule - engine *re - msgChan chan WorkerMessage - updateChan chan Rule - ctx context.Context - cancel context.CancelFunc - g *errgroup.Group - running int32 + rule Rule + engine *re + msgChan chan WorkerMessage + ctx context.Context + cancel context.CancelFunc + g *errgroup.Group + running int32 } // NewRuleWorker creates a new rule worker for the given rule. func NewRuleWorker(rule Rule, engine *re) *RuleWorker { return &RuleWorker{ - rule: rule, - engine: engine, - msgChan: make(chan WorkerMessage, 100), - updateChan: make(chan Rule, 1), - running: 0, // 0 = not running, 1 = running + rule: rule, + engine: engine, + msgChan: make(chan WorkerMessage, 100), + running: 0, // 0 = not running, 1 = running } } @@ -85,19 +83,6 @@ func (w *RuleWorker) IsRunning() bool { return atomic.LoadInt32(&w.running) == 1 } -// UpdateRule updates the rule configuration for this worker. -func (w *RuleWorker) UpdateRule(rule Rule) { - select { - case w.updateChan <- rule: - default: - select { - case <-w.updateChan: - default: - } - w.updateChan <- rule - } -} - // GetRule returns the current rule configuration. func (w *RuleWorker) GetRule() Rule { return w.rule @@ -113,8 +98,6 @@ func (w *RuleWorker) run(ctx context.Context) error { select { case <-ctx.Done(): return ctx.Err() - case rule := <-w.updateChan: - w.rule = rule case workerMsg := <-w.msgChan: w.processMessage(ctx, workerMsg) } @@ -143,7 +126,6 @@ const ( CmdAdd WorkerCommandType = iota CmdRemove CmdUpdate - CmdSend CmdStopAll CmdCount CmdList @@ -157,8 +139,6 @@ func (c WorkerCommandType) String() string { return "remove" case CmdUpdate: return "update" - case CmdSend: - return "send" case CmdStopAll: return "stop_all" case CmdCount: @@ -175,7 +155,6 @@ type WorkerManagerCommand struct { Type WorkerCommandType Rule Rule RuleID string - Message *messaging.Message Response chan interface{} } @@ -204,7 +183,6 @@ func NewWorkerManager(engine *re, ctx context.Context) *WorkerManager { running: 0, } - // Start the worker manager goroutine wm.g.Go(func() error { return wm.manageWorkers(ctx) }) @@ -214,7 +192,9 @@ func NewWorkerManager(engine *re, ctx context.Context) *WorkerManager { } func (wm *WorkerManager) manageWorkers(ctx context.Context) error { - defer atomic.StoreInt32(&wm.running, 0) + defer func() { + atomic.StoreInt32(&wm.running, 0) + }() for { select { @@ -259,11 +239,6 @@ func (wm *WorkerManager) handleCommand(cmd WorkerManagerCommand) { default: } } - case CmdSend: - result := wm.sendMessage(cmd.Message, cmd.Rule) - if cmd.Response != nil { - cmd.Response <- result - } case CmdStopAll: if err := wm.stopAll(); err != nil { select { @@ -298,20 +273,32 @@ func (wm *WorkerManager) addWorker(rule Rule) error { wm.mu.Lock() defer wm.mu.Unlock() - if existing, ok := wm.workers[rule.ID]; ok { - if err := existing.Stop(); err != nil { - return err - } - } + oldWorker, exists := wm.workers[rule.ID] if rule.Status != EnabledStatus { + if exists { + if err := oldWorker.Stop(); err != nil { + return err + } + } delete(wm.workers, rule.ID) return nil } - worker := NewRuleWorker(rule, wm.engine) - worker.Start(wm.ctx) - wm.workers[rule.ID] = worker + newWorker := NewRuleWorker(rule, wm.engine) + newWorker.Start(wm.ctx) + + wm.workers[rule.ID] = newWorker + + if exists { + if err := oldWorker.Stop(); err != nil { + select { + case wm.errorCh <- err: + default: + } + } + } + return nil } @@ -332,21 +319,32 @@ func (wm *WorkerManager) updateWorker(rule Rule) error { wm.mu.Lock() defer wm.mu.Unlock() - if worker, ok := wm.workers[rule.ID]; ok { - if err := worker.Stop(); err != nil { - return err - } - delete(wm.workers, rule.ID) - } + oldWorker, exists := wm.workers[rule.ID] if rule.Status != EnabledStatus { + if exists { + if err := oldWorker.Stop(); err != nil { + return err + } + } delete(wm.workers, rule.ID) return nil } - worker := NewRuleWorker(rule, wm.engine) - worker.Start(wm.ctx) - wm.workers[rule.ID] = worker + newWorker := NewRuleWorker(rule, wm.engine) + newWorker.Start(wm.ctx) + + wm.workers[rule.ID] = newWorker + + if exists { + if err := oldWorker.Stop(); err != nil { + select { + case wm.errorCh <- err: + default: + } + } + } + return nil } @@ -428,28 +426,7 @@ func (wm *WorkerManager) SendMessage(msg *messaging.Message, rule Rule) bool { return false } - responseCh := make(chan interface{}, 1) - cmd := WorkerManagerCommand{ - Type: CmdSend, - Rule: rule, - Message: msg, - Response: responseCh, - } - - select { - case wm.commandCh <- cmd: - select { - case result := <-responseCh: - if b, ok := result.(bool); ok { - return b - } - return false - case <-wm.ctx.Done(): - return false - } - default: - return false - } + return wm.sendMessage(msg, rule) } func (wm *WorkerManager) StopAll() error { From 3b6d9a4518536e1902b2e0c49cc37ac46d4e2c73 Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Thu, 25 Sep 2025 12:40:51 +0300 Subject: [PATCH 08/25] fix service tests Signed-off-by: nyagamunene --- re/handlers.go | 46 ++++++++++++----- re/service_test.go | 124 ++++++++++++++++++++++++++++++++++++++------- re/worker.go | 27 ++++------ 3 files changed, 147 insertions(+), 50 deletions(-) diff --git a/re/handlers.go b/re/handlers.go index 27056e508..96a584c38 100644 --- a/re/handlers.go +++ b/re/handlers.go @@ -131,7 +131,7 @@ func (re *re) handleOutput(ctx context.Context, o Runnable, r Rule, msg *messagi } func (re *re) StartScheduler(ctx context.Context) error { - re.workerMgr = NewWorkerManager(re, ctx) + re.workerMgr = NewWorkerManager(ctx, re) go func() { for { @@ -193,21 +193,39 @@ func (re *re) StartScheduler(ctx context.Context) error { } for _, r := range page.Rules { - go func(rule Rule) { - if _, err := re.repo.UpdateRuleDue(ctx, rule.ID, rule.Schedule.NextDue()); err != nil { - re.runInfo <- pkglog.RunInfo{Level: slog.LevelError, Message: fmt.Sprintf("failed to update rule: %s", err), Details: []slog.Attr{slog.Time("time", time.Now().UTC())}} - return + msg := &messaging.Message{ + Domain: r.DomainID, + Channel: r.InputChannel, + Subtopic: r.InputTopic, + Protocol: protocol, + Created: due.Unix(), + } + + if !re.workerMgr.SendMessage(msg, r) { + re.runInfo <- pkglog.RunInfo{ + Level: slog.LevelWarn, + Message: fmt.Sprintf("failed to send scheduled message to worker for rule %s", r.ID), + Details: []slog.Attr{ + slog.String("rule_id", r.ID), + slog.String("rule_name", r.Name), + slog.String("channel", msg.Channel), + slog.Time("scheduled_time", due), + }, } - - msg := &messaging.Message{ - Domain: rule.DomainID, - Channel: rule.InputChannel, - Subtopic: rule.InputTopic, - Protocol: protocol, - Created: due.Unix(), + } + + go func(ruleID string, nextDue time.Time) { + if _, err := re.repo.UpdateRuleDue(ctx, ruleID, nextDue); err != nil { + re.runInfo <- pkglog.RunInfo{ + Level: slog.LevelError, + Message: fmt.Sprintf("failed to update rule due time: %s", err), + Details: []slog.Attr{ + slog.String("rule_id", ruleID), + slog.Time("time", time.Now().UTC()), + }, + } } - re.runInfo <- re.process(ctx, rule, msg) - }(r) + }(r.ID, r.Schedule.NextDue()) } // Reset due, it will reset in the page meta as well. due = time.Now().UTC() diff --git a/re/service_test.go b/re/service_test.go index b5fb8c3de..454dfd28a 100644 --- a/re/service_test.go +++ b/re/service_test.go @@ -54,11 +54,50 @@ func newService(t *testing.T, runInfo chan pkglog.RunInfo) (re.Service, *mocks.R pubsub := pubsubmocks.NewPubSub(t) readersSvc := new(readmocks.ReadersServiceClient) e := new(emocks.Emailer) - return re.NewService(repo, runInfo, idProvider, pubsub, pubsub, pubsub, mockTicker, e, readersSvc), repo, pubsub, mockTicker + + svc := re.NewService(repo, runInfo, idProvider, pubsub, pubsub, pubsub, mockTicker, e, readersSvc) + + repocall1 := repo.On("ListRules", mock.Anything, mock.Anything).Return(re.Page{Rules: []re.Rule{}}, nil) + defer repocall1.Unset() + + tickCh := make(chan time.Time) + mockTicker.On("Tick").Return((<-chan time.Time)(tickCh)) + mockTicker.On("Stop").Return() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + t.Cleanup(func() { + cancel() + time.Sleep(10 * time.Millisecond) + }) + + go func() { + _ = svc.StartScheduler(ctx) + }() + + time.Sleep(50 * time.Millisecond) + + return svc, repo, pubsub, mockTicker +} + +func newServiceForSchedulerTest(t *testing.T, runInfo chan pkglog.RunInfo) (re.Service, *mocks.Repository, *pubsubmocks.PubSub, *tmocks.Ticker) { + repo := new(mocks.Repository) + mockTicker := new(tmocks.Ticker) + idProvider := uuid.NewMock() + pubsub := pubsubmocks.NewPubSub(t) + readersSvc := new(readmocks.ReadersServiceClient) + e := new(emocks.Emailer) + + svc := re.NewService(repo, runInfo, idProvider, pubsub, pubsub, pubsub, mockTicker, e, readersSvc) + + tickCh := make(chan time.Time) + mockTicker.On("Tick").Return((<-chan time.Time)(tickCh)) + mockTicker.On("Stop").Return() + + return svc, repo, pubsub, mockTicker } func TestAddRule(t *testing.T) { - svc, repo, _, _ := newService(t, make(chan pkglog.RunInfo)) + svc, repo, _, _ := newService(t, make(chan pkglog.RunInfo, 100)) ruleName := namegen.Generate() now := time.Now().Add(time.Hour) cases := []struct { @@ -133,7 +172,7 @@ func TestAddRule(t *testing.T) { } func TestViewRule(t *testing.T) { - svc, repo, _, _ := newService(t, make(chan pkglog.RunInfo)) + svc, repo, _, _ := newService(t, make(chan pkglog.RunInfo, 100)) now := time.Now().Add(time.Hour) cases := []struct { @@ -191,7 +230,7 @@ func TestViewRule(t *testing.T) { } func TestUpdateRule(t *testing.T) { - svc, repo, _, _ := newService(t, make(chan pkglog.RunInfo)) + svc, repo, _, _ := newService(t, make(chan pkglog.RunInfo, 100)) newName := namegen.Generate() now := time.Now().Add(time.Hour) @@ -276,7 +315,7 @@ func TestUpdateRule(t *testing.T) { } func TestUpdateRuleTags(t *testing.T) { - svc, repo, _, _ := newService(t, make(chan pkglog.RunInfo)) + svc, repo, _, _ := newService(t, make(chan pkglog.RunInfo, 100)) cases := []struct { desc string @@ -332,7 +371,7 @@ func TestUpdateRuleTags(t *testing.T) { } func TestListRules(t *testing.T) { - svc, repo, _, _ := newService(t, make(chan pkglog.RunInfo)) + svc, repo, _, _ := newService(t, make(chan pkglog.RunInfo, 100)) numRules := 50 now := time.Now().Add(time.Hour) var rules []re.Rule @@ -418,7 +457,10 @@ func TestListRules(t *testing.T) { DomainID: domainID, }, pageMeta: re.PageMeta{}, - err: svcerr.ErrViewEntity, + res: re.Page{ + Rules: []re.Rule{}, + }, + err: svcerr.ErrViewEntity, }, } @@ -437,7 +479,7 @@ func TestListRules(t *testing.T) { } func TestRemoveRule(t *testing.T) { - svc, repo, _, _ := newService(t, make(chan pkglog.RunInfo)) + svc, repo, _, _ := newService(t, make(chan pkglog.RunInfo, 100)) cases := []struct { desc string @@ -477,7 +519,7 @@ func TestRemoveRule(t *testing.T) { } func TestEnableRule(t *testing.T) { - svc, repo, _, _ := newService(t, make(chan pkglog.RunInfo)) + svc, repo, _, _ := newService(t, make(chan pkglog.RunInfo, 100)) now := time.Now() @@ -536,7 +578,7 @@ func TestEnableRule(t *testing.T) { } func TestDisableRule(t *testing.T) { - svc, repo, _, _ := newService(t, make(chan pkglog.RunInfo)) + svc, repo, _, _ := newService(t, make(chan pkglog.RunInfo, 100)) now := time.Now() @@ -595,7 +637,7 @@ func TestDisableRule(t *testing.T) { } func TestHandle(t *testing.T) { - svc, repo, pubmocks, _ := newService(t, make(chan pkglog.RunInfo)) + svc, repo, pubmocks, _ := newService(t, make(chan pkglog.RunInfo, 100)) now := time.Now() scheduled := false @@ -619,7 +661,7 @@ func TestHandle(t *testing.T) { listErr: nil, }, { - desc: "consume message with rules", + desc: "consume message with enabled rules", message: &messaging.Message{ Channel: inputChannel, Created: now.Unix(), @@ -646,8 +688,41 @@ func TestHandle(t *testing.T) { }, listErr: nil, }, + { + desc: "consume message with disabled rules", + message: &messaging.Message{ + Channel: inputChannel, + Created: now.Unix(), + }, + page: re.Page{ + Rules: []re.Rule{ + { + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + InputChannel: inputChannel, + Status: re.DisabledStatus, + Logic: re.Script{ + Type: re.ScriptType(0), + }, + Outputs: re.Outputs{ + &outputs.ChannelPublisher{ + Channel: "output.channel", + Topic: "output.topic", + }, + }, + Schedule: schedule, + }, + }, + }, + listErr: nil, + }, } + // go func() { + // for range runInfo { + // } + // }() + for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { var err error @@ -657,23 +732,36 @@ func TestHandle(t *testing.T) { err = tc.listErr } }) - repoCall1 := pubmocks.On("Publish", mock.Anything, mock.Anything, mock.Anything).Return(tc.publishErr) + + enabledRulesCount := 0 + for _, rule := range tc.page.Rules { + if rule.Status == re.EnabledStatus { + enabledRulesCount++ + } + } + + if enabledRulesCount > 0 { + repoCall1 := pubmocks.On("Publish", mock.Anything, mock.Anything, mock.Anything).Return(tc.publishErr).Times(enabledRulesCount) + defer repoCall1.Unset() + } err = svc.Handle(tc.message) - assert.Nil(t, err) - assert.True(t, errors.Contains(err, tc.listErr), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.listErr, err)) + assert.NoError(t, err, "Handle should not return errors with worker architecture") + + if tc.listErr != nil { + assert.True(t, errors.Contains(err, tc.listErr), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.listErr, err)) + } repoCall.Unset() - repoCall1.Unset() }) } } func TestStartScheduler(t *testing.T) { now := time.Now().Truncate(time.Minute) - ri := make(chan pkglog.RunInfo) - svc, repo, _, ticker := newService(t, ri) + ri := make(chan pkglog.RunInfo, 100) + svc, repo, _, ticker := newServiceForSchedulerTest(t, ri) ctxCases := []struct { desc string diff --git a/re/worker.go b/re/worker.go index c75de883a..6a1c0938d 100644 --- a/re/worker.go +++ b/re/worker.go @@ -24,8 +24,6 @@ type RuleWorker struct { engine *re msgChan chan WorkerMessage ctx context.Context - cancel context.CancelFunc - g *errgroup.Group running int32 } @@ -45,12 +43,11 @@ func (w *RuleWorker) Start(ctx context.Context) { return } - w.ctx, w.cancel = context.WithCancel(ctx) - w.g, w.ctx = errgroup.WithContext(w.ctx) - - w.g.Go(func() error { - return w.run(w.ctx) - }) + w.ctx = ctx + go func() { + defer atomic.StoreInt32(&w.running, 0) + w.run(w.ctx) + }() } // Stop stops the worker goroutine and waits for it to finish. @@ -59,9 +56,7 @@ func (w *RuleWorker) Stop() error { return nil } - w.cancel() - - return w.g.Wait() + return nil } // Send sends a message to the worker for processing. @@ -89,15 +84,11 @@ func (w *RuleWorker) GetRule() Rule { } // run is the main worker loop that processes messages. -func (w *RuleWorker) run(ctx context.Context) error { - defer func() { - atomic.StoreInt32(&w.running, 0) - }() - +func (w *RuleWorker) run(ctx context.Context) { for { select { case <-ctx.Done(): - return ctx.Err() + return case workerMsg := <-w.msgChan: w.processMessage(ctx, workerMsg) } @@ -171,7 +162,7 @@ type WorkerManager struct { } // NewWorkerManager creates a new worker manager. -func NewWorkerManager(engine *re, ctx context.Context) *WorkerManager { +func NewWorkerManager(ctx context.Context, engine *re) *WorkerManager { g, ctx := errgroup.WithContext(ctx) wm := &WorkerManager{ workers: make(map[string]*RuleWorker), From c7c2b4cb67358490df28c9782e1cb18f52907b39 Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Thu, 25 Sep 2025 12:55:35 +0300 Subject: [PATCH 09/25] fix failing linter Signed-off-by: nyagamunene --- re/handlers.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/re/handlers.go b/re/handlers.go index 96a584c38..008dde846 100644 --- a/re/handlers.go +++ b/re/handlers.go @@ -200,7 +200,7 @@ func (re *re) StartScheduler(ctx context.Context) error { Protocol: protocol, Created: due.Unix(), } - + if !re.workerMgr.SendMessage(msg, r) { re.runInfo <- pkglog.RunInfo{ Level: slog.LevelWarn, @@ -213,7 +213,7 @@ func (re *re) StartScheduler(ctx context.Context) error { }, } } - + go func(ruleID string, nextDue time.Time) { if _, err := re.repo.UpdateRuleDue(ctx, ruleID, nextDue); err != nil { re.runInfo <- pkglog.RunInfo{ From 06ef9d92f66bd9715acd675ec0e108108a299de1 Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Thu, 25 Sep 2025 18:30:42 +0300 Subject: [PATCH 10/25] fix startscheduler Signed-off-by: nyagamunene --- re/handlers.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/re/handlers.go b/re/handlers.go index 008dde846..95d14c0a1 100644 --- a/re/handlers.go +++ b/re/handlers.go @@ -133,12 +133,14 @@ func (re *re) handleOutput(ctx context.Context, o Runnable, r Rule, msg *messagi func (re *re) StartScheduler(ctx context.Context) error { re.workerMgr = NewWorkerManager(ctx, re) + workerMgr := re.workerMgr + go func() { for { select { case <-ctx.Done(): return - case err := <-re.workerMgr.ErrorChan(): + case err := <-workerMgr.ErrorChan(): if err != nil { re.runInfo <- pkglog.RunInfo{ Level: slog.LevelError, From 4ea24048d87c867c292d45eb3660ed0238f1ffc7 Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Thu, 25 Sep 2025 21:51:48 +0300 Subject: [PATCH 11/25] fix ui env variable Signed-off-by: nyagamunene --- docker/.env | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/.env b/docker/.env index a8e6e7e1b..693897bda 100644 --- a/docker/.env +++ b/docker/.env @@ -396,4 +396,4 @@ MG_RELEASE_TAG=latest SMQ_ALLOW_UNVERIFIED_USER=true # Set to yes to accept the EULA for the UI services. To view the EULA visit: https://github.com/absmach/eula -MG_UI_DOCKER_ACCEPT_EULA=yes +MG_UI_DOCKER_ACCEPT_EULA=no From ecf0846467112c3a5a1f586c857cc66dbf114177 Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Thu, 25 Sep 2025 22:00:28 +0300 Subject: [PATCH 12/25] remove commented code Signed-off-by: nyagamunene --- re/service_test.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/re/service_test.go b/re/service_test.go index 454dfd28a..03c4c098f 100644 --- a/re/service_test.go +++ b/re/service_test.go @@ -718,11 +718,6 @@ func TestHandle(t *testing.T) { }, } - // go func() { - // for range runInfo { - // } - // }() - for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { var err error From 3ff2113d322a5edbb14358c8daca225300da6cf8 Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Mon, 29 Sep 2025 16:18:38 +0300 Subject: [PATCH 13/25] initial implementation Signed-off-by: nyagamunene --- re/api/transport.go | 30 ++++-- re/execution_status.go | 101 +++++++++++++++++ re/execution_status_test.go | 208 ++++++++++++++++++++++++++++++++++++ re/handlers.go | 35 +++++- re/postgres/init.go | 15 +++ re/postgres/repository.go | 52 +++++++-- re/postgres/rule.go | 110 +++++++++++-------- re/rule.go | 48 ++++++--- re/service.go | 33 ++++++ 9 files changed, 552 insertions(+), 80 deletions(-) create mode 100644 re/execution_status.go create mode 100644 re/execution_status_test.go diff --git a/re/api/transport.go b/re/api/transport.go index 54f0a41f9..ae26e0ffb 100644 --- a/re/api/transport.go +++ b/re/api/transport.go @@ -23,8 +23,9 @@ import ( ) const ( - ruleIdKey = "ruleID" - inputChannelKey = "input_channel" + ruleIdKey = "ruleID" + inputChannelKey = "input_channel" + lastRunStatusKey = "last_run_status" ) // MakeHandler creates an HTTP handler for the service endpoints. @@ -198,6 +199,10 @@ func decodeListRulesRequest(_ context.Context, r *http.Request) (any, error) { if err != nil { return nil, errors.Wrap(apiutil.ErrValidation, err) } + lrs, err := apiutil.ReadStringQuery(r, lastRunStatusKey, re.NeverRun) + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } dir, err := apiutil.ReadStringQuery(r, api.DirKey, "desc") if err != nil { return nil, errors.Wrap(apiutil.ErrValidation, err) @@ -210,6 +215,10 @@ func decodeListRulesRequest(_ context.Context, r *http.Request) (any, error) { if err != nil { return nil, errors.Wrap(apiutil.ErrValidation, err) } + lrst, err := re.ToExecutionStatus(lrs) + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } tag, err := apiutil.ReadStringQuery(r, api.TagKey, "") if err != nil { return nil, errors.Wrap(apiutil.ErrValidation, err) @@ -217,14 +226,15 @@ func decodeListRulesRequest(_ context.Context, r *http.Request) (any, error) { return listRulesReq{ PageMeta: re.PageMeta{ - Offset: offset, - Limit: limit, - Name: name, - InputChannel: ic, - Status: st, - Dir: dir, - Order: order, - Tag: tag, + Offset: offset, + Limit: limit, + Name: name, + InputChannel: ic, + Status: st, + LastRunStatus: lrst, + Dir: dir, + Order: order, + Tag: tag, }, }, nil } diff --git a/re/execution_status.go b/re/execution_status.go new file mode 100644 index 000000000..aefc32be8 --- /dev/null +++ b/re/execution_status.go @@ -0,0 +1,101 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package re + +import ( + "encoding/json" + "strings" + + svcerr "github.com/absmach/supermq/pkg/errors/service" +) + +// ExecutionStatus represents the last run status of a rule execution. +type ExecutionStatus uint8 + +// Possible execution status values. +const ( + // NeverRunStatus represents a rule that has never been executed. + NeverRunStatus ExecutionStatus = iota + // SuccessStatus represents a successful rule execution. + SuccessStatus + // FailureStatus represents a failed rule execution. + FailureStatus + // PartialSuccessStatus represents a rule execution with partial success. + PartialSuccessStatus + // QueuedStatus represents a rule that is queued for execution. + QueuedStatus + // InProgressStatus represents a rule that is currently being executed. + InProgressStatus + // AbortedStatus represents a rule execution that was aborted. + AbortedStatus + // UnknownExecutionStatus represents an unknown execution status. + UnknownExecutionStatus +) + +// String representation of the possible execution status values. +const ( + NeverRun = "never_run" + Success = "success" + Failure = "failure" + PartialSuccess = "partial_success" + Queued = "queued" + InProgress = "in_progress" + Aborted = "aborted" + UnknownExecution = "unknown" +) + +func (es ExecutionStatus) String() string { + switch es { + case NeverRunStatus: + return NeverRun + case SuccessStatus: + return Success + case FailureStatus: + return Failure + case PartialSuccessStatus: + return PartialSuccess + case QueuedStatus: + return Queued + case InProgressStatus: + return InProgress + case AbortedStatus: + return Aborted + default: + return UnknownExecution + } +} + +// ToExecutionStatus converts string value to a valid execution status. +func ToExecutionStatus(status string) (ExecutionStatus, error) { + switch status { + case NeverRun: + return NeverRunStatus, nil + case Success: + return SuccessStatus, nil + case Failure: + return FailureStatus, nil + case PartialSuccess: + return PartialSuccessStatus, nil + case Queued: + return QueuedStatus, nil + case InProgress: + return InProgressStatus, nil + case Aborted: + return AbortedStatus, nil + case "", UnknownExecution: + return UnknownExecutionStatus, nil + } + return UnknownExecutionStatus, svcerr.ErrInvalidStatus +} + +func (es ExecutionStatus) MarshalJSON() ([]byte, error) { + return json.Marshal(es.String()) +} + +func (es *ExecutionStatus) UnmarshalJSON(data []byte) error { + str := strings.Trim(string(data), "\"") + val, err := ToExecutionStatus(str) + *es = val + return err +} diff --git a/re/execution_status_test.go b/re/execution_status_test.go new file mode 100644 index 000000000..69e27aecc --- /dev/null +++ b/re/execution_status_test.go @@ -0,0 +1,208 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package re + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestExecutionStatusString(t *testing.T) { + cases := []struct { + desc string + status ExecutionStatus + want string + }{ + { + desc: "Success status", + status: SuccessStatus, + want: Success, + }, + { + desc: "Failure status", + status: FailureStatus, + want: Failure, + }, + { + desc: "Aborted status", + status: AbortedStatus, + want: Aborted, + }, + { + desc: "Queued status", + status: QueuedStatus, + want: Queued, + }, + { + desc: "In Progress status", + status: InProgressStatus, + want: InProgress, + }, + { + desc: "Partial Success status", + status: PartialSuccessStatus, + want: PartialSuccess, + }, + { + desc: "Never Run status", + status: NeverRunStatus, + want: NeverRun, + }, + { + desc: "Unknown status", + status: ExecutionStatus(99), + want: UnknownExec, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + got := tc.status.String() + assert.Equal(t, tc.want, got) + }) + } +} + +func TestToExecutionStatus(t *testing.T) { + cases := []struct { + desc string + status string + want ExecutionStatus + wantErr bool + }{ + { + desc: "Success status", + status: Success, + want: SuccessStatus, + }, + { + desc: "Failure status", + status: Failure, + want: FailureStatus, + }, + { + desc: "Aborted status", + status: Aborted, + want: AbortedStatus, + }, + { + desc: "Queued status", + status: Queued, + want: QueuedStatus, + }, + { + desc: "In Progress status", + status: InProgress, + want: InProgressStatus, + }, + { + desc: "Partial Success status", + status: PartialSuccess, + want: PartialSuccessStatus, + }, + { + desc: "Never Run status", + status: NeverRun, + want: NeverRunStatus, + }, + { + desc: "Empty string defaults to Never Run", + status: "", + want: NeverRunStatus, + }, + { + desc: "Invalid status", + status: "invalid", + wantErr: true, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + got, err := ToExecutionStatus(tc.status) + if tc.wantErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + assert.Equal(t, tc.want, got) + }) + } +} + +func TestExecutionStatusMarshalJSON(t *testing.T) { + cases := []struct { + desc string + status ExecutionStatus + want string + }{ + { + desc: "Success status", + status: SuccessStatus, + want: `"success"`, + }, + { + desc: "Failure status", + status: FailureStatus, + want: `"failure"`, + }, + { + desc: "Never Run status", + status: NeverRunStatus, + want: `"never_run"`, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + got, err := tc.status.MarshalJSON() + assert.NoError(t, err) + assert.Equal(t, tc.want, string(got)) + }) + } +} + +func TestExecutionStatusUnmarshalJSON(t *testing.T) { + cases := []struct { + desc string + data string + want ExecutionStatus + wantErr bool + }{ + { + desc: "Success status", + data: `"success"`, + want: SuccessStatus, + }, + { + desc: "Failure status", + data: `"failure"`, + want: FailureStatus, + }, + { + desc: "Never Run status", + data: `"never_run"`, + want: NeverRunStatus, + }, + { + desc: "Invalid status", + data: `"invalid"`, + wantErr: true, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + var status ExecutionStatus + err := status.UnmarshalJSON([]byte(tc.data)) + if tc.wantErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + assert.Equal(t, tc.want, status) + }) + } +} \ No newline at end of file diff --git a/re/handlers.go b/re/handlers.go index 95d14c0a1..d018f09fb 100644 --- a/re/handlers.go +++ b/re/handlers.go @@ -100,12 +100,43 @@ func (re *re) process(ctx context.Context, r Rule, msg *messaging.Message) pkglo slog.String("rule_name", r.Name), slog.Time("exec_time", time.Now().UTC()), } + + // Set rule status to in progress + re.updateRuleExecutionStatus(ctx, r.ID, InProgressStatus, "") + + var result pkglog.RunInfo switch r.Logic.Type { case GoType: - return re.processGo(ctx, details, r, msg) + result = re.processGo(ctx, details, r, msg) + default: + result = re.processLua(ctx, details, r, msg) + } + + // Update execution status based on result + var execStatus ExecutionStatus + var errorMsg string + switch result.Level { + case slog.LevelInfo: + execStatus = SuccessStatus + case slog.LevelWarn: + // Check if it's a partial success case + if strings.Contains(result.Message, "logic returned false") || strings.Contains(result.Message, "no outputs") { + execStatus = SuccessStatus + } else { + execStatus = PartialSuccessStatus + errorMsg = result.Message + } + case slog.LevelError: + execStatus = FailureStatus + errorMsg = result.Message default: - return re.processLua(ctx, details, r, msg) + execStatus = FailureStatus + errorMsg = result.Message } + + re.updateRuleExecutionStatus(ctx, r.ID, execStatus, errorMsg) + + return result } func (re *re) handleOutput(ctx context.Context, o Runnable, r Rule, msg *messaging.Message, val any) error { diff --git a/re/postgres/init.go b/re/postgres/init.go index 99eb347ff..549b02c78 100644 --- a/re/postgres/init.go +++ b/re/postgres/init.go @@ -50,6 +50,21 @@ func Migration() *migrate.MemoryMigrationSource { `ALTER TABLE rules DROP COLUMN tags;`, }, }, + { + Id: "rules_03", + Up: []string{ + `ALTER TABLE rules ADD COLUMN last_run_status SMALLINT NOT NULL DEFAULT 6 CHECK (last_run_status >= 0);`, // 6 = NeverRunStatus + `ALTER TABLE rules ADD COLUMN last_run_time TIMESTAMP;`, + `ALTER TABLE rules ADD COLUMN last_run_error_message TEXT;`, + `ALTER TABLE rules ADD COLUMN execution_count BIGINT NOT NULL DEFAULT 0;`, + }, + Down: []string{ + `ALTER TABLE rules DROP COLUMN last_run_status;`, + `ALTER TABLE rules DROP COLUMN last_run_time;`, + `ALTER TABLE rules DROP COLUMN last_run_error_message;`, + `ALTER TABLE rules DROP COLUMN execution_count;`, + }, + }, }, } } diff --git a/re/postgres/repository.go b/re/postgres/repository.go index 3596a9f7d..242a656d3 100644 --- a/re/postgres/repository.go +++ b/re/postgres/repository.go @@ -28,11 +28,14 @@ func NewRepository(db postgres.Database) re.Repository { func (repo *PostgresRepository) AddRule(ctx context.Context, r re.Rule) (re.Rule, error) { q := ` INSERT INTO rules (id, name, domain_id, tags, metadata, input_channel, input_topic, logic_type, logic_value, - outputs, start_datetime, time, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status) + outputs, start_datetime, time, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status, + last_run_status, last_run_time, last_run_error_message, execution_count) VALUES (:id, :name, :domain_id, :tags, :metadata, :input_channel, :input_topic, :logic_type, :logic_value, - :outputs, :start_datetime, :time, :recurring, :recurring_period, :created_at, :created_by, :updated_at, :updated_by, :status) + :outputs, :start_datetime, :time, :recurring, :recurring_period, :created_at, :created_by, :updated_at, :updated_by, :status, + :last_run_status, :last_run_time, :last_run_error_message, :execution_count) RETURNING id, name, domain_id, tags, metadata, input_channel, input_topic, logic_type, logic_value, - outputs, start_datetime, time, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status; + outputs, start_datetime, time, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status, + last_run_status, last_run_time, last_run_error_message, execution_count; ` dbr, err := ruleToDb(r) if err != nil { @@ -62,7 +65,8 @@ func (repo *PostgresRepository) AddRule(ctx context.Context, r re.Rule) (re.Rule func (repo *PostgresRepository) ViewRule(ctx context.Context, id string) (re.Rule, error) { q := ` SELECT id, name, domain_id, tags, metadata, input_channel, input_topic, logic_type, logic_value, outputs, - start_datetime, time, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status + start_datetime, time, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status, + last_run_status, last_run_time, last_run_error_message, execution_count FROM rules WHERE id = $1; ` @@ -87,11 +91,31 @@ func (repo *PostgresRepository) UpdateRuleStatus(ctx context.Context, r re.Rule) SET status = :status, updated_at = :updated_at, updated_by = :updated_by WHERE id = :id RETURNING id, name, domain_id, tags, metadata, input_channel, input_topic, logic_type, logic_value, - outputs, start_datetime, time, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status;` + outputs, start_datetime, time, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status, + last_run_status, last_run_time, last_run_error_message, execution_count;` return repo.update(ctx, r, q) } +func (repo *PostgresRepository) UpdateRuleExecutionStatus(ctx context.Context, r re.Rule) error { + q := `UPDATE rules + SET last_run_status = :last_run_status, last_run_time = :last_run_time, + last_run_error_message = :last_run_error_message, execution_count = :execution_count + WHERE id = :id;` + + dbr, err := ruleToDb(r) + if err != nil { + return errors.Wrap(repoerr.ErrUpdateEntity, err) + } + + _, err = repo.DB.NamedExecContext(ctx, q, dbr) + if err != nil { + return postgres.HandleError(repoerr.ErrUpdateEntity, err) + } + + return nil +} + func (repo *PostgresRepository) UpdateRule(ctx context.Context, r re.Rule) (re.Rule, error) { var query []string var upq string @@ -119,7 +143,8 @@ func (repo *PostgresRepository) UpdateRule(ctx context.Context, r re.Rule) (re.R UPDATE rules SET %s updated_at = :updated_at, updated_by = :updated_by WHERE id = :id RETURNING id, name, domain_id, tags, metadata, input_channel, input_topic, logic_type, logic_value, - outputs, start_datetime, time, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status; + outputs, start_datetime, time, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status, + last_run_status, last_run_time, last_run_error_message, execution_count; `, upq) return repo.update(ctx, r, q) @@ -129,7 +154,8 @@ func (repo *PostgresRepository) UpdateRuleTags(ctx context.Context, r re.Rule) ( q := `UPDATE rules SET tags = :tags, updated_at = :updated_at, updated_by = :updated_by WHERE id = :id AND status = :status RETURNING id, name, domain_id, tags, metadata, input_channel, input_topic, logic_type, logic_value, - outputs, start_datetime, time, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status;` + outputs, start_datetime, time, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status, + last_run_status, last_run_time, last_run_error_message, execution_count;` r.Status = re.EnabledStatus return repo.update(ctx, r, q) @@ -141,7 +167,8 @@ func (repo *PostgresRepository) UpdateRuleSchedule(ctx context.Context, r re.Rul SET start_datetime = :start_datetime, time = :time, recurring = :recurring, recurring_period = :recurring_period, updated_at = :updated_at, updated_by = :updated_by WHERE id = :id RETURNING id, name, domain_id, tags, metadata, input_channel, input_topic, logic_type, logic_value, - outputs, start_datetime, time, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status; + outputs, start_datetime, time, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status, + last_run_status, last_run_time, last_run_error_message, execution_count; ` return repo.update(ctx, r, q) } @@ -223,7 +250,8 @@ func (repo *PostgresRepository) ListRules(ctx context.Context, pm re.PageMeta) ( q := fmt.Sprintf(` SELECT id, name, domain_id, tags, input_channel, input_topic, logic_type, logic_value, outputs, - start_datetime, time, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status + start_datetime, time, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status, + last_run_status, last_run_time, last_run_error_message, execution_count FROM rules r %s %s %s; `, pq, orderClause, pgData) rows, err := repo.DB.NamedQueryContext(ctx, q, pm) @@ -266,7 +294,8 @@ func (repo *PostgresRepository) UpdateRuleDue(ctx context.Context, id string, du UPDATE rules SET time = :time, updated_at = :updated_at WHERE id = :id RETURNING id, name, domain_id, tags, metadata, input_channel, input_topic, logic_type, logic_value, - outputs, start_datetime, time, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status; + outputs, start_datetime, time, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status, + last_run_status, last_run_time, last_run_error_message, execution_count; ` dbr := dbRule{ ID: id, @@ -304,6 +333,9 @@ func pageRulesQuery(pm re.PageMeta) string { if pm.Status != re.AllStatus { query = append(query, "r.status = :status") } + if pm.LastRunStatus != re.NeverRunStatus { + query = append(query, "r.last_run_status = :last_run_status") + } if pm.Domain != "" { query = append(query, "r.domain_id = :domain_id") } diff --git a/re/postgres/rule.go b/re/postgres/rule.go index 3027e3d66..4551269d8 100644 --- a/re/postgres/rule.go +++ b/re/postgres/rule.go @@ -16,25 +16,29 @@ import ( // dbRule represents the database structure for a Rule. type dbRule struct { - ID string `db:"id"` - Name string `db:"name"` - DomainID string `db:"domain_id"` - Tags pgtype.TextArray `db:"tags,omitempty"` - Metadata []byte `db:"metadata,omitempty"` - InputChannel string `db:"input_channel"` - InputTopic sql.NullString `db:"input_topic"` - LogicType re.ScriptType `db:"logic_type"` - LogicValue string `db:"logic_value"` - Outputs []byte `db:"outputs"` - StartDateTime sql.NullTime `db:"start_datetime"` - Time sql.NullTime `db:"time"` - Recurring schedule.Recurring `db:"recurring"` - RecurringPeriod uint `db:"recurring_period"` - Status re.Status `db:"status"` - CreatedAt time.Time `db:"created_at"` - CreatedBy string `db:"created_by"` - UpdatedAt time.Time `db:"updated_at"` - UpdatedBy string `db:"updated_by"` + ID string `db:"id"` + Name string `db:"name"` + DomainID string `db:"domain_id"` + Tags pgtype.TextArray `db:"tags,omitempty"` + Metadata []byte `db:"metadata,omitempty"` + InputChannel string `db:"input_channel"` + InputTopic sql.NullString `db:"input_topic"` + LogicType re.ScriptType `db:"logic_type"` + LogicValue string `db:"logic_value"` + Outputs []byte `db:"outputs"` + StartDateTime sql.NullTime `db:"start_datetime"` + Time sql.NullTime `db:"time"` + Recurring schedule.Recurring `db:"recurring"` + RecurringPeriod uint `db:"recurring_period"` + Status re.Status `db:"status"` + LastRunStatus re.ExecutionStatus `db:"last_run_status"` + LastRunTime sql.NullTime `db:"last_run_time"` + LastRunErrorMessage sql.NullString `db:"last_run_error_message"` + ExecutionCount uint64 `db:"execution_count"` + CreatedAt time.Time `db:"created_at"` + CreatedBy string `db:"created_by"` + UpdatedAt time.Time `db:"updated_at"` + UpdatedBy string `db:"updated_by"` } func ruleToDb(r re.Rule) (dbRule, error) { @@ -55,6 +59,13 @@ func ruleToDb(r re.Rule) (dbRule, error) { if !r.Schedule.Time.IsZero() { t.Valid = true } + + lastRunTime := sql.NullTime{} + if r.LastRunTime != nil && !r.LastRunTime.IsZero() { + lastRunTime.Time = *r.LastRunTime + lastRunTime.Valid = true + } + var tags pgtype.TextArray if err := tags.Set(r.Tags); err != nil { return dbRule{}, err @@ -66,25 +77,29 @@ func ruleToDb(r re.Rule) (dbRule, error) { } return dbRule{ - ID: r.ID, - Name: r.Name, - DomainID: r.DomainID, - Tags: tags, - Metadata: metadata, - InputChannel: r.InputChannel, - InputTopic: toNullString(r.InputTopic), - LogicType: r.Logic.Type, - LogicValue: r.Logic.Value, - Outputs: outputs, - StartDateTime: start, - Time: t, - Recurring: r.Schedule.Recurring, - RecurringPeriod: r.Schedule.RecurringPeriod, - Status: r.Status, - CreatedAt: r.CreatedAt, - CreatedBy: r.CreatedBy, - UpdatedAt: r.UpdatedAt, - UpdatedBy: r.UpdatedBy, + ID: r.ID, + Name: r.Name, + DomainID: r.DomainID, + Tags: tags, + Metadata: metadata, + InputChannel: r.InputChannel, + InputTopic: toNullString(r.InputTopic), + LogicType: r.Logic.Type, + LogicValue: r.Logic.Value, + Outputs: outputs, + StartDateTime: start, + Time: t, + Recurring: r.Schedule.Recurring, + RecurringPeriod: r.Schedule.RecurringPeriod, + Status: r.Status, + LastRunStatus: r.LastRunStatus, + LastRunTime: lastRunTime, + LastRunErrorMessage: toNullString(r.LastRunErrorMessage), + ExecutionCount: r.ExecutionCount, + CreatedAt: r.CreatedAt, + CreatedBy: r.CreatedBy, + UpdatedAt: r.UpdatedAt, + UpdatedBy: r.UpdatedBy, }, nil } @@ -108,6 +123,11 @@ func dbToRule(dto dbRule) (re.Rule, error) { } } + var lastRunTime *time.Time + if dto.LastRunTime.Valid { + lastRunTime = &dto.LastRunTime.Time + } + return re.Rule{ ID: dto.ID, Name: dto.Name, @@ -127,11 +147,15 @@ func dbToRule(dto dbRule) (re.Rule, error) { Recurring: dto.Recurring, RecurringPeriod: dto.RecurringPeriod, }, - Status: dto.Status, - CreatedAt: dto.CreatedAt, - CreatedBy: dto.CreatedBy, - UpdatedAt: dto.UpdatedAt, - UpdatedBy: dto.UpdatedBy, + Status: dto.Status, + LastRunStatus: dto.LastRunStatus, + LastRunTime: lastRunTime, + LastRunErrorMessage: fromNullString(dto.LastRunErrorMessage), + ExecutionCount: dto.ExecutionCount, + CreatedAt: dto.CreatedAt, + CreatedBy: dto.CreatedBy, + UpdatedAt: dto.UpdatedAt, + UpdatedBy: dto.UpdatedBy, }, nil } diff --git a/re/rule.go b/re/rule.go index 695dd28fc..a9e658d36 100644 --- a/re/rule.go +++ b/re/rule.go @@ -65,21 +65,28 @@ type Rule struct { Outputs Outputs `json:"outputs,omitempty"` Schedule schedule.Schedule `json:"schedule,omitempty"` Status Status `json:"status"` - CreatedAt time.Time `json:"created_at"` - CreatedBy string `json:"created_by"` - UpdatedAt time.Time `json:"updated_at"` - UpdatedBy string `json:"updated_by"` + // Last execution tracking + LastRunStatus ExecutionStatus `json:"last_run_status"` + LastRunTime *time.Time `json:"last_run_time,omitempty"` + LastRunErrorMessage string `json:"last_run_error_message,omitempty"` + ExecutionCount uint64 `json:"execution_count"` + CreatedAt time.Time `json:"created_at"` + CreatedBy string `json:"created_by"` + UpdatedAt time.Time `json:"updated_at"` + UpdatedBy string `json:"updated_by"` } // EventEncode converts a Rule struct to map[string]any at event producer. func (r Rule) EventEncode() (map[string]any, error) { m := map[string]any{ - "id": r.ID, - "name": r.Name, - "created_at": r.CreatedAt.Format(time.RFC3339Nano), - "created_by": r.CreatedBy, - "schedule": r.Schedule.EventEncode(), - "status": r.Status.String(), + "id": r.ID, + "name": r.Name, + "created_at": r.CreatedAt.Format(time.RFC3339Nano), + "created_by": r.CreatedBy, + "schedule": r.Schedule.EventEncode(), + "status": r.Status.String(), + "last_run_status": r.LastRunStatus.String(), + "execution_count": r.ExecutionCount, } if r.Name != "" { @@ -98,6 +105,14 @@ func (r Rule) EventEncode() (map[string]any, error) { m["updated_by"] = r.UpdatedBy } + if r.LastRunTime != nil { + m["last_run_time"] = r.LastRunTime.Format(time.RFC3339Nano) + } + + if r.LastRunErrorMessage != "" { + m["last_run_error_message"] = r.LastRunErrorMessage + } + if len(r.Metadata) > 0 { m["metadata"] = r.Metadata } @@ -175,6 +190,7 @@ type PageMeta struct { Scheduled *bool `json:"scheduled,omitempty"` OutputChannel string `json:"output_channel,omitempty" db:"output_channel"` Status Status `json:"status,omitempty" db:"status"` + LastRunStatus ExecutionStatus `json:"last_run_status,omitempty" db:"last_run_status"` Domain string `json:"domain_id,omitempty" db:"domain_id"` Tag string `json:"tag,omitempty"` ScheduledBefore *time.Time `json:"scheduled_before,omitempty" db:"scheduled_before"` // Filter rules scheduled before this time @@ -185,11 +201,12 @@ type PageMeta struct { // EventEncode converts a PageMeta struct to map[string]any. func (pm PageMeta) EventEncode() map[string]any { m := map[string]any{ - "total": pm.Total, - "offset": pm.Offset, - "limit": pm.Limit, - "status": pm.Status.String(), - "domain_id": pm.Domain, + "total": pm.Total, + "offset": pm.Offset, + "limit": pm.Limit, + "status": pm.Status.String(), + "last_run_status": pm.LastRunStatus.String(), + "domain_id": pm.Domain, } if pm.Dir != "" { @@ -256,6 +273,7 @@ type Repository interface { UpdateRuleSchedule(ctx context.Context, r Rule) (Rule, error) RemoveRule(ctx context.Context, id string) error UpdateRuleStatus(ctx context.Context, r Rule) (Rule, error) + UpdateRuleExecutionStatus(ctx context.Context, r Rule) error ListRules(ctx context.Context, pm PageMeta) (Page, error) UpdateRuleDue(ctx context.Context, id string, due time.Time) (Rule, error) } diff --git a/re/service.go b/re/service.go index ae70c7c75..934516c8f 100644 --- a/re/service.go +++ b/re/service.go @@ -5,6 +5,8 @@ package re import ( "context" + "fmt" + "log/slog" "time" grpcReadersV1 "github.com/absmach/magistrala/api/grpc/readers/v1" @@ -77,6 +79,8 @@ func (re *re) AddRule(ctx context.Context, session authn.Session, r Rule) (Rule, r.CreatedBy = session.UserID r.DomainID = session.DomainID r.Status = EnabledStatus + r.LastRunStatus = NeverRunStatus + r.ExecutionCount = 0 if !r.Schedule.StartDateTime.IsZero() { r.Schedule.StartDateTime = now @@ -215,3 +219,32 @@ func (re *re) DisableRule(ctx context.Context, session authn.Session, id string) func (re *re) Cancel() error { return re.workerMgr.StopAll() } + +// updateRuleExecutionStatus updates the execution status of a rule +func (re *re) updateRuleExecutionStatus(ctx context.Context, ruleID string, status ExecutionStatus, errorMessage string) { + now := time.Now().UTC() + rule := Rule{ + ID: ruleID, + LastRunStatus: status, + LastRunTime: &now, + LastRunErrorMessage: errorMessage, + } + + if status == SuccessStatus || status == PartialSuccessStatus { + currentRule, err := re.repo.ViewRule(ctx, ruleID) + if err == nil { + rule.ExecutionCount = currentRule.ExecutionCount + 1 + } + } + + if err := re.repo.UpdateRuleExecutionStatus(ctx, rule); err != nil { + re.runInfo <- pkglog.RunInfo{ + Level: slog.LevelWarn, + Message: fmt.Sprintf("failed to update rule execution status: %s", err), + Details: []slog.Attr{ + slog.String("rule_id", ruleID), + slog.String("status", status.String()), + }, + } + } +} From 7d7288ef027226d8ce23d65a70fa07afec133131 Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Mon, 29 Sep 2025 16:47:48 +0300 Subject: [PATCH 14/25] remove comments Signed-off-by: nyagamunene --- re/handlers.go | 3 --- re/service.go | 1 - 2 files changed, 4 deletions(-) diff --git a/re/handlers.go b/re/handlers.go index d018f09fb..607a57d78 100644 --- a/re/handlers.go +++ b/re/handlers.go @@ -101,7 +101,6 @@ func (re *re) process(ctx context.Context, r Rule, msg *messaging.Message) pkglo slog.Time("exec_time", time.Now().UTC()), } - // Set rule status to in progress re.updateRuleExecutionStatus(ctx, r.ID, InProgressStatus, "") var result pkglog.RunInfo @@ -112,14 +111,12 @@ func (re *re) process(ctx context.Context, r Rule, msg *messaging.Message) pkglo result = re.processLua(ctx, details, r, msg) } - // Update execution status based on result var execStatus ExecutionStatus var errorMsg string switch result.Level { case slog.LevelInfo: execStatus = SuccessStatus case slog.LevelWarn: - // Check if it's a partial success case if strings.Contains(result.Message, "logic returned false") || strings.Contains(result.Message, "no outputs") { execStatus = SuccessStatus } else { diff --git a/re/service.go b/re/service.go index 934516c8f..aed90a108 100644 --- a/re/service.go +++ b/re/service.go @@ -220,7 +220,6 @@ func (re *re) Cancel() error { return re.workerMgr.StopAll() } -// updateRuleExecutionStatus updates the execution status of a rule func (re *re) updateRuleExecutionStatus(ctx context.Context, ruleID string, status ExecutionStatus, errorMessage string) { now := time.Now().UTC() rule := Rule{ From 759cbcbbb090ac73c14f5f8e42873082d98719d2 Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Tue, 30 Sep 2025 19:23:19 +0300 Subject: [PATCH 15/25] remove processing field Signed-off-by: nyagamunene --- docker/.env | 2 +- re/api/endpoints.go | 19 ++++ re/api/requests.go | 12 +++ re/api/responses.go | 15 ++++ re/api/transport.go | 13 +++ re/events/events.go | 14 +++ re/events/streams.go | 16 ++++ re/golang.go | 6 ++ re/handlers.go | 30 ++++++- re/lua.go | 6 ++ re/middleware/authorization.go | 24 ++++++ re/middleware/logging.go | 19 +++- re/rule.go | 85 +++++++++--------- re/service.go | 19 +++- re/worker.go | 153 ++++++++++++++++++++++++++++++--- 15 files changed, 374 insertions(+), 59 deletions(-) diff --git a/docker/.env b/docker/.env index 693897bda..a8e6e7e1b 100644 --- a/docker/.env +++ b/docker/.env @@ -396,4 +396,4 @@ MG_RELEASE_TAG=latest SMQ_ALLOW_UNVERIFIED_USER=true # Set to yes to accept the EULA for the UI services. To view the EULA visit: https://github.com/absmach/eula -MG_UI_DOCKER_ACCEPT_EULA=no +MG_UI_DOCKER_ACCEPT_EULA=yes diff --git a/re/api/endpoints.go b/re/api/endpoints.go index fa8c3b356..f0773bf67 100644 --- a/re/api/endpoints.go +++ b/re/api/endpoints.go @@ -203,3 +203,22 @@ func disableRuleEndpoint(s re.Service) endpoint.Endpoint { return updateRuleStatusRes{Rule: rule}, err } } + +func abortRuleExecutionEndpoint(s re.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return nil, svcerr.ErrAuthorization + } + + req := request.(abortRuleExecutionReq) + if err := req.validate(); err != nil { + return abortRuleExecutionRes{}, err + } + err := s.AbortRuleExecution(ctx, session, req.id) + if err != nil { + return abortRuleExecutionRes{}, err + } + return abortRuleExecutionRes{}, nil + } +} diff --git a/re/api/requests.go b/re/api/requests.go index cdd7a9c6a..b52b9fec3 100644 --- a/re/api/requests.go +++ b/re/api/requests.go @@ -134,3 +134,15 @@ func (req deleteRuleReq) validate() error { return nil } + +type abortRuleExecutionReq struct { + id string +} + +func (req abortRuleExecutionReq) validate() error { + if req.id == "" { + return apiutil.ErrMissingID + } + + return nil +} diff --git a/re/api/responses.go b/re/api/responses.go index d13753dc1..75951ad5b 100644 --- a/re/api/responses.go +++ b/re/api/responses.go @@ -18,6 +18,7 @@ var ( _ supermq.Response = (*rulesPageRes)(nil) _ supermq.Response = (*updateRuleRes)(nil) _ supermq.Response = (*deleteRuleRes)(nil) + _ supermq.Response = (*abortRuleExecutionRes)(nil) ) type pageRes struct { @@ -136,3 +137,17 @@ func (res deleteRuleRes) Headers() map[string]string { func (res deleteRuleRes) Empty() bool { return true } + +type abortRuleExecutionRes struct{} + +func (res abortRuleExecutionRes) Code() int { + return http.StatusAccepted +} + +func (res abortRuleExecutionRes) Headers() map[string]string { + return map[string]string{} +} + +func (res abortRuleExecutionRes) Empty() bool { + return true +} diff --git a/re/api/transport.go b/re/api/transport.go index ae26e0ffb..5a797276b 100644 --- a/re/api/transport.go +++ b/re/api/transport.go @@ -100,6 +100,13 @@ func MakeHandler(svc re.Service, authn smqauthn.AuthNMiddleware, mux *chi.Mux, l api.EncodeResponse, opts..., ), "disable_rule").ServeHTTP) + + r.Post("/abort", otelhttp.NewHandler(kithttp.NewServer( + abortRuleExecutionEndpoint(svc), + decodeAbortRuleExecutionRequest, + api.EncodeResponse, + opts..., + ), "abort_rule_execution").ServeHTTP) }) }) }) @@ -244,3 +251,9 @@ func decodeDeleteRuleRequest(_ context.Context, r *http.Request) (any, error) { return deleteRuleReq{id: id}, nil } + +func decodeAbortRuleExecutionRequest(_ context.Context, r *http.Request) (any, error) { + id := chi.URLParam(r, ruleIdKey) + + return abortRuleExecutionReq{id: id}, nil +} diff --git a/re/events/events.go b/re/events/events.go index ffc60228d..9e3fd8dff 100644 --- a/re/events/events.go +++ b/re/events/events.go @@ -21,6 +21,7 @@ const ( ruleUpdateSchedule = rulePrefix + "update_schedule" ruleEnable = rulePrefix + "enable" ruleDisable = rulePrefix + "disable" + ruleAbort = rulePrefix + "abort" ruleRemove = rulePrefix + "remove" ) @@ -33,6 +34,7 @@ var ( _ events.Event = (*updateRuleScheduleEvent)(nil) _ events.Event = (*enableRuleEvent)(nil) _ events.Event = (*disableRuleEvent)(nil) + _ events.Event = (*abortRuleExecutionEvent)(nil) _ events.Event = (*removeRuleEvent)(nil) ) @@ -187,3 +189,15 @@ func (rre removeRuleEvent) Encode() (map[string]any, error) { val["operation"] = ruleRemove return val, nil } + +type abortRuleExecutionEvent struct { + id string + baseRuleEvent +} + +func (aree abortRuleExecutionEvent) Encode() (map[string]any, error) { + val := aree.baseRuleEvent.Encode() + val["id"] = aree.id + val["operation"] = ruleAbort + return val, nil +} diff --git a/re/events/streams.go b/re/events/streams.go index dfeece402..b253c05d0 100644 --- a/re/events/streams.go +++ b/re/events/streams.go @@ -24,6 +24,7 @@ const ( UpdateScheduleStream = supermqPrefix + ruleUpdateSchedule EnableStream = supermqPrefix + ruleEnable DisableStream = supermqPrefix + ruleDisable + AbortStream = supermqPrefix + ruleAbort RemoveStream = supermqPrefix + ruleRemove ) @@ -183,6 +184,21 @@ func (es *eventStore) DisableRule(ctx context.Context, session authn.Session, id return rule, nil } +func (es *eventStore) AbortRuleExecution(ctx context.Context, session authn.Session, id string) error { + err := es.svc.AbortRuleExecution(ctx, session, id) + if err != nil { + return err + } + event := abortRuleExecutionEvent{ + id: id, + baseRuleEvent: newBaseRuleEvent(session, middleware.GetReqID(ctx)), + } + if err := es.Publish(ctx, AbortStream, event); err != nil { + return err + } + return nil +} + func (es *eventStore) StartScheduler(ctx context.Context) error { return es.svc.StartScheduler(ctx) } diff --git a/re/golang.go b/re/golang.go index a7a831eeb..081500c28 100644 --- a/re/golang.go +++ b/re/golang.go @@ -31,6 +31,12 @@ type message struct { } func (re *re) processGo(ctx context.Context, details []slog.Attr, r Rule, msg *messaging.Message) pkglog.RunInfo { + select { + case <-ctx.Done(): + return pkglog.RunInfo{Level: slog.LevelError, Details: details, Message: "rule execution was cancelled"} + default: + } + i := golang.New(golang.Options{}) if err := i.Use(stdlib.Symbols); err != nil { return pkglog.RunInfo{Level: slog.LevelError, Details: details, Message: err.Error()} diff --git a/re/handlers.go b/re/handlers.go index 607a57d78..f74c2b487 100644 --- a/re/handlers.go +++ b/re/handlers.go @@ -55,6 +55,12 @@ func (re *re) Handle(msg *messaging.Message) error { for _, r := range page.Rules { if matchSubject(msg.Subtopic, r.InputTopic) { + if workerStatus := re.workerMgr.GetWorkerStatus(r.ID); workerStatus != nil { + if processing, ok := workerStatus["processing"].(bool); ok && processing { + re.updateRuleExecutionStatus(ctx, r.ID, QueuedStatus, nil) + } + } + if !re.workerMgr.SendMessage(msg, r) { re.runInfo <- pkglog.RunInfo{ Level: slog.LevelWarn, @@ -101,7 +107,7 @@ func (re *re) process(ctx context.Context, r Rule, msg *messaging.Message) pkglo slog.Time("exec_time", time.Now().UTC()), } - re.updateRuleExecutionStatus(ctx, r.ID, InProgressStatus, "") + re.updateRuleExecutionStatus(ctx, r.ID, InProgressStatus, nil) var result pkglog.RunInfo switch r.Logic.Type { @@ -131,12 +137,24 @@ func (re *re) process(ctx context.Context, r Rule, msg *messaging.Message) pkglo errorMsg = result.Message } - re.updateRuleExecutionStatus(ctx, r.ID, execStatus, errorMsg) + var execError error + if errorMsg != "" { + execError = fmt.Errorf("%s", errorMsg) + } + + re.updateRuleExecutionStatus(ctx, r.ID, execStatus, execError) return result } func (re *re) handleOutput(ctx context.Context, o Runnable, r Rule, msg *messaging.Message, val any) error { + // Check if context is cancelled before handling output + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + switch o := o.(type) { case *outputs.Alarm: o.AlarmsPub = re.alarmsPub @@ -231,6 +249,14 @@ func (re *re) StartScheduler(ctx context.Context) error { Created: due.Unix(), } + // Check worker status for scheduled rules too + if workerStatus := re.workerMgr.GetWorkerStatus(r.ID); workerStatus != nil { + if processing, ok := workerStatus["processing"].(bool); ok && processing { + // Worker is busy, scheduled message will be queued + re.updateRuleExecutionStatus(ctx, r.ID, QueuedStatus, nil) + } + } + if !re.workerMgr.SendMessage(msg, r) { re.runInfo <- pkglog.RunInfo{ Level: slog.LevelWarn, diff --git a/re/lua.go b/re/lua.go index 21e26592f..332125c1c 100644 --- a/re/lua.go +++ b/re/lua.go @@ -32,6 +32,12 @@ import ( const payloadKey = "payload" func (re *re) processLua(ctx context.Context, details []slog.Attr, r Rule, msg *messaging.Message) pkglog.RunInfo { + select { + case <-ctx.Done(): + return pkglog.RunInfo{Level: slog.LevelError, Details: details, Message: "rule execution was cancelled"} + default: + } + l := lua.NewState() defer l.Close() preload(l) diff --git a/re/middleware/authorization.go b/re/middleware/authorization.go index 3dfe56fe6..a513042e7 100644 --- a/re/middleware/authorization.go +++ b/re/middleware/authorization.go @@ -258,6 +258,30 @@ func (am *authorizationMiddleware) DisableRule(ctx context.Context, session auth return am.svc.DisableRule(ctx, session, id) } +func (am *authorizationMiddleware) AbortRuleExecution(ctx context.Context, session authn.Session, id string) error { + if err := am.authorize(ctx, smqauthz.PolicyReq{ + Domain: session.DomainID, + SubjectType: policies.UserType, + SubjectKind: policies.UsersKind, + Subject: session.DomainUserID, + Object: session.DomainID, + ObjectType: policies.DomainType, + Permission: policies.MembershipPermission, + }); err != nil { + return errors.Wrap(errDomainUpdateRules, err) + } + + params := map[string]any{ + "entity_id": id, + } + + if err := am.callOut(ctx, session, re.OpAbortRuleExecution, params); err != nil { + return err + } + + return am.svc.AbortRuleExecution(ctx, session, id) +} + func (am *authorizationMiddleware) StartScheduler(ctx context.Context) error { return am.svc.StartScheduler(ctx) } diff --git a/re/middleware/logging.go b/re/middleware/logging.go index 948e528ed..6f5c7f1bf 100644 --- a/re/middleware/logging.go +++ b/re/middleware/logging.go @@ -200,6 +200,23 @@ func (lm *loggingMiddleware) DisableRule(ctx context.Context, session authn.Sess return lm.svc.DisableRule(ctx, session, id) } +func (lm *loggingMiddleware) AbortRuleExecution(ctx context.Context, session authn.Session, id string) (err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("domain_id", session.DomainID), + slog.String("rule_id", id), + } + if err != nil { + args = append(args, slog.String("error", err.Error())) + lm.logger.Warn("Abort rule execution failed", args...) + return + } + lm.logger.Info("Abort rule execution completed successfully", args...) + }(time.Now()) + return lm.svc.AbortRuleExecution(ctx, session, id) +} + func (lm *loggingMiddleware) StartScheduler(ctx context.Context) (err error) { defer func(begin time.Time) { args := []any{ @@ -238,5 +255,5 @@ func (lm *loggingMiddleware) Handle(msg *messaging.Message) (err error) { } func (lm *loggingMiddleware) Cancel() error { - return lm.Cancel() + return lm.svc.Cancel() } diff --git a/re/rule.go b/re/rule.go index a9e658d36..8286f7a7f 100644 --- a/re/rule.go +++ b/re/rule.go @@ -21,15 +21,16 @@ const ( ) const ( - OpAddRule = "OpAddRule" - OpViewRule = "OpViewRule" - OpUpdateRule = "OpUpdateRule" - OpUpdateRuleTags = "OpUpdateRuleTags" - OpUpdateRuleSchedule = "OpUpdateRuleSchedule" - OpListRules = "OpListRules" - OpRemoveRule = "OpRemoveRule" - OpEnableRule = "OpEnableRule" - OpDisableRule = "OpDisableRule" + OpAddRule = "OpAddRule" + OpViewRule = "OpViewRule" + OpUpdateRule = "OpUpdateRule" + OpUpdateRuleTags = "OpUpdateRuleTags" + OpUpdateRuleSchedule = "OpUpdateRuleSchedule" + OpListRules = "OpListRules" + OpRemoveRule = "OpRemoveRule" + OpEnableRule = "OpEnableRule" + OpDisableRule = "OpDisableRule" + OpAbortRuleExecution = "OpAbortRuleExecution" ) type ( @@ -54,39 +55,38 @@ var outputRegistry = map[outputs.OutputType]func() Runnable{ } type Rule struct { - ID string `json:"id"` - Name string `json:"name"` - DomainID string `json:"domain"` - Metadata Metadata `json:"metadata,omitempty"` - Tags []string `json:"tags,omitempty"` - InputChannel string `json:"input_channel"` - InputTopic string `json:"input_topic"` - Logic Script `json:"logic"` - Outputs Outputs `json:"outputs,omitempty"` - Schedule schedule.Schedule `json:"schedule,omitempty"` - Status Status `json:"status"` - // Last execution tracking - LastRunStatus ExecutionStatus `json:"last_run_status"` - LastRunTime *time.Time `json:"last_run_time,omitempty"` - LastRunErrorMessage string `json:"last_run_error_message,omitempty"` - ExecutionCount uint64 `json:"execution_count"` - CreatedAt time.Time `json:"created_at"` - CreatedBy string `json:"created_by"` - UpdatedAt time.Time `json:"updated_at"` - UpdatedBy string `json:"updated_by"` + ID string `json:"id"` + Name string `json:"name"` + DomainID string `json:"domain"` + Metadata Metadata `json:"metadata,omitempty"` + Tags []string `json:"tags,omitempty"` + InputChannel string `json:"input_channel"` + InputTopic string `json:"input_topic"` + Logic Script `json:"logic"` + Outputs Outputs `json:"outputs,omitempty"` + Schedule schedule.Schedule `json:"schedule,omitempty"` + Status Status `json:"status"` + LastRunStatus ExecutionStatus `json:"last_run_status"` + LastRunTime *time.Time `json:"last_run_time,omitempty"` + LastRunErrorMessage string `json:"last_run_error_message,omitempty"` + ExecutionCount uint64 `json:"execution_count"` + CreatedAt time.Time `json:"created_at"` + CreatedBy string `json:"created_by"` + UpdatedAt time.Time `json:"updated_at"` + UpdatedBy string `json:"updated_by"` } // EventEncode converts a Rule struct to map[string]any at event producer. func (r Rule) EventEncode() (map[string]any, error) { m := map[string]any{ - "id": r.ID, - "name": r.Name, - "created_at": r.CreatedAt.Format(time.RFC3339Nano), - "created_by": r.CreatedBy, - "schedule": r.Schedule.EventEncode(), - "status": r.Status.String(), - "last_run_status": r.LastRunStatus.String(), - "execution_count": r.ExecutionCount, + "id": r.ID, + "name": r.Name, + "created_at": r.CreatedAt.Format(time.RFC3339Nano), + "created_by": r.CreatedBy, + "schedule": r.Schedule.EventEncode(), + "status": r.Status.String(), + "last_run_status": r.LastRunStatus.String(), + "execution_count": r.ExecutionCount, } if r.Name != "" { @@ -201,12 +201,12 @@ type PageMeta struct { // EventEncode converts a PageMeta struct to map[string]any. func (pm PageMeta) EventEncode() map[string]any { m := map[string]any{ - "total": pm.Total, - "offset": pm.Offset, - "limit": pm.Limit, - "status": pm.Status.String(), + "total": pm.Total, + "offset": pm.Offset, + "limit": pm.Limit, + "status": pm.Status.String(), "last_run_status": pm.LastRunStatus.String(), - "domain_id": pm.Domain, + "domain_id": pm.Domain, } if pm.Dir != "" { @@ -261,6 +261,7 @@ type Service interface { RemoveRule(ctx context.Context, session authn.Session, id string) error EnableRule(ctx context.Context, session authn.Session, id string) (Rule, error) DisableRule(ctx context.Context, session authn.Session, id string) (Rule, error) + AbortRuleExecution(ctx context.Context, session authn.Session, id string) error StartScheduler(ctx context.Context) error } diff --git a/re/service.go b/re/service.go index aed90a108..dec2d1622 100644 --- a/re/service.go +++ b/re/service.go @@ -220,13 +220,28 @@ func (re *re) Cancel() error { return re.workerMgr.StopAll() } -func (re *re) updateRuleExecutionStatus(ctx context.Context, ruleID string, status ExecutionStatus, errorMessage string) { +func (re *re) AbortRuleExecution(ctx context.Context, session authn.Session, id string) error { + if _, err := re.repo.ViewRule(ctx, id); err != nil { + return errors.Wrap(svcerr.ErrViewEntity, err) + } + + if re.workerMgr != nil { + re.workerMgr.AbortRule(id) + } + + return nil +} + +func (re *re) updateRuleExecutionStatus(ctx context.Context, ruleID string, status ExecutionStatus, err error) { now := time.Now().UTC() rule := Rule{ ID: ruleID, LastRunStatus: status, LastRunTime: &now, - LastRunErrorMessage: errorMessage, + } + + if err != nil { + rule.LastRunErrorMessage = err.Error() } if status == SuccessStatus || status == PartialSuccessStatus { diff --git a/re/worker.go b/re/worker.go index 6a1c0938d..1d1a0a22f 100644 --- a/re/worker.go +++ b/re/worker.go @@ -5,6 +5,7 @@ package re import ( "context" + "fmt" "sync" "sync/atomic" @@ -20,20 +21,23 @@ type WorkerMessage struct { // RuleWorker manages execution of a single rule in its own goroutine. type RuleWorker struct { - rule Rule - engine *re - msgChan chan WorkerMessage - ctx context.Context - running int32 + rule Rule + engine *re + msgChan chan WorkerMessage + ctx context.Context + cancel context.CancelFunc + running int32 + maxQueueSize int } // NewRuleWorker creates a new rule worker for the given rule. func NewRuleWorker(rule Rule, engine *re) *RuleWorker { return &RuleWorker{ - rule: rule, - engine: engine, - msgChan: make(chan WorkerMessage, 100), - running: 0, // 0 = not running, 1 = running + rule: rule, + engine: engine, + msgChan: make(chan WorkerMessage, 100), + running: 0, // 0 = not running, 1 = running + maxQueueSize: 100, } } @@ -43,7 +47,7 @@ func (w *RuleWorker) Start(ctx context.Context) { return } - w.ctx = ctx + w.ctx, w.cancel = context.WithCancel(ctx) go func() { defer atomic.StoreInt32(&w.running, 0) w.run(w.ctx) @@ -56,15 +60,39 @@ func (w *RuleWorker) Stop() error { return nil } + if w.cancel != nil { + w.cancel() + } + return nil } +// AbortExecution aborts the current execution if the worker is running. +func (w *RuleWorker) AbortExecution(ctx context.Context) { + if atomic.LoadInt32(&w.running) == 1 { + if w.cancel != nil { + w.cancel() + } + + w.engine.updateRuleExecutionStatus(ctx, w.rule.ID, AbortedStatus, fmt.Errorf("rule execution manually aborted")) + } +} + // Send sends a message to the worker for processing. func (w *RuleWorker) Send(msg WorkerMessage) bool { if atomic.LoadInt32(&w.running) == 0 { return false } + queueLen := len(w.msgChan) + if queueLen >= w.maxQueueSize { + return false + } + + if queueLen > 0 { + w.engine.updateRuleExecutionStatus(context.Background(), msg.Rule.ID, QueuedStatus, nil) + } + select { case w.msgChan <- msg: return true @@ -78,6 +106,11 @@ func (w *RuleWorker) IsRunning() bool { return atomic.LoadInt32(&w.running) == 1 } +// GetQueueLength returns the current number of queued messages. +func (w *RuleWorker) GetQueueLength() int { + return len(w.msgChan) +} + // GetRule returns the current rule configuration. func (w *RuleWorker) GetRule() Rule { return w.rule @@ -103,7 +136,21 @@ func (w *RuleWorker) processMessage(ctx context.Context, workerMsg WorkerMessage return } - runInfo := w.engine.process(ctx, currentRule, workerMsg.Message) + w.engine.updateRuleExecutionStatus(ctx, currentRule.ID, InProgressStatus, nil) + + select { + case <-w.ctx.Done(): + w.engine.updateRuleExecutionStatus(ctx, currentRule.ID, AbortedStatus, w.ctx.Err()) + return + default: + } + + runInfo := w.engine.process(w.ctx, currentRule, workerMsg.Message) + + if w.ctx.Err() == context.Canceled { + w.engine.updateRuleExecutionStatus(ctx, currentRule.ID, AbortedStatus, w.ctx.Err()) + return + } select { case w.engine.runInfo <- runInfo: @@ -120,6 +167,8 @@ const ( CmdStopAll CmdCount CmdList + CmdAbort + CmdGetStatus ) func (c WorkerCommandType) String() string { @@ -136,6 +185,10 @@ func (c WorkerCommandType) String() string { return "count" case CmdList: return "list" + case CmdAbort: + return "abort" + case CmdGetStatus: + return "get_status" default: return "unknown" } @@ -230,6 +283,8 @@ func (wm *WorkerManager) handleCommand(cmd WorkerManagerCommand) { default: } } + case CmdAbort: + wm.abortWorker(cmd.RuleID) case CmdStopAll: if err := wm.stopAll(); err != nil { select { @@ -257,6 +312,19 @@ func (wm *WorkerManager) handleCommand(cmd WorkerManagerCommand) { if cmd.Response != nil { cmd.Response <- ruleIDs } + case CmdGetStatus: + wm.mu.RLock() + var status map[string]interface{} + if worker, exists := wm.workers[cmd.RuleID]; exists { + status = map[string]interface{}{ + "running": worker.IsRunning(), + "queue_length": worker.GetQueueLength(), + } + } + wm.mu.RUnlock() + if cmd.Response != nil { + cmd.Response <- status + } } } @@ -339,6 +407,32 @@ func (wm *WorkerManager) updateWorker(rule Rule) error { return nil } +func (wm *WorkerManager) abortWorker(ruleID string) { + wm.mu.Lock() + defer wm.mu.Unlock() + + worker, exists := wm.workers[ruleID] + if !exists { + return + } + + worker.AbortExecution(wm.ctx) + + rule := worker.GetRule() + if rule.Status == EnabledStatus { + newWorker := NewRuleWorker(rule, wm.engine) + newWorker.Start(wm.ctx) + wm.workers[ruleID] = newWorker + } + + if err := worker.Stop(); err != nil { + select { + case wm.errorCh <- err: + default: + } + } +} + func (wm *WorkerManager) sendMessage(msg *messaging.Message, rule Rule) bool { wm.mu.RLock() worker, ok := wm.workers[rule.ID] @@ -500,3 +594,40 @@ func (wm *WorkerManager) RefreshWorkers(ctx context.Context, rules []Rule) { } } } + +func (wm *WorkerManager) AbortRule(ruleID string) { + if atomic.LoadInt32(&wm.running) == 0 { + return + } + + cmd := WorkerManagerCommand{ + Type: CmdAbort, + RuleID: ruleID, + } + + wm.commandCh <- cmd +} + +func (wm *WorkerManager) GetWorkerStatus(ruleID string) map[string]interface{} { + if atomic.LoadInt32(&wm.running) == 0 { + return nil + } + + responseCh := make(chan interface{}, 1) + cmd := WorkerManagerCommand{ + Type: CmdGetStatus, + RuleID: ruleID, + Response: responseCh, + } + + select { + case wm.commandCh <- cmd: + if result := <-responseCh; result != nil { + if status, ok := result.(map[string]interface{}); ok { + return status + } + } + default: + } + return nil +} From 5dfcb1bd21b3fcdb3ffe62083abf452d041f79d2 Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Wed, 1 Oct 2025 01:43:50 +0300 Subject: [PATCH 16/25] update tests Signed-off-by: nyagamunene --- re/execution_status_test.go | 167 +++++++++++++++++++++++++----------- re/mocks/repository.go | 57 ++++++++++++ re/mocks/service.go | 63 ++++++++++++++ re/service_test.go | 44 ++++++++++ 4 files changed, 281 insertions(+), 50 deletions(-) diff --git a/re/execution_status_test.go b/re/execution_status_test.go index 69e27aecc..883c8c3b9 100644 --- a/re/execution_status_test.go +++ b/re/execution_status_test.go @@ -4,9 +4,9 @@ package re import ( - "testing" +"testing" - "github.com/stretchr/testify/assert" +"github.com/stretchr/testify/assert" ) func TestExecutionStatusString(t *testing.T) { @@ -15,6 +15,11 @@ func TestExecutionStatusString(t *testing.T) { status ExecutionStatus want string }{ + { + desc: "Never Run status", + status: NeverRunStatus, + want: NeverRun, + }, { desc: "Success status", status: SuccessStatus, @@ -26,9 +31,9 @@ func TestExecutionStatusString(t *testing.T) { want: Failure, }, { - desc: "Aborted status", - status: AbortedStatus, - want: Aborted, + desc: "Partial Success status", + status: PartialSuccessStatus, + want: PartialSuccess, }, { desc: "Queued status", @@ -41,25 +46,25 @@ func TestExecutionStatusString(t *testing.T) { want: InProgress, }, { - desc: "Partial Success status", - status: PartialSuccessStatus, - want: PartialSuccess, + desc: "Aborted status", + status: AbortedStatus, + want: Aborted, }, { - desc: "Never Run status", - status: NeverRunStatus, - want: NeverRun, + desc: "Unknown status", + status: UnknownExecutionStatus, + want: UnknownExecution, }, { - desc: "Unknown status", + desc: "Invalid status (out of range)", status: ExecutionStatus(99), - want: UnknownExec, + want: UnknownExecution, }, } for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - got := tc.status.String() +got := tc.status.String() assert.Equal(t, tc.want, got) }) } @@ -72,6 +77,11 @@ func TestToExecutionStatus(t *testing.T) { want ExecutionStatus wantErr bool }{ + { + desc: "Never Run status", + status: NeverRun, + want: NeverRunStatus, + }, { desc: "Success status", status: Success, @@ -83,9 +93,9 @@ func TestToExecutionStatus(t *testing.T) { want: FailureStatus, }, { - desc: "Aborted status", - status: Aborted, - want: AbortedStatus, + desc: "Partial Success status", + status: PartialSuccess, + want: PartialSuccessStatus, }, { desc: "Queued status", @@ -98,37 +108,38 @@ func TestToExecutionStatus(t *testing.T) { want: InProgressStatus, }, { - desc: "Partial Success status", - status: PartialSuccess, - want: PartialSuccessStatus, + desc: "Aborted status", + status: Aborted, + want: AbortedStatus, }, { - desc: "Never Run status", - status: NeverRun, - want: NeverRunStatus, + desc: "Unknown status string", + status: UnknownExecution, + want: UnknownExecutionStatus, }, { - desc: "Empty string defaults to Never Run", + desc: "Empty string defaults to Unknown", status: "", - want: NeverRunStatus, + want: UnknownExecutionStatus, }, { desc: "Invalid status", status: "invalid", + want: UnknownExecutionStatus, wantErr: true, }, } for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - got, err := ToExecutionStatus(tc.status) - if tc.wantErr { - assert.Error(t, err) - return - } - assert.NoError(t, err) - assert.Equal(t, tc.want, got) - }) +got, err := ToExecutionStatus(tc.status) +if tc.wantErr { +assert.Error(t, err) +} else { +assert.NoError(t, err) +} +assert.Equal(t, tc.want, got) +}) } } @@ -138,6 +149,11 @@ func TestExecutionStatusMarshalJSON(t *testing.T) { status ExecutionStatus want string }{ + { + desc: "Never Run status", + status: NeverRunStatus, + want: `"never_run"`, + }, { desc: "Success status", status: SuccessStatus, @@ -149,15 +165,35 @@ func TestExecutionStatusMarshalJSON(t *testing.T) { want: `"failure"`, }, { - desc: "Never Run status", - status: NeverRunStatus, - want: `"never_run"`, + desc: "Partial Success status", + status: PartialSuccessStatus, + want: `"partial_success"`, + }, + { + desc: "Queued status", + status: QueuedStatus, + want: `"queued"`, + }, + { + desc: "In Progress status", + status: InProgressStatus, + want: `"in_progress"`, + }, + { + desc: "Aborted status", + status: AbortedStatus, + want: `"aborted"`, + }, + { + desc: "Unknown status", + status: UnknownExecutionStatus, + want: `"unknown"`, }, } for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - got, err := tc.status.MarshalJSON() +got, err := tc.status.MarshalJSON() assert.NoError(t, err) assert.Equal(t, tc.want, string(got)) }) @@ -171,6 +207,11 @@ func TestExecutionStatusUnmarshalJSON(t *testing.T) { want ExecutionStatus wantErr bool }{ + { + desc: "Never Run status", + data: `"never_run"`, + want: NeverRunStatus, + }, { desc: "Success status", data: `"success"`, @@ -182,27 +223,53 @@ func TestExecutionStatusUnmarshalJSON(t *testing.T) { want: FailureStatus, }, { - desc: "Never Run status", - data: `"never_run"`, - want: NeverRunStatus, + desc: "Partial Success status", + data: `"partial_success"`, + want: PartialSuccessStatus, + }, + { + desc: "Queued status", + data: `"queued"`, + want: QueuedStatus, + }, + { + desc: "In Progress status", + data: `"in_progress"`, + want: InProgressStatus, + }, + { + desc: "Aborted status", + data: `"aborted"`, + want: AbortedStatus, + }, + { + desc: "Unknown status string", + data: `"unknown"`, + want: UnknownExecutionStatus, + }, + { + desc: "Empty string", + data: `""`, + want: UnknownExecutionStatus, }, { desc: "Invalid status", data: `"invalid"`, + want: UnknownExecutionStatus, wantErr: true, }, } for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - var status ExecutionStatus - err := status.UnmarshalJSON([]byte(tc.data)) - if tc.wantErr { - assert.Error(t, err) - return - } - assert.NoError(t, err) - assert.Equal(t, tc.want, status) - }) +var status ExecutionStatus +err := status.UnmarshalJSON([]byte(tc.data)) +if tc.wantErr { +assert.Error(t, err) +} else { +assert.NoError(t, err) +} +assert.Equal(t, tc.want, status) +}) } -} \ No newline at end of file +} diff --git a/re/mocks/repository.go b/re/mocks/repository.go index a6458c5db..c0138fe09 100644 --- a/re/mocks/repository.go +++ b/re/mocks/repository.go @@ -369,6 +369,63 @@ func (_c *Repository_UpdateRuleDue_Call) RunAndReturn(run func(ctx context.Conte return _c } +// UpdateRuleExecutionStatus provides a mock function for the type Repository +func (_mock *Repository) UpdateRuleExecutionStatus(ctx context.Context, r re.Rule) error { + ret := _mock.Called(ctx, r) + + if len(ret) == 0 { + panic("no return value specified for UpdateRuleExecutionStatus") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, re.Rule) error); ok { + r0 = returnFunc(ctx, r) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Repository_UpdateRuleExecutionStatus_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateRuleExecutionStatus' +type Repository_UpdateRuleExecutionStatus_Call struct { + *mock.Call +} + +// UpdateRuleExecutionStatus is a helper method to define mock.On call +// - ctx context.Context +// - r re.Rule +func (_e *Repository_Expecter) UpdateRuleExecutionStatus(ctx interface{}, r interface{}) *Repository_UpdateRuleExecutionStatus_Call { + return &Repository_UpdateRuleExecutionStatus_Call{Call: _e.mock.On("UpdateRuleExecutionStatus", ctx, r)} +} + +func (_c *Repository_UpdateRuleExecutionStatus_Call) Run(run func(ctx context.Context, r re.Rule)) *Repository_UpdateRuleExecutionStatus_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 re.Rule + if args[1] != nil { + arg1 = args[1].(re.Rule) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Repository_UpdateRuleExecutionStatus_Call) Return(err error) *Repository_UpdateRuleExecutionStatus_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Repository_UpdateRuleExecutionStatus_Call) RunAndReturn(run func(ctx context.Context, r re.Rule) error) *Repository_UpdateRuleExecutionStatus_Call { + _c.Call.Return(run) + return _c +} + // UpdateRuleSchedule provides a mock function for the type Repository func (_mock *Repository) UpdateRuleSchedule(ctx context.Context, r re.Rule) (re.Rule, error) { ret := _mock.Called(ctx, r) diff --git a/re/mocks/service.go b/re/mocks/service.go index 5a3d500f8..d6427c398 100644 --- a/re/mocks/service.go +++ b/re/mocks/service.go @@ -43,6 +43,69 @@ func (_m *Service) EXPECT() *Service_Expecter { return &Service_Expecter{mock: &_m.Mock} } +// AbortRuleExecution provides a mock function for the type Service +func (_mock *Service) AbortRuleExecution(ctx context.Context, session authn.Session, id string) error { + ret := _mock.Called(ctx, session, id) + + if len(ret) == 0 { + panic("no return value specified for AbortRuleExecution") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string) error); ok { + r0 = returnFunc(ctx, session, id) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_AbortRuleExecution_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AbortRuleExecution' +type Service_AbortRuleExecution_Call struct { + *mock.Call +} + +// AbortRuleExecution is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - id string +func (_e *Service_Expecter) AbortRuleExecution(ctx interface{}, session interface{}, id interface{}) *Service_AbortRuleExecution_Call { + return &Service_AbortRuleExecution_Call{Call: _e.mock.On("AbortRuleExecution", ctx, session, id)} +} + +func (_c *Service_AbortRuleExecution_Call) Run(run func(ctx context.Context, session authn.Session, id string)) *Service_AbortRuleExecution_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Service_AbortRuleExecution_Call) Return(err error) *Service_AbortRuleExecution_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_AbortRuleExecution_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, id string) error) *Service_AbortRuleExecution_Call { + _c.Call.Return(run) + return _c +} + // AddRule provides a mock function for the type Service func (_mock *Service) AddRule(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, error) { ret := _mock.Called(ctx, session, r) diff --git a/re/service_test.go b/re/service_test.go index 03c4c098f..e4f8e55a6 100644 --- a/re/service_test.go +++ b/re/service_test.go @@ -636,6 +636,50 @@ func TestDisableRule(t *testing.T) { } } +func TestAbortRuleExecution(t *testing.T) { + cases := []struct { + desc string + session authn.Session + id string + rule re.Rule + repoErr error + err error + }{ + { + desc: "abort rule execution successfully", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + id: ruleID, + rule: re.Rule{ID: ruleID}, + repoErr: nil, + err: nil, + }, + { + desc: "abort rule execution with non-existent rule", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + id: "non-existent-rule", + rule: re.Rule{}, + repoErr: svcerr.ErrNotFound, + err: svcerr.ErrViewEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + svc, repo, _, _ := newService(t, make(chan pkglog.RunInfo, 100)) + repo.On("ViewRule", mock.Anything, tc.id).Return(tc.rule, tc.repoErr) + err := svc.AbortRuleExecution(context.Background(), tc.session, tc.id) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + repo.AssertExpectations(t) + }) + } +} + func TestHandle(t *testing.T) { svc, repo, pubmocks, _ := newService(t, make(chan pkglog.RunInfo, 100)) now := time.Now() From a43ec22659edbc6f3843fd7d490f82e29e6704fa Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Wed, 1 Oct 2025 14:44:01 +0300 Subject: [PATCH 17/25] remove comments Signed-off-by: nyagamunene --- re/execution_status_test.go | 42 ++++++++++++++++++------------------- re/handlers.go | 4 +--- re/postgres/rule.go | 4 ++-- re/rule.go | 20 +++++++++--------- re/service.go | 6 +++--- 5 files changed, 37 insertions(+), 39 deletions(-) diff --git a/re/execution_status_test.go b/re/execution_status_test.go index 883c8c3b9..d4ad404b3 100644 --- a/re/execution_status_test.go +++ b/re/execution_status_test.go @@ -4,9 +4,9 @@ package re import ( -"testing" + "testing" -"github.com/stretchr/testify/assert" + "github.com/stretchr/testify/assert" ) func TestExecutionStatusString(t *testing.T) { @@ -64,7 +64,7 @@ func TestExecutionStatusString(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { -got := tc.status.String() + got := tc.status.String() assert.Equal(t, tc.want, got) }) } @@ -132,14 +132,14 @@ func TestToExecutionStatus(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { -got, err := ToExecutionStatus(tc.status) -if tc.wantErr { -assert.Error(t, err) -} else { -assert.NoError(t, err) -} -assert.Equal(t, tc.want, got) -}) + got, err := ToExecutionStatus(tc.status) + if tc.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.want, got) + }) } } @@ -193,7 +193,7 @@ func TestExecutionStatusMarshalJSON(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { -got, err := tc.status.MarshalJSON() + got, err := tc.status.MarshalJSON() assert.NoError(t, err) assert.Equal(t, tc.want, string(got)) }) @@ -262,14 +262,14 @@ func TestExecutionStatusUnmarshalJSON(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { -var status ExecutionStatus -err := status.UnmarshalJSON([]byte(tc.data)) -if tc.wantErr { -assert.Error(t, err) -} else { -assert.NoError(t, err) -} -assert.Equal(t, tc.want, status) -}) + var status ExecutionStatus + err := status.UnmarshalJSON([]byte(tc.data)) + if tc.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.want, status) + }) } } diff --git a/re/handlers.go b/re/handlers.go index f74c2b487..1cf10c036 100644 --- a/re/handlers.go +++ b/re/handlers.go @@ -139,7 +139,7 @@ func (re *re) process(ctx context.Context, r Rule, msg *messaging.Message) pkglo var execError error if errorMsg != "" { - execError = fmt.Errorf("%s", errorMsg) + execError = errors.New(errorMsg) } re.updateRuleExecutionStatus(ctx, r.ID, execStatus, execError) @@ -249,10 +249,8 @@ func (re *re) StartScheduler(ctx context.Context) error { Created: due.Unix(), } - // Check worker status for scheduled rules too if workerStatus := re.workerMgr.GetWorkerStatus(r.ID); workerStatus != nil { if processing, ok := workerStatus["processing"].(bool); ok && processing { - // Worker is busy, scheduled message will be queued re.updateRuleExecutionStatus(ctx, r.ID, QueuedStatus, nil) } } diff --git a/re/postgres/rule.go b/re/postgres/rule.go index 4551269d8..8655a772e 100644 --- a/re/postgres/rule.go +++ b/re/postgres/rule.go @@ -59,13 +59,13 @@ func ruleToDb(r re.Rule) (dbRule, error) { if !r.Schedule.Time.IsZero() { t.Valid = true } - + lastRunTime := sql.NullTime{} if r.LastRunTime != nil && !r.LastRunTime.IsZero() { lastRunTime.Time = *r.LastRunTime lastRunTime.Valid = true } - + var tags pgtype.TextArray if err := tags.Set(r.Tags); err != nil { return dbRule{}, err diff --git a/re/rule.go b/re/rule.go index 8286f7a7f..8c3ca71f6 100644 --- a/re/rule.go +++ b/re/rule.go @@ -21,16 +21,16 @@ const ( ) const ( - OpAddRule = "OpAddRule" - OpViewRule = "OpViewRule" - OpUpdateRule = "OpUpdateRule" - OpUpdateRuleTags = "OpUpdateRuleTags" - OpUpdateRuleSchedule = "OpUpdateRuleSchedule" - OpListRules = "OpListRules" - OpRemoveRule = "OpRemoveRule" - OpEnableRule = "OpEnableRule" - OpDisableRule = "OpDisableRule" - OpAbortRuleExecution = "OpAbortRuleExecution" + OpAddRule = "OpAddRule" + OpViewRule = "OpViewRule" + OpUpdateRule = "OpUpdateRule" + OpUpdateRuleTags = "OpUpdateRuleTags" + OpUpdateRuleSchedule = "OpUpdateRuleSchedule" + OpListRules = "OpListRules" + OpRemoveRule = "OpRemoveRule" + OpEnableRule = "OpEnableRule" + OpDisableRule = "OpDisableRule" + OpAbortRuleExecution = "OpAbortRuleExecution" ) type ( diff --git a/re/service.go b/re/service.go index dec2d1622..d8f80a9cf 100644 --- a/re/service.go +++ b/re/service.go @@ -235,9 +235,9 @@ func (re *re) AbortRuleExecution(ctx context.Context, session authn.Session, id func (re *re) updateRuleExecutionStatus(ctx context.Context, ruleID string, status ExecutionStatus, err error) { now := time.Now().UTC() rule := Rule{ - ID: ruleID, - LastRunStatus: status, - LastRunTime: &now, + ID: ruleID, + LastRunStatus: status, + LastRunTime: &now, } if err != nil { From 34b5f077c4ac7a1df0e7e3ccf827f294431326a9 Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Wed, 15 Oct 2025 12:27:28 +0300 Subject: [PATCH 18/25] fix error wrap Signed-off-by: nyagamunene --- re/api/endpoints.go | 19 ++++++++ re/api/requests.go | 12 +++++ re/api/responses.go | 17 +++++++ re/api/transport.go | 13 ++++++ re/events/streams.go | 4 ++ re/middleware/authorization.go | 16 +++++++ re/middleware/logging.go | 21 +++++++++ re/mocks/service.go | 72 +++++++++++++++++++++++++++++ re/postgres/repository.go | 11 ++++- re/rule.go | 9 +++- re/service.go | 84 ++++++++++++++++++++++++++++++++-- re/worker.go | 14 +++--- 12 files changed, 280 insertions(+), 12 deletions(-) diff --git a/re/api/endpoints.go b/re/api/endpoints.go index f0773bf67..7c87b94b6 100644 --- a/re/api/endpoints.go +++ b/re/api/endpoints.go @@ -222,3 +222,22 @@ func abortRuleExecutionEndpoint(s re.Service) endpoint.Endpoint { return abortRuleExecutionRes{}, nil } } + +func getRuleExecutionStatusEndpoint(s re.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return nil, svcerr.ErrAuthorization + } + + req := request.(getRuleExecutionStatusReq) + if err := req.validate(); err != nil { + return getRuleExecutionStatusRes{}, err + } + status, err := s.GetRuleExecutionStatus(ctx, session, req.id) + if err != nil { + return getRuleExecutionStatusRes{}, err + } + return getRuleExecutionStatusRes{RuleExecutionStatus: status}, nil + } +} diff --git a/re/api/requests.go b/re/api/requests.go index b52b9fec3..9da67ca9e 100644 --- a/re/api/requests.go +++ b/re/api/requests.go @@ -146,3 +146,15 @@ func (req abortRuleExecutionReq) validate() error { return nil } + +type getRuleExecutionStatusReq struct { + id string +} + +func (req getRuleExecutionStatusReq) validate() error { + if req.id == "" { + return apiutil.ErrMissingID + } + + return nil +} diff --git a/re/api/responses.go b/re/api/responses.go index 75951ad5b..7f6d53009 100644 --- a/re/api/responses.go +++ b/re/api/responses.go @@ -19,6 +19,7 @@ var ( _ supermq.Response = (*updateRuleRes)(nil) _ supermq.Response = (*deleteRuleRes)(nil) _ supermq.Response = (*abortRuleExecutionRes)(nil) + _ supermq.Response = (*getRuleExecutionStatusRes)(nil) ) type pageRes struct { @@ -151,3 +152,19 @@ func (res abortRuleExecutionRes) Headers() map[string]string { func (res abortRuleExecutionRes) Empty() bool { return true } + +type getRuleExecutionStatusRes struct { + re.RuleExecutionStatus `json:",inline"` +} + +func (res getRuleExecutionStatusRes) Code() int { + return http.StatusOK +} + +func (res getRuleExecutionStatusRes) Headers() map[string]string { + return map[string]string{} +} + +func (res getRuleExecutionStatusRes) Empty() bool { + return false +} diff --git a/re/api/transport.go b/re/api/transport.go index 5a797276b..afafe9547 100644 --- a/re/api/transport.go +++ b/re/api/transport.go @@ -107,6 +107,13 @@ func MakeHandler(svc re.Service, authn smqauthn.AuthNMiddleware, mux *chi.Mux, l api.EncodeResponse, opts..., ), "abort_rule_execution").ServeHTTP) + + r.Get("/execution-status", otelhttp.NewHandler(kithttp.NewServer( + getRuleExecutionStatusEndpoint(svc), + decodeGetRuleExecutionStatusRequest, + api.EncodeResponse, + opts..., + ), "get_rule_execution_status").ServeHTTP) }) }) }) @@ -257,3 +264,9 @@ func decodeAbortRuleExecutionRequest(_ context.Context, r *http.Request) (any, e return abortRuleExecutionReq{id: id}, nil } + +func decodeGetRuleExecutionStatusRequest(_ context.Context, r *http.Request) (any, error) { + id := chi.URLParam(r, ruleIdKey) + + return getRuleExecutionStatusReq{id: id}, nil +} diff --git a/re/events/streams.go b/re/events/streams.go index b253c05d0..c3f15148a 100644 --- a/re/events/streams.go +++ b/re/events/streams.go @@ -199,6 +199,10 @@ func (es *eventStore) AbortRuleExecution(ctx context.Context, session authn.Sess return nil } +func (es *eventStore) GetRuleExecutionStatus(ctx context.Context, session authn.Session, id string) (re.RuleExecutionStatus, error) { + return es.svc.GetRuleExecutionStatus(ctx, session, id) +} + func (es *eventStore) StartScheduler(ctx context.Context) error { return es.svc.StartScheduler(ctx) } diff --git a/re/middleware/authorization.go b/re/middleware/authorization.go index a513042e7..8cffc4068 100644 --- a/re/middleware/authorization.go +++ b/re/middleware/authorization.go @@ -282,6 +282,22 @@ func (am *authorizationMiddleware) AbortRuleExecution(ctx context.Context, sessi return am.svc.AbortRuleExecution(ctx, session, id) } +func (am *authorizationMiddleware) GetRuleExecutionStatus(ctx context.Context, session authn.Session, id string) (re.RuleExecutionStatus, error) { + if err := am.authorize(ctx, smqauthz.PolicyReq{ + Domain: session.DomainID, + SubjectType: policies.UserType, + SubjectKind: policies.UsersKind, + Subject: session.DomainUserID, + Object: session.DomainID, + ObjectType: policies.DomainType, + Permission: policies.MembershipPermission, + }); err != nil { + return re.RuleExecutionStatus{}, errors.Wrap(errDomainViewRules, err) + } + + return am.svc.GetRuleExecutionStatus(ctx, session, id) +} + func (am *authorizationMiddleware) StartScheduler(ctx context.Context) error { return am.svc.StartScheduler(ctx) } diff --git a/re/middleware/logging.go b/re/middleware/logging.go index 6f5c7f1bf..305a86f5b 100644 --- a/re/middleware/logging.go +++ b/re/middleware/logging.go @@ -217,6 +217,27 @@ func (lm *loggingMiddleware) AbortRuleExecution(ctx context.Context, session aut return lm.svc.AbortRuleExecution(ctx, session, id) } +func (lm *loggingMiddleware) GetRuleExecutionStatus(ctx context.Context, session authn.Session, id string) (res re.RuleExecutionStatus, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("domain_id", session.DomainID), + slog.Group("rule", + slog.String("id", id), + slog.Bool("worker_running", res.WorkerRunning), + slog.Int("queue_length", res.QueueLength), + ), + } + if err != nil { + args = append(args, slog.String("error", err.Error())) + lm.logger.Warn("Get rule execution status failed", args...) + return + } + lm.logger.Info("Get rule execution status completed successfully", args...) + }(time.Now()) + return lm.svc.GetRuleExecutionStatus(ctx, session, id) +} + func (lm *loggingMiddleware) StartScheduler(ctx context.Context) (err error) { defer func(begin time.Time) { args := []any{ diff --git a/re/mocks/service.go b/re/mocks/service.go index d6427c398..2b26f7c4e 100644 --- a/re/mocks/service.go +++ b/re/mocks/service.go @@ -366,6 +366,78 @@ func (_c *Service_EnableRule_Call) RunAndReturn(run func(ctx context.Context, se return _c } +// GetRuleExecutionStatus provides a mock function for the type Service +func (_mock *Service) GetRuleExecutionStatus(ctx context.Context, session authn.Session, id string) (re.RuleExecutionStatus, error) { + ret := _mock.Called(ctx, session, id) + + if len(ret) == 0 { + panic("no return value specified for GetRuleExecutionStatus") + } + + var r0 re.RuleExecutionStatus + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string) (re.RuleExecutionStatus, error)); ok { + return returnFunc(ctx, session, id) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string) re.RuleExecutionStatus); ok { + r0 = returnFunc(ctx, session, id) + } else { + r0 = ret.Get(0).(re.RuleExecutionStatus) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, string) error); ok { + r1 = returnFunc(ctx, session, id) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_GetRuleExecutionStatus_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetRuleExecutionStatus' +type Service_GetRuleExecutionStatus_Call struct { + *mock.Call +} + +// GetRuleExecutionStatus is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - id string +func (_e *Service_Expecter) GetRuleExecutionStatus(ctx interface{}, session interface{}, id interface{}) *Service_GetRuleExecutionStatus_Call { + return &Service_GetRuleExecutionStatus_Call{Call: _e.mock.On("GetRuleExecutionStatus", ctx, session, id)} +} + +func (_c *Service_GetRuleExecutionStatus_Call) Run(run func(ctx context.Context, session authn.Session, id string)) *Service_GetRuleExecutionStatus_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Service_GetRuleExecutionStatus_Call) Return(ruleExecutionStatus re.RuleExecutionStatus, err error) *Service_GetRuleExecutionStatus_Call { + _c.Call.Return(ruleExecutionStatus, err) + return _c +} + +func (_c *Service_GetRuleExecutionStatus_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, id string) (re.RuleExecutionStatus, error)) *Service_GetRuleExecutionStatus_Call { + _c.Call.Return(run) + return _c +} + // Handle provides a mock function for the type Service func (_mock *Service) Handle(msg *messaging.Message) error { ret := _mock.Called(msg) diff --git a/re/postgres/repository.go b/re/postgres/repository.go index 242a656d3..1190e52ac 100644 --- a/re/postgres/repository.go +++ b/re/postgres/repository.go @@ -108,11 +108,20 @@ func (repo *PostgresRepository) UpdateRuleExecutionStatus(ctx context.Context, r return errors.Wrap(repoerr.ErrUpdateEntity, err) } - _, err = repo.DB.NamedExecContext(ctx, q, dbr) + // Debug: Log what we're about to write to the database + fmt.Printf("[DEBUG-DB] UpdateRuleExecutionStatus: rule_id=%s, execution_count_in=%d, db_execution_count=%d, last_run_status=%s\n", + r.ID, r.ExecutionCount, dbr.ExecutionCount, r.LastRunStatus.String()) + + result, err := repo.DB.NamedExecContext(ctx, q, dbr) if err != nil { + fmt.Printf("[ERROR-DB] UpdateRuleExecutionStatus failed: rule_id=%s, error=%v\n", r.ID, err) return postgres.HandleError(repoerr.ErrUpdateEntity, err) } + rowsAffected, _ := result.RowsAffected() + fmt.Printf("[DEBUG-DB] UpdateRuleExecutionStatus completed: rule_id=%s, rows_affected=%d, execution_count=%d\n", + r.ID, rowsAffected, r.ExecutionCount) + return nil } diff --git a/re/rule.go b/re/rule.go index 8c3ca71f6..ac7cd6780 100644 --- a/re/rule.go +++ b/re/rule.go @@ -250,7 +250,13 @@ type Page struct { Rules []Rule `json:"rules"` } -type Service interface { +type RuleExecutionStatus struct { + Rule Rule `json:"rule"` + WorkerRunning bool `json:"worker_running"` + QueueLength int `json:"queue_length"` +} + +type Service interface{ messaging.MessageHandler AddRule(ctx context.Context, session authn.Session, r Rule) (Rule, error) ViewRule(ctx context.Context, session authn.Session, id string) (Rule, error) @@ -262,6 +268,7 @@ type Service interface { EnableRule(ctx context.Context, session authn.Session, id string) (Rule, error) DisableRule(ctx context.Context, session authn.Session, id string) (Rule, error) AbortRuleExecution(ctx context.Context, session authn.Session, id string) error + GetRuleExecutionStatus(ctx context.Context, session authn.Session, id string) (RuleExecutionStatus, error) StartScheduler(ctx context.Context) error } diff --git a/re/service.go b/re/service.go index d8f80a9cf..e349c5751 100644 --- a/re/service.go +++ b/re/service.go @@ -221,17 +221,66 @@ func (re *re) Cancel() error { } func (re *re) AbortRuleExecution(ctx context.Context, session authn.Session, id string) error { - if _, err := re.repo.ViewRule(ctx, id); err != nil { + rule, err := re.repo.ViewRule(ctx, id) + if err != nil { return errors.Wrap(svcerr.ErrViewEntity, err) } + if rule.LastRunStatus != InProgressStatus && rule.LastRunStatus != QueuedStatus { + return errors.Wrap(errors.New(fmt.Sprintf("cannot abort rule with status '%s': rule must be in 'in_progress' or 'queued' status", + rule.LastRunStatus.String())), svcerr.ErrMalformedEntity) + } + if re.workerMgr != nil { + // Also check if worker actually exists and is running + workerStatus := re.workerMgr.GetWorkerStatus(id) + if workerStatus == nil { + return errors.Wrap(errors.New("no active worker found for this rule"), svcerr.ErrNotFound) + } + + running, ok := workerStatus["running"].(bool) + if !ok || !running { + return errors.Wrap(errors.New("cannot abort: worker is not currently running"), svcerr.ErrMalformedEntity) + } + + queueLen, _ := workerStatus["queue_length"].(int) + if queueLen == 0 && rule.LastRunStatus != InProgressStatus { + return errors.Wrap(errors.New("cannot abort: no execution in progress"), svcerr.ErrMalformedEntity) + } + re.workerMgr.AbortRule(id) } return nil } +func (re *re) GetRuleExecutionStatus(ctx context.Context, session authn.Session, id string) (RuleExecutionStatus, error) { + rule, err := re.repo.ViewRule(ctx, id) + if err != nil { + return RuleExecutionStatus{}, errors.Wrap(svcerr.ErrViewEntity, err) + } + + status := RuleExecutionStatus{ + Rule: rule, + WorkerRunning: false, + QueueLength: 0, + } + + if re.workerMgr != nil { + workerStatus := re.workerMgr.GetWorkerStatus(id) + if workerStatus != nil { + if running, ok := workerStatus["running"].(bool); ok { + status.WorkerRunning = running + } + if queueLen, ok := workerStatus["queue_length"].(int); ok { + status.QueueLength = queueLen + } + } + } + + return status, nil +} + func (re *re) updateRuleExecutionStatus(ctx context.Context, ruleID string, status ExecutionStatus, err error) { now := time.Now().UTC() rule := Rule{ @@ -244,14 +293,39 @@ func (re *re) updateRuleExecutionStatus(ctx context.Context, ruleID string, stat rule.LastRunErrorMessage = err.Error() } - if status == SuccessStatus || status == PartialSuccessStatus { - currentRule, err := re.repo.ViewRule(ctx, ruleID) - if err == nil { + // Debug: Log the status being set + fmt.Printf("[DEBUG] updateRuleExecutionStatus: rule_id=%s, status=%s, has_error=%v\n", ruleID, status.String(), err != nil) + + // Always fetch the current rule to get the current execution count + currentRule, viewErr := re.repo.ViewRule(ctx, ruleID) + if viewErr != nil { + fmt.Printf("[WARN] Failed to retrieve current rule: rule_id=%s, error=%v\n", ruleID, viewErr) + // If we can't fetch the rule, set count based on status + switch status { + case SuccessStatus, PartialSuccessStatus, FailureStatus: + rule.ExecutionCount = 1 + default: + rule.ExecutionCount = 0 + } + } else { + // Start with current count + rule.ExecutionCount = currentRule.ExecutionCount + + // Add 1 if completed successfully, add 0 otherwise (preserve count) + switch status { + case SuccessStatus, PartialSuccessStatus, FailureStatus: rule.ExecutionCount = currentRule.ExecutionCount + 1 + fmt.Printf("[DEBUG] Incremented execution count: rule_id=%s, old_count=%d, new_count=%d\n", ruleID, currentRule.ExecutionCount, rule.ExecutionCount) + default: + fmt.Printf("[DEBUG] Preserving execution count: rule_id=%s, status=%s, count=%d\n", ruleID, status.String(), rule.ExecutionCount) } } + fmt.Printf("[DEBUG] About to update rule execution status in database: rule_id=%s, status=%s, execution_count=%d\n", ruleID, status.String(), rule.ExecutionCount) + if err := re.repo.UpdateRuleExecutionStatus(ctx, rule); err != nil { + fmt.Printf("[ERROR] Failed to update rule execution status in database: rule_id=%s, error=%v\n", ruleID, err) + re.runInfo <- pkglog.RunInfo{ Level: slog.LevelWarn, Message: fmt.Sprintf("failed to update rule execution status: %s", err), @@ -260,5 +334,7 @@ func (re *re) updateRuleExecutionStatus(ctx context.Context, ruleID string, stat slog.String("status", status.String()), }, } + } else { + fmt.Printf("[DEBUG] Successfully updated rule execution status in database: rule_id=%s, status=%s, execution_count=%d\n", ruleID, status.String(), rule.ExecutionCount) } } diff --git a/re/worker.go b/re/worker.go index 1d1a0a22f..f9a8cf1ce 100644 --- a/re/worker.go +++ b/re/worker.go @@ -136,7 +136,6 @@ func (w *RuleWorker) processMessage(ctx context.Context, workerMsg WorkerMessage return } - w.engine.updateRuleExecutionStatus(ctx, currentRule.ID, InProgressStatus, nil) select { case <-w.ctx.Done(): @@ -419,11 +418,6 @@ func (wm *WorkerManager) abortWorker(ruleID string) { worker.AbortExecution(wm.ctx) rule := worker.GetRule() - if rule.Status == EnabledStatus { - newWorker := NewRuleWorker(rule, wm.engine) - newWorker.Start(wm.ctx) - wm.workers[ruleID] = newWorker - } if err := worker.Stop(); err != nil { select { @@ -431,6 +425,14 @@ func (wm *WorkerManager) abortWorker(ruleID string) { default: } } + + delete(wm.workers, ruleID) + + if rule.Status == EnabledStatus { + newWorker := NewRuleWorker(rule, wm.engine) + newWorker.Start(wm.ctx) + wm.workers[ruleID] = newWorker + } } func (wm *WorkerManager) sendMessage(msg *messaging.Message, rule Rule) bool { From 7370ed18badcd6ff4bc955ca617d346b91fe7d38 Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Wed, 15 Oct 2025 12:43:33 +0300 Subject: [PATCH 19/25] remove debug logs Signed-off-by: nyagamunene --- re/postgres/repository.go | 11 +---------- re/rule.go | 2 +- re/service.go | 18 ------------------ re/worker.go | 1 - 4 files changed, 2 insertions(+), 30 deletions(-) diff --git a/re/postgres/repository.go b/re/postgres/repository.go index 1190e52ac..242a656d3 100644 --- a/re/postgres/repository.go +++ b/re/postgres/repository.go @@ -108,20 +108,11 @@ func (repo *PostgresRepository) UpdateRuleExecutionStatus(ctx context.Context, r return errors.Wrap(repoerr.ErrUpdateEntity, err) } - // Debug: Log what we're about to write to the database - fmt.Printf("[DEBUG-DB] UpdateRuleExecutionStatus: rule_id=%s, execution_count_in=%d, db_execution_count=%d, last_run_status=%s\n", - r.ID, r.ExecutionCount, dbr.ExecutionCount, r.LastRunStatus.String()) - - result, err := repo.DB.NamedExecContext(ctx, q, dbr) + _, err = repo.DB.NamedExecContext(ctx, q, dbr) if err != nil { - fmt.Printf("[ERROR-DB] UpdateRuleExecutionStatus failed: rule_id=%s, error=%v\n", r.ID, err) return postgres.HandleError(repoerr.ErrUpdateEntity, err) } - rowsAffected, _ := result.RowsAffected() - fmt.Printf("[DEBUG-DB] UpdateRuleExecutionStatus completed: rule_id=%s, rows_affected=%d, execution_count=%d\n", - r.ID, rowsAffected, r.ExecutionCount) - return nil } diff --git a/re/rule.go b/re/rule.go index ac7cd6780..38b948e32 100644 --- a/re/rule.go +++ b/re/rule.go @@ -256,7 +256,7 @@ type RuleExecutionStatus struct { QueueLength int `json:"queue_length"` } -type Service interface{ +type Service interface { messaging.MessageHandler AddRule(ctx context.Context, session authn.Session, r Rule) (Rule, error) ViewRule(ctx context.Context, session authn.Session, id string) (Rule, error) diff --git a/re/service.go b/re/service.go index e349c5751..68ce5e81e 100644 --- a/re/service.go +++ b/re/service.go @@ -232,7 +232,6 @@ func (re *re) AbortRuleExecution(ctx context.Context, session authn.Session, id } if re.workerMgr != nil { - // Also check if worker actually exists and is running workerStatus := re.workerMgr.GetWorkerStatus(id) if workerStatus == nil { return errors.Wrap(errors.New("no active worker found for this rule"), svcerr.ErrNotFound) @@ -293,14 +292,8 @@ func (re *re) updateRuleExecutionStatus(ctx context.Context, ruleID string, stat rule.LastRunErrorMessage = err.Error() } - // Debug: Log the status being set - fmt.Printf("[DEBUG] updateRuleExecutionStatus: rule_id=%s, status=%s, has_error=%v\n", ruleID, status.String(), err != nil) - - // Always fetch the current rule to get the current execution count currentRule, viewErr := re.repo.ViewRule(ctx, ruleID) if viewErr != nil { - fmt.Printf("[WARN] Failed to retrieve current rule: rule_id=%s, error=%v\n", ruleID, viewErr) - // If we can't fetch the rule, set count based on status switch status { case SuccessStatus, PartialSuccessStatus, FailureStatus: rule.ExecutionCount = 1 @@ -308,24 +301,15 @@ func (re *re) updateRuleExecutionStatus(ctx context.Context, ruleID string, stat rule.ExecutionCount = 0 } } else { - // Start with current count rule.ExecutionCount = currentRule.ExecutionCount - // Add 1 if completed successfully, add 0 otherwise (preserve count) switch status { case SuccessStatus, PartialSuccessStatus, FailureStatus: rule.ExecutionCount = currentRule.ExecutionCount + 1 - fmt.Printf("[DEBUG] Incremented execution count: rule_id=%s, old_count=%d, new_count=%d\n", ruleID, currentRule.ExecutionCount, rule.ExecutionCount) - default: - fmt.Printf("[DEBUG] Preserving execution count: rule_id=%s, status=%s, count=%d\n", ruleID, status.String(), rule.ExecutionCount) } } - fmt.Printf("[DEBUG] About to update rule execution status in database: rule_id=%s, status=%s, execution_count=%d\n", ruleID, status.String(), rule.ExecutionCount) - if err := re.repo.UpdateRuleExecutionStatus(ctx, rule); err != nil { - fmt.Printf("[ERROR] Failed to update rule execution status in database: rule_id=%s, error=%v\n", ruleID, err) - re.runInfo <- pkglog.RunInfo{ Level: slog.LevelWarn, Message: fmt.Sprintf("failed to update rule execution status: %s", err), @@ -334,7 +318,5 @@ func (re *re) updateRuleExecutionStatus(ctx context.Context, ruleID string, stat slog.String("status", status.String()), }, } - } else { - fmt.Printf("[DEBUG] Successfully updated rule execution status in database: rule_id=%s, status=%s, execution_count=%d\n", ruleID, status.String(), rule.ExecutionCount) } } diff --git a/re/worker.go b/re/worker.go index f9a8cf1ce..246a65eef 100644 --- a/re/worker.go +++ b/re/worker.go @@ -136,7 +136,6 @@ func (w *RuleWorker) processMessage(ctx context.Context, workerMsg WorkerMessage return } - select { case <-w.ctx.Done(): w.engine.updateRuleExecutionStatus(ctx, currentRule.ID, AbortedStatus, w.ctx.Err()) From 04e7b42f678118b07d45d6c5d0e6ef89853b7277 Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Wed, 15 Oct 2025 12:46:51 +0300 Subject: [PATCH 20/25] revert env variable Signed-off-by: nyagamunene --- docker/.env | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/.env b/docker/.env index a8e6e7e1b..693897bda 100644 --- a/docker/.env +++ b/docker/.env @@ -396,4 +396,4 @@ MG_RELEASE_TAG=latest SMQ_ALLOW_UNVERIFIED_USER=true # Set to yes to accept the EULA for the UI services. To view the EULA visit: https://github.com/absmach/eula -MG_UI_DOCKER_ACCEPT_EULA=yes +MG_UI_DOCKER_ACCEPT_EULA=no From 79a772ea2667d0e4c981e0c6056a32a73a666a17 Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Wed, 15 Oct 2025 13:15:50 +0300 Subject: [PATCH 21/25] fix failing linter Signed-off-by: nyagamunene --- re/handlers.go | 2 +- re/worker.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/re/handlers.go b/re/handlers.go index 1cf10c036..ffbd2a6ca 100644 --- a/re/handlers.go +++ b/re/handlers.go @@ -254,7 +254,7 @@ func (re *re) StartScheduler(ctx context.Context) error { re.updateRuleExecutionStatus(ctx, r.ID, QueuedStatus, nil) } } - + //nolint:contextcheck if !re.workerMgr.SendMessage(msg, r) { re.runInfo <- pkglog.RunInfo{ Level: slog.LevelWarn, diff --git a/re/worker.go b/re/worker.go index 246a65eef..044cd268a 100644 --- a/re/worker.go +++ b/re/worker.go @@ -143,7 +143,7 @@ func (w *RuleWorker) processMessage(ctx context.Context, workerMsg WorkerMessage default: } - runInfo := w.engine.process(w.ctx, currentRule, workerMsg.Message) + runInfo := w.engine.process(ctx, currentRule, workerMsg.Message) if w.ctx.Err() == context.Canceled { w.engine.updateRuleExecutionStatus(ctx, currentRule.ID, AbortedStatus, w.ctx.Err()) From bfdcb309c79b8580b04b0a66c22eb511b64c4d2b Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Wed, 15 Oct 2025 13:29:58 +0300 Subject: [PATCH 22/25] update tests Signed-off-by: nyagamunene --- re/service_test.go | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/re/service_test.go b/re/service_test.go index e4f8e55a6..1bc3b30e7 100644 --- a/re/service_test.go +++ b/re/service_test.go @@ -646,15 +646,32 @@ func TestAbortRuleExecution(t *testing.T) { err error }{ { - desc: "abort rule execution successfully", + desc: "abort rule execution with wrong status", session: authn.Session{ UserID: userID, DomainID: domainID, }, - id: ruleID, - rule: re.Rule{ID: ruleID}, + id: ruleID, + rule: re.Rule{ + ID: ruleID, + LastRunStatus: re.SuccessStatus, + }, + repoErr: nil, + err: svcerr.ErrMalformedEntity, + }, + { + desc: "abort rule execution with never_run status", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + id: ruleID, + rule: re.Rule{ + ID: ruleID, + LastRunStatus: re.NeverRunStatus, + }, repoErr: nil, - err: nil, + err: svcerr.ErrMalformedEntity, }, { desc: "abort rule execution with non-existent rule", From e767ff6b17c8f3adb9da64997b1193fc2a45e7b9 Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Wed, 15 Oct 2025 13:55:14 +0300 Subject: [PATCH 23/25] add tests Signed-off-by: nyagamunene --- re/service_test.go | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/re/service_test.go b/re/service_test.go index 1bc3b30e7..7afb0ec4c 100644 --- a/re/service_test.go +++ b/re/service_test.go @@ -65,15 +65,30 @@ func newService(t *testing.T, runInfo chan pkglog.RunInfo) (re.Service, *mocks.R mockTicker.On("Stop").Return() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - t.Cleanup(func() { - cancel() - time.Sleep(10 * time.Millisecond) - }) + schedulerDone := make(chan struct{}) go func() { _ = svc.StartScheduler(ctx) + close(schedulerDone) }() + t.Cleanup(func() { + // Cancel context first to stop scheduler + cancel() + + // Wait for scheduler to finish gracefully + select { + case <-schedulerDone: + // Scheduler stopped successfully + case <-time.After(2 * time.Second): + t.Log("Warning: scheduler did not stop within timeout") + } + + // Give goroutines time to cleanup + time.Sleep(50 * time.Millisecond) + }) + + // Wait for scheduler to initialize time.Sleep(50 * time.Millisecond) return svc, repo, pubsub, mockTicker From 9ca935b62a91b31548d30614dde0c30350a22e0c Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Wed, 15 Oct 2025 14:12:23 +0300 Subject: [PATCH 24/25] fix tests Signed-off-by: nyagamunene --- re/service_test.go | 10 ++-------- re/worker.go | 8 ++++++-- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/re/service_test.go b/re/service_test.go index 7afb0ec4c..8412e0119 100644 --- a/re/service_test.go +++ b/re/service_test.go @@ -73,22 +73,16 @@ func newService(t *testing.T, runInfo chan pkglog.RunInfo) (re.Service, *mocks.R }() t.Cleanup(func() { - // Cancel context first to stop scheduler cancel() - // Wait for scheduler to finish gracefully select { case <-schedulerDone: - // Scheduler stopped successfully - case <-time.After(2 * time.Second): + case <-time.After(5 * time.Second): t.Log("Warning: scheduler did not stop within timeout") } - // Give goroutines time to cleanup - time.Sleep(50 * time.Millisecond) + time.Sleep(100 * time.Millisecond) }) - - // Wait for scheduler to initialize time.Sleep(50 * time.Millisecond) return svc, repo, pubsub, mockTicker diff --git a/re/worker.go b/re/worker.go index 044cd268a..efc530743 100644 --- a/re/worker.go +++ b/re/worker.go @@ -8,6 +8,7 @@ import ( "fmt" "sync" "sync/atomic" + "time" "github.com/absmach/supermq/pkg/messaging" "golang.org/x/sync/errgroup" @@ -526,8 +527,11 @@ func (wm *WorkerManager) StopAll() error { Response: responseCh, } - wm.commandCh <- cmd - <-responseCh + select { + case wm.commandCh <- cmd: + <-responseCh + case <-time.After(100 * time.Millisecond): + } return wm.g.Wait() } From a016b23a7c019f82dca5e6a62ceede57c4ae14a7 Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Wed, 15 Oct 2025 14:23:17 +0300 Subject: [PATCH 25/25] fix test Signed-off-by: nyagamunene --- re/service_test.go | 9 +++------ re/worker.go | 13 ++++++++++++- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/re/service_test.go b/re/service_test.go index 8412e0119..c8af8ca3d 100644 --- a/re/service_test.go +++ b/re/service_test.go @@ -64,7 +64,7 @@ func newService(t *testing.T, runInfo chan pkglog.RunInfo) (re.Service, *mocks.R mockTicker.On("Tick").Return((<-chan time.Time)(tickCh)) mockTicker.On("Stop").Return() - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) schedulerDone := make(chan struct{}) go func() { @@ -75,11 +75,8 @@ func newService(t *testing.T, runInfo chan pkglog.RunInfo) (re.Service, *mocks.R t.Cleanup(func() { cancel() - select { - case <-schedulerDone: - case <-time.After(5 * time.Second): - t.Log("Warning: scheduler did not stop within timeout") - } + // Wait for scheduler to stop (can take time for worker cleanup) + <-schedulerDone time.Sleep(100 * time.Millisecond) }) diff --git a/re/worker.go b/re/worker.go index efc530743..5fce93426 100644 --- a/re/worker.go +++ b/re/worker.go @@ -521,6 +521,13 @@ func (wm *WorkerManager) StopAll() error { return nil } + // If context is already cancelled, manageWorkers is stopping, just wait + select { + case <-wm.ctx.Done(): + return wm.g.Wait() + default: + } + responseCh := make(chan interface{}, 1) cmd := WorkerManagerCommand{ Type: CmdStopAll, @@ -529,7 +536,11 @@ func (wm *WorkerManager) StopAll() error { select { case wm.commandCh <- cmd: - <-responseCh + select { + case <-responseCh: + case <-wm.ctx.Done(): + case <-time.After(100 * time.Millisecond): + } case <-time.After(100 * time.Millisecond): }