diff --git a/changelog/fragments/1751913009-Cancel-policy-dispatches-when-a-new-revision-arrives.yaml b/changelog/fragments/1751913009-Cancel-policy-dispatches-when-a-new-revision-arrives.yaml new file mode 100644 index 0000000000..9b685b15e3 --- /dev/null +++ b/changelog/fragments/1751913009-Cancel-policy-dispatches-when-a-new-revision-arrives.yaml @@ -0,0 +1,35 @@ +# Kind can be one of: +# - breaking-change: a change to previously-documented behavior +# - deprecation: functionality that is being removed in a later release +# - bug-fix: fixes a problem in a previous version +# - enhancement: extends functionality but does not break or fix existing behavior +# - feature: new functionality +# - known-issue: problems that we are aware of in a given version +# - security: impacts on the security of a product or a user’s deployment. +# - upgrade: important information for someone upgrading from a prior version +# - other: does not fit into any of the other categories +kind: enhancement + +# Change summary; a 80ish characters long description of the change. +summary: Cancel policy dispatches when a new revision arrives + +# Long description; in case the summary is not enough to describe the change +# this field accommodate a description without length limits. +# NOTE: This field will be rendered only for breaking-change and known-issue kinds at the moment. +description: | + The policy monitor may cancel dispatches to the pending queue when + new output is recieved. This allows the cancellation of sending + revision N to agents when N+1 is recieved. + +# Affected component; usually one of "elastic-agent", "fleet-server", "filebeat", "metricbeat", "auditbeat", "all", etc. +component: fleet-server + +# PR URL; optional; the PR number that added the changeset. +# If not present is automatically filled by the tooling finding the PR where this changelog fragment has been added. +# NOTE: the tooling supports backports, so it's able to fill the original PR number instead of the backport PR number. +# Please provide it if you are adding a fragment for a different PR. +#pr: https://github.com/owner/repo/1234 + +# Issue URL; optional; the GitHub issue related to this changeset (either closes or is part of). +# If not present is automatically filled by the tooling with the issue linked to the PR number. +issue: https://github.com/elastic/fleet-server/issues/3254 diff --git a/internal/pkg/policy/monitor.go b/internal/pkg/policy/monitor.go index a5347a0771..4616003b2d 100644 --- a/internal/pkg/policy/monitor.go +++ b/internal/pkg/policy/monitor.go @@ -88,7 +88,8 @@ type monitorT struct { policiesIndex string limit *rate.Limiter - startCh chan struct{} + startCh chan struct{} + dispatchCh chan struct{} } // NewMonitor creates the policy monitor for subscribing agents. @@ -135,14 +136,17 @@ func (m *monitorT) Run(ctx context.Context) error { close(m.startCh) - var iCtx context.Context + // use a cancellable context so we can stop dispatching changes if a new hit is received. + // the cancel func is manually called before return, or after policies have been dispatched. + iCtx, iCancel := context.WithCancel(ctx) var trans *apm.Transaction LOOP: for { m.log.Trace().Msg("policy monitor loop start") - iCtx = ctx select { case <-m.kickCh: + cancelOnce(iCtx, iCancel) + iCtx, iCancel = context.WithCancel(ctx) m.log.Trace().Msg("policy monitor kicked") if m.bulker.HasTracer() { trans = m.bulker.StartTransaction("initial policies", "policy_monitor") @@ -151,20 +155,31 @@ LOOP: if err := m.loadPolicies(iCtx); err != nil { endTrans(trans) + cancelOnce(iCtx, iCancel) return err } - m.dispatchPending(iCtx) - endTrans(trans) + go func(ctx context.Context, cancel context.CancelFunc, trans *apm.Transaction) { + m.dispatchPending(ctx) + endTrans(trans) + cancelOnce(ctx, cancel) + }(iCtx, iCancel, trans) case <-m.deployCh: + cancelOnce(iCtx, iCancel) + iCtx, iCancel = context.WithCancel(ctx) m.log.Trace().Msg("policy monitor deploy ch") if m.bulker.HasTracer() { trans = m.bulker.StartTransaction("forced policies", "policy_monitor") iCtx = apm.ContextWithTransaction(ctx, trans) } - m.dispatchPending(iCtx) - endTrans(trans) + go func(ctx context.Context, cancel context.CancelFunc, trans *apm.Transaction) { + m.dispatchPending(ctx) + endTrans(trans) + cancelOnce(ctx, cancel) + }(iCtx, iCancel, trans) case hits := <-s.Output(): // TODO would be nice to attach transaction IDs to hits, but would likely need a bigger refactor. + cancelOnce(iCtx, iCancel) + iCtx, iCancel = context.WithCancel(ctx) m.log.Trace().Int("hits", len(hits)).Msg("policy monitor hits from sub") if m.bulker.HasTracer() { trans = m.bulker.StartTransaction("output policies", "policy_monitor") @@ -173,18 +188,33 @@ LOOP: if err := m.processHits(iCtx, hits); err != nil { endTrans(trans) + cancelOnce(iCtx, iCancel) return err } - m.dispatchPending(iCtx) - endTrans(trans) + go func(ctx context.Context, cancel context.CancelFunc, trans *apm.Transaction) { + m.dispatchPending(ctx) + endTrans(trans) + cancelOnce(ctx, cancel) + }(iCtx, iCancel, trans) case <-ctx.Done(): break LOOP } } + iCancel() return nil } +// cancelOnce calls cancel if the context is not done. +func cancelOnce(ctx context.Context, cancel context.CancelFunc) { + select { + case <-ctx.Done(): + return + default: + cancel() + } +} + func unmarshalHits(hits []es.HitT) ([]model.Policy, error) { policies := make([]model.Policy, len(hits)) for i, hit := range hits { @@ -224,6 +254,14 @@ func (m *monitorT) waitStart(ctx context.Context) error { // dispatchPending will dispatch all pending policy changes to the subscriptions in the queue. // dispatches are rate limited by the monitor's limiter. func (m *monitorT) dispatchPending(ctx context.Context) { + // dispatchCh is used in tests to be able to control when a dispatch execution proceeds + if m.dispatchCh != nil { + select { + case <-m.dispatchCh: + case <-ctx.Done(): + return + } + } span, ctx := apm.StartSpan(ctx, "dispatch pending", "dispatch") defer span.End() m.mut.Lock() @@ -243,7 +281,10 @@ func (m *monitorT) dispatchPending(ctx context.Context) { // If too many (checkin) responses are written concurrently memory usage may explode due to allocating gzip writers. err := m.limit.Wait(ctx) if err != nil { - m.log.Warn().Err(err).Msg("Policy limit error") + m.pendingQ.pushFront(s) // context cancelled before sub is handled, put it back + if !errors.Is(err, context.Canceled) { + m.log.Warn().Err(err).Msg("Policy limit error") + } return } // Lookup the latest policy for this subscription @@ -257,6 +298,7 @@ func (m *monitorT) dispatchPending(ctx context.Context) { select { case <-ctx.Done(): + m.pendingQ.pushFront(s) // context cancelled before sub is handled, put it back m.log.Debug().Err(ctx.Err()).Msg("context termination detected in policy dispatch") return case s.ch <- &policy.pp: diff --git a/internal/pkg/policy/monitor_test.go b/internal/pkg/policy/monitor_test.go index 869420da97..24cf5f0260 100644 --- a/internal/pkg/policy/monitor_test.go +++ b/internal/pkg/policy/monitor_test.go @@ -226,7 +226,6 @@ func TestMonitor_SamePolicy(t *testing.T) { } func TestMonitor_NewPolicyExists(t *testing.T) { - tests := []struct { name string delay time.Duration @@ -442,3 +441,111 @@ LOOP: ms.AssertExpectations(t) mm.AssertExpectations(t) } + +func Test_Monitor_cancel_pending(t *testing.T) { + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + ctx = testlog.SetLogger(t).WithContext(ctx) + + chHitT := make(chan []es.HitT, 2) + defer close(chHitT) + ms := mmock.NewMockSubscription() + ms.On("Output").Return((<-chan []es.HitT)(chHitT)) + mm := mmock.NewMockMonitor() + mm.On("Subscribe").Return(ms).Once() + mm.On("Unsubscribe", mock.Anything).Return().Once() + bulker := ftesting.NewMockBulk() + + monitor := NewMonitor(bulker, mm, config.ServerLimits{}) + pm := monitor.(*monitorT) + pm.policyF = func(ctx context.Context, bulker bulk.Bulk, opt ...dl.Option) ([]model.Policy, error) { + return []model.Policy{}, nil + } + pm.dispatchCh = make(chan struct{}, 1) + + agentId := uuid.Must(uuid.NewV4()).String() + policyId := uuid.Must(uuid.NewV4()).String() + + rId := xid.New().String() + policy := model.Policy{ + ESDocument: model.ESDocument{ + Id: rId, + Version: 1, + SeqNo: 1, + }, + PolicyID: policyId, + Data: policyDataDefault, + RevisionIdx: 1, + } + policyData, err := json.Marshal(&policy) + require.NoError(t, err) + policy2 := model.Policy{ + ESDocument: model.ESDocument{ + Id: rId, + Version: 1, + SeqNo: 1, + }, + PolicyID: policyId, + Data: policyDataDefault, + RevisionIdx: 2, + } + policyData2, err := json.Marshal(&policy2) + require.NoError(t, err) + + // Send both revisions to monitor as as seperate hits + chHitT <- []es.HitT{{ + ID: rId, + SeqNo: 1, + Version: 1, + Source: policyData, + }} + chHitT <- []es.HitT{{ + ID: rId, + SeqNo: 2, + Version: 1, + Source: policyData2, + }} + + // start monitor + var merr error + var mwg sync.WaitGroup + mwg.Add(1) + go func() { + defer mwg.Done() + merr = monitor.Run(ctx) + }() + err = monitor.(*monitorT).waitStart(ctx) + require.NoError(t, err) + + // subscribe with revision 0 + s, err := monitor.Subscribe(agentId, policyId, 0) + defer monitor.Unsubscribe(s) + require.NoError(t, err) + + // This sleep allows the main run to call dispatch + // but dispatch will not proceed until there is a signal from the dispatchCh + time.Sleep(100 * time.Millisecond) + pm.dispatchCh <- struct{}{} + + tm := time.NewTimer(time.Second) + policies := make([]*ParsedPolicy, 0, 2) +LOOP: + for { + select { + case p := <-s.Output(): + policies = append(policies, p) + case <-tm.C: + break LOOP + } + } + + cancel() + mwg.Wait() + if merr != nil && merr != context.Canceled { + t.Fatal(merr) + } + require.Len(t, policies, 1, "expected to recieve one revision") + require.Equal(t, policies[0].Policy.RevisionIdx, int64(2)) + ms.AssertExpectations(t) + mm.AssertExpectations(t) +}