diff --git a/changelog/fragments/1757633911-Add-agent_policy_id-and-policy_revision_idx-to-checkin-requests.yaml b/changelog/fragments/1757633911-Add-agent_policy_id-and-policy_revision_idx-to-checkin-requests.yaml new file mode 100644 index 0000000000..8ae51d7a39 --- /dev/null +++ b/changelog/fragments/1757633911-Add-agent_policy_id-and-policy_revision_idx-to-checkin-requests.yaml @@ -0,0 +1,40 @@ +# Kind can be one of: +# - breaking-change: a change to previously-documented behavior +# - deprecation: functionality that is being removed in a later release +# - bug-fix: fixes a problem in a previous version +# - enhancement: extends functionality but does not break or fix existing behavior +# - feature: new functionality +# - known-issue: problems that we are aware of in a given version +# - security: impacts on the security of a product or a user’s deployment. +# - upgrade: important information for someone upgrading from a prior version +# - other: does not fit into any of the other categories +kind: feature + +# Change summary; a 80ish characters long description of the change. +summary: Add agent_policy_id and policy_revision_idx to checkin requests + +# Long description; in case the summary is not enough to describe the change +# this field accommodate a description without length limits. +# NOTE: This field will be rendered only for breaking-change and known-issue kinds at the moment. +description: | + Add the agent_policy_id and policy_revision_idx attributes to checkin + request bodies so an agent is able to inform fleet-server of its exact + policy. These details will replace the need for an ack on + policy_change actions, and will be used to determine when to send a + policy change when there is a new revision available, or when the + agent is reassigned to a different policy. Add a server setting under + feature_flags.ignore_checkin_policy_id that disables this behavour and + restores the previous approach. + +# Affected component; usually one of "elastic-agent", "fleet-server", "filebeat", "metricbeat", "auditbeat", "all", etc. +component: fleet-server + +# PR URL; optional; the PR number that added the changeset. +# If not present is automatically filled by the tooling finding the PR where this changelog fragment has been added. +# NOTE: the tooling supports backports, so it's able to fill the original PR number instead of the backport PR number. +# Please provide it if you are adding a fragment for a different PR. +pr: https://github.com/elastic/fleet-server/pull/5501 + +# Issue URL; optional; the GitHub issue related to this changeset (either closes or is part of). +# If not present is automatically filled by the tooling with the issue linked to the PR number. +issue: https://github.com/elastic/elastic-agent/issues/6446 diff --git a/fleet-server.reference.yml b/fleet-server.reference.yml index e0a3669bc7..a5c7c01f5d 100644 --- a/fleet-server.reference.yml +++ b/fleet-server.reference.yml @@ -270,6 +270,13 @@ fleet: # upstream_url: "https://artifacts.elastic.co/GPG-KEY-elastic-agent" # # By default dir is the directory containing the fleet-server executable (following symlinks) joined with elastic-agent-upgrade-keys # dir: ./elastic-agent-upgrade-keys +# +# # Toggles to enable new behaviour or restore old behaviour. +# feature_flags: +# // ignore agent_policy_id and policy_revision_idx attributes that may be present in the checkin request bodies. +# // POLICY_CHANGE actions need an explicit ack if this is set. +# ignore_checkin_policy_id: false +# # # monitor options are advanced configuration and should not be adjusted is most cases # monitor: # fetch_size: 1000 # The number of documents that each monitor may fetch at once diff --git a/internal/pkg/api/handleAck.go b/internal/pkg/api/handleAck.go index 5ab9f70043..a456062670 100644 --- a/internal/pkg/api/handleAck.go +++ b/internal/pkg/api/handleAck.go @@ -442,19 +442,28 @@ func (ack *AckT) updateAPIKey(ctx context.Context, agentID string, apiKeyID, permissionHash string, toRetireAPIKeyIDs []model.ToRetireAPIKeyIdsItems, outputName string) error { - bulk := ack.bulk + return updateAPIKey(ctx, zlog, ack.bulk, agentID, apiKeyID, permissionHash, toRetireAPIKeyIDs, outputName) +} + +func updateAPIKey(ctx context.Context, + zlog zerolog.Logger, + bulk bulk.Bulk, + agentID string, + apiKeyID, permissionHash string, + toRetireAPIKeyIDs []model.ToRetireAPIKeyIdsItems, outputName string) error { // use output bulker if exists + outBulk := bulk if outputName != "" { - outputBulk := ack.bulk.GetBulker(outputName) + outputBulk := bulk.GetBulker(outputName) if outputBulk != nil { zlog.Debug().Str(ecs.PolicyOutputName, outputName).Msg("Using output bulker in updateAPIKey") - bulk = outputBulk + outBulk = outputBulk } } if apiKeyID != "" { - res, err := bulk.APIKeyRead(ctx, apiKeyID, true) + res, err := outBulk.APIKeyRead(ctx, apiKeyID, true) if err != nil { - if isAgentActive(ctx, zlog, ack.bulk, agentID) { + if isAgentActive(ctx, zlog, outBulk, agentID) { zlog.Warn(). Err(err). Str(LogAPIKeyID, apiKeyID). @@ -480,7 +489,7 @@ func (ack *AckT) updateAPIKey(ctx context.Context, Str(LogAPIKeyID, apiKeyID). Msg("Failed to cleanup roles") } else if removedRolesCount > 0 { - if err := bulk.APIKeyUpdate(ctx, apiKeyID, permissionHash, clean); err != nil { + if err := outBulk.APIKeyUpdate(ctx, apiKeyID, permissionHash, clean); err != nil { zlog.Error().Err(err).RawJSON("roles", clean).Str(LogAPIKeyID, apiKeyID).Str(ecs.PolicyOutputName, outputName).Msg("Failed to update API Key") } else { zlog.Debug(). @@ -493,7 +502,7 @@ func (ack *AckT) updateAPIKey(ctx context.Context, } } } - ack.invalidateAPIKeys(ctx, zlog, toRetireAPIKeyIDs, apiKeyID) + invalidateAPIKeys(ctx, zlog, bulk, toRetireAPIKeyIDs, apiKeyID) } return nil diff --git a/internal/pkg/api/handleCheckin.go b/internal/pkg/api/handleCheckin.go index 4ad39e9273..24d8df38bc 100644 --- a/internal/pkg/api/handleCheckin.go +++ b/internal/pkg/api/handleCheckin.go @@ -279,16 +279,34 @@ func (ct *CheckinT) ProcessRequest(zlog zerolog.Logger, w http.ResponseWriter, r return fmt.Errorf("failed to update upgrade_details: %w", err) } + initialOpts := []checkin.Option{ + checkin.WithStatus(string(req.Status)), + checkin.WithMessage(req.Message), + checkin.WithMeta(rawMeta), + checkin.WithComponents(rawComponents), + checkin.WithSeqNo(seqno), + checkin.WithVer(ver), + checkin.WithUnhealthyReason(unhealthyReason), + checkin.WithDeleteAudit(agent.AuditUnenrolledReason != "" || agent.UnenrolledAt != ""), + } + + revID, opts, err := ct.processPolicyDetails(r.Context(), zlog, agent, req) + if err != nil { + return fmt.Errorf("failed to update policy details: %w", err) + } + if len(opts) > 0 { + initialOpts = append(initialOpts, opts...) + } + // Subscribe to actions dispatcher aSub := ct.ad.Subscribe(zlog, agent.Id, seqno) defer ct.ad.Unsubscribe(zlog, aSub) actCh := aSub.Ch() - // use revision_idx=0 if the agent has a single output where no API key is defined - // This will force the policy monitor to emit a new policy to regerate API keys - revID := agent.PolicyRevisionIdx for _, output := range agent.Outputs { if output.APIKey == "" { + // use revision_idx=0 if the agent has a single output where no API key is defined + // This will force the policy monitor to emit a new policy to regerate API keys revID = 0 break } @@ -328,7 +346,7 @@ func (ct *CheckinT) ProcessRequest(zlog zerolog.Logger, w http.ResponseWriter, r // Initial update on checkin, and any user fields that might have changed // Run a script to remove audit_unenrolled_* and unenrolled_at attributes if one is set on checkin. // 8.16.x releases would incorrectly set unenrolled_at - err = ct.bc.CheckIn(agent.Id, checkin.WithStatus(string(req.Status)), checkin.WithMessage(req.Message), checkin.WithMeta(rawMeta), checkin.WithComponents(rawComponents), checkin.WithSeqNo(seqno), checkin.WithVer(ver), checkin.WithUnhealthyReason(unhealthyReason), checkin.WithDeleteAudit(agent.AuditUnenrolledReason != "" || agent.UnenrolledAt != "")) + err = ct.bc.CheckIn(agent.Id, initialOpts...) if err != nil { zlog.Error().Err(err).Str(ecs.AgentID, agent.Id).Msg("checkin failed") } @@ -1124,3 +1142,55 @@ func calcPollDuration(zlog zerolog.Logger, pollDuration, setupDuration, jitterDu return pollDuration, jitter } + +// processPolicyDetails handles the agent_policy_id and revision_idx included in the checkin request. +// The API keys will be managed if the agent reports a new policy id from its last checkin, or if the revision is different than what the last checkin reported. +// It returns the revision idx that should be used when subscribing for new POLICY_CHANGE actons and optional args to use when doing the non-tick checkin. +func (ct *CheckinT) processPolicyDetails(ctx context.Context, zlog zerolog.Logger, agent *model.Agent, req *CheckinRequest) (int64, []checkin.Option, error) { + // no details specified or attributes are ignored by config + if ct.cfg.Features.IgnoreCheckinPolicyID || req == nil || req.PolicyRevisionIdx == nil || req.AgentPolicyId == nil { + return agent.PolicyRevisionIdx, nil, nil + } + policyID := *req.AgentPolicyId + revisionIDX := *req.PolicyRevisionIdx + + span, ctx := apm.StartSpan(ctx, "Process policy details", "process") + span.Context.SetLabel("agent_id", agent.Agent.ID) + span.Context.SetLabel(dl.FieldAgentPolicyID, policyID) + span.Context.SetLabel(dl.FieldPolicyRevisionIdx, revisionIDX) + defer span.End() + + // update agent doc if policy id or revision idx does not match + var opts []checkin.Option + if policyID != agent.AgentPolicyID || revisionIDX != agent.PolicyRevisionIdx { + opts = []checkin.Option{ + checkin.WithAgentPolicyID(policyID), + checkin.WithPolicyRevisionIDX(revisionIDX), + } + } + // Policy reassign, subscribe to policy with revision 0 + if policyID != agent.PolicyID { + zlog.Debug().Str(dl.FieldAgentPolicyID, policyID).Str("new_policy_id", agent.PolicyID).Msg("Policy ID mismatch detected, reassigning agent.") + return 0, opts, nil + } + + // Check if the checkin revision_idx is greater than the latest available + latestRev := ct.pm.LatestRev(ctx, agent.PolicyID) + if latestRev != 0 && revisionIDX > latestRev { + revisionIDX = 0 // set return val to 0 so the agent gets latest available revision. + } + + // Update API keys if the policy has changed, or if the revision differs. + if policyID != agent.AgentPolicyID || revisionIDX != agent.PolicyRevisionIdx { + for outputName, output := range agent.Outputs { + if output.Type != policy.OutputTypeElasticsearch { + continue + } + if err := updateAPIKey(ctx, zlog, ct.bulker, agent.Id, output.APIKeyID, output.PermissionsHash, output.ToRetireAPIKeyIds, outputName); err != nil { + // Only returns ErrUpdatingInactiveAgent + return 0, nil, err + } + } + } + return revisionIDX, opts, nil +} diff --git a/internal/pkg/api/handleCheckin_test.go b/internal/pkg/api/handleCheckin_test.go index fe824b09d0..07b775f78d 100644 --- a/internal/pkg/api/handleCheckin_test.go +++ b/internal/pkg/api/handleCheckin_test.go @@ -25,7 +25,6 @@ import ( "github.com/elastic/fleet-server/v7/internal/pkg/dl" "github.com/elastic/fleet-server/v7/internal/pkg/es" "github.com/elastic/fleet-server/v7/internal/pkg/model" - mockmonitor "github.com/elastic/fleet-server/v7/internal/pkg/monitor/mock" "github.com/elastic/fleet-server/v7/internal/pkg/policy" "github.com/elastic/fleet-server/v7/internal/pkg/sqn" ftesting "github.com/elastic/fleet-server/v7/internal/pkg/testing" @@ -39,6 +38,30 @@ import ( "github.com/stretchr/testify/require" ) +type mockPolicyMonitor struct { + mock.Mock +} + +func (m *mockPolicyMonitor) Run(ctx context.Context) error { + args := m.Called(ctx) + return args.Error(0) +} + +func (m *mockPolicyMonitor) Subscribe(agentID, policyID string, revIDX int64) (policy.Subscription, error) { + args := m.Called(agentID, policyID, revIDX) + return args.Get(0).(policy.Subscription), args.Error(1) +} + +func (m *mockPolicyMonitor) Unsubscribe(sub policy.Subscription) error { + args := m.Called(sub) + return args.Error(0) +} + +func (m *mockPolicyMonitor) LatestRev(ctx context.Context, id string) int64 { + args := m.Called(ctx, id) + return args.Get(0).(int64) +} + func TestConvertActionData(t *testing.T) { tests := []struct { name string @@ -339,14 +362,13 @@ func TestResolveSeqNo(t *testing.T) { cfg := &config.Server{} c, _ := cache.New(config.Cache{NumCounters: 100, MaxCost: 100000}) bc := checkin.NewBulk(nil) - bulker := ftesting.NewMockBulk() - pim := mockmonitor.NewMockMonitor() - pm := policy.NewMonitor(bulker, pim, config.ServerLimits{PolicyLimit: config.Limit{Interval: 5 * time.Millisecond, Burst: 1}}) + pm := &mockPolicyMonitor{} ct, err := NewCheckinT(verCon, cfg, c, bc, pm, nil, nil, nil) assert.NoError(t, err) resp, _ := ct.resolveSeqNo(ctx, logger, tc.req, tc.agent) assert.Equal(t, tc.resp, resp) + pm.AssertExpectations(t) }) } @@ -1118,3 +1140,208 @@ func TestValidateCheckinRequest(t *testing.T) { }) } } + +func TestProcessPolicyDetails(t *testing.T) { + policyID := "policy-id" + revIDX2 := int64(2) + tests := []struct { + name string + agent *model.Agent + req *CheckinRequest + getPolicyMonitor func() *mockPolicyMonitor + revIDX int64 + returnsOpts bool + err error + }{{ + name: "request has no policy details", + agent: &model.Agent{ + PolicyRevisionIdx: 1, + }, + req: &CheckinRequest{}, + getPolicyMonitor: func() *mockPolicyMonitor { + return &mockPolicyMonitor{} + }, + revIDX: 1, + returnsOpts: false, + err: nil, + }, { + name: "policy reassign detected", + agent: &model.Agent{ + Agent: &model.AgentMetadata{ + ID: "agent-id", + }, + PolicyID: "new-policy-id", + AgentPolicyID: policyID, + PolicyRevisionIdx: 2, + }, + req: &CheckinRequest{ + AgentPolicyId: &policyID, + PolicyRevisionIdx: &revIDX2, + }, + getPolicyMonitor: func() *mockPolicyMonitor { + return &mockPolicyMonitor{} + }, + revIDX: 0, + returnsOpts: false, + err: nil, + }, { + name: "revision updated", + agent: &model.Agent{ + Agent: &model.AgentMetadata{ + ID: "agent-id", + }, + PolicyID: policyID, + AgentPolicyID: policyID, + PolicyRevisionIdx: 1, + }, + req: &CheckinRequest{ + AgentPolicyId: &policyID, + PolicyRevisionIdx: &revIDX2, + }, + getPolicyMonitor: func() *mockPolicyMonitor { + pm := &mockPolicyMonitor{} + pm.On("LatestRev", mock.Anything, policyID).Return(int64(2)).Once() + return pm + }, + revIDX: 2, + returnsOpts: true, + err: nil, + }, { + name: "checkin revision is greater than the policy's latest revision", + agent: &model.Agent{ + Agent: &model.AgentMetadata{ + ID: "agent-id", + }, + PolicyID: policyID, + AgentPolicyID: policyID, + PolicyRevisionIdx: 1, + }, + req: &CheckinRequest{ + AgentPolicyId: &policyID, + PolicyRevisionIdx: &revIDX2, + }, + getPolicyMonitor: func() *mockPolicyMonitor { + pm := &mockPolicyMonitor{} + pm.On("LatestRev", mock.Anything, policyID).Return(int64(1)).Once() + return pm + }, + revIDX: 0, + returnsOpts: true, + err: nil, + }, { + name: "agent_policy_id has changed", + agent: &model.Agent{ + Agent: &model.AgentMetadata{ + ID: "agent-id", + }, + PolicyID: policyID, + AgentPolicyID: "old-policy-id", + PolicyRevisionIdx: 1, + }, + req: &CheckinRequest{ + AgentPolicyId: &policyID, + PolicyRevisionIdx: &revIDX2, + }, + getPolicyMonitor: func() *mockPolicyMonitor { + pm := &mockPolicyMonitor{} + pm.On("LatestRev", mock.Anything, policyID).Return(int64(2)).Once() + return pm + }, + revIDX: 2, + returnsOpts: true, + err: nil, + }, { + name: "agent does not have agent_policy_id present", + agent: &model.Agent{ + Agent: &model.AgentMetadata{ + ID: "agent-id", + }, + PolicyID: policyID, + PolicyRevisionIdx: 2, + }, + req: &CheckinRequest{ + AgentPolicyId: &policyID, + PolicyRevisionIdx: &revIDX2, + }, + getPolicyMonitor: func() *mockPolicyMonitor { + pm := &mockPolicyMonitor{} + pm.On("LatestRev", mock.Anything, policyID).Return(int64(2)).Once() + return pm + }, + revIDX: 2, + returnsOpts: true, + err: nil, + }, { + name: "details present with no changes from agent doc", + agent: &model.Agent{ + Agent: &model.AgentMetadata{ + ID: "agent-id", + }, + AgentPolicyID: policyID, + PolicyID: policyID, + PolicyRevisionIdx: revIDX2, + }, + req: &CheckinRequest{ + AgentPolicyId: &policyID, + PolicyRevisionIdx: &revIDX2, + }, + getPolicyMonitor: func() *mockPolicyMonitor { + pm := &mockPolicyMonitor{} + pm.On("LatestRev", mock.Anything, policyID).Return(int64(2)).Once() + return pm + }, + revIDX: 2, + returnsOpts: false, + err: nil, + }} + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + logger := testlog.SetLogger(t) + pm := tc.getPolicyMonitor() + checkin := &CheckinT{ + cfg: &config.Server{}, + bulker: ftesting.NewMockBulk(), + pm: pm, + } + + revIDX, opts, err := checkin.processPolicyDetails(t.Context(), logger, tc.agent, tc.req) + assert.Equal(t, tc.revIDX, revIDX) + if tc.returnsOpts { + assert.NotEmpty(t, opts) + } else { + assert.Empty(t, opts) + } + if tc.err != nil { + assert.ErrorIs(t, tc.err, err) + } else { + assert.NoError(t, err) + } + pm.AssertExpectations(t) + }) + } + + t.Run("IgnoreCheckinPolicyID flag is set", func(t *testing.T) { + logger := testlog.SetLogger(t) + checkin := &CheckinT{ + cfg: &config.Server{ + Features: config.FeatureFlags{ + IgnoreCheckinPolicyID: true, + }, + }, + } + revIDX, opts, err := checkin.processPolicyDetails(t.Context(), logger, + &model.Agent{ + PolicyID: policyID, + PolicyRevisionIdx: 1, + }, + &CheckinRequest{ + AgentPolicyId: &policyID, + PolicyRevisionIdx: &revIDX2, + }, + ) + assert.NoError(t, err) + assert.Equal(t, int64(1), revIDX) + assert.Empty(t, opts) + }) +} diff --git a/internal/pkg/api/handleStatus_test.go b/internal/pkg/api/handleStatus_test.go index 8ec5208b82..b68436aa15 100644 --- a/internal/pkg/api/handleStatus_test.go +++ b/internal/pkg/api/handleStatus_test.go @@ -34,15 +34,15 @@ func withAuthFunc(authfn AuthFunc) OptFunc { } } -type mockPolicyMonitor struct { +type mockStatusPolicyMonitor struct { state client.UnitState } -func (pm *mockPolicyMonitor) Run(ctx context.Context) error { +func (pm *mockStatusPolicyMonitor) Run(ctx context.Context) error { return nil } -func (pm *mockPolicyMonitor) State() client.UnitState { +func (pm *mockStatusPolicyMonitor) State() client.UnitState { return pm.state } @@ -86,7 +86,7 @@ func TestHandleStatus(t *testing.T) { ctx = logger.WithContext(ctx) state := client.UnitState(k) r := apiServer{ - st: NewStatusT(cfg, nil, c, withAuthFunc(tc.AuthFn), WithSelfMonitor(&mockPolicyMonitor{state}), WithBuildInfo(fbuild.Info{ + st: NewStatusT(cfg, nil, c, withAuthFunc(tc.AuthFn), WithSelfMonitor(&mockStatusPolicyMonitor{state}), WithBuildInfo(fbuild.Info{ Version: "8.1.0", Commit: "4eff928", BuildTime: time.Now(), diff --git a/internal/pkg/api/openapi.gen.go b/internal/pkg/api/openapi.gen.go index 87b3e96618..0bc7b6d9d3 100644 --- a/internal/pkg/api/openapi.gen.go +++ b/internal/pkg/api/openapi.gen.go @@ -314,6 +314,9 @@ type CheckinRequest struct { // Translated to a sequence number in fleet-server in order to retrieve any new actions for the agent from the last checkin. AckToken *string `json:"ack_token,omitempty"` + // AgentPolicyId The ID of the policy that the agent is currently running. + AgentPolicyId *string `json:"agent_policy_id,omitempty"` + // Components An embedded JSON object that holds component information that the agent is running. // Defined in fleet-server as a `json.RawMessage`, defined as an object in the elastic-agent. // fleet-server will update the components in an agent record if they differ from this object. @@ -328,6 +331,9 @@ type CheckinRequest struct { // Message State message, may be overridden or use the error message of a failing component. Message string `json:"message"` + // PolicyRevisionIdx The revision of the policy that the agent is currently running. + PolicyRevisionIdx *int64 `json:"policy_revision_idx,omitempty"` + // PollTimeout An optional timeout value that informs fleet-server of when a client will time out on it's checkin request. // If not specified fleet-server will use the timeout values specified in the config (defaults to 5m polling and a 10m write timeout). // The value, if specified is expected to be a string that is parsable by [time.ParseDuration](https://pkg.go.dev/time#ParseDuration). diff --git a/internal/pkg/checkin/bulk.go b/internal/pkg/checkin/bulk.go index ae74dcb922..b68692ab8b 100644 --- a/internal/pkg/checkin/bulk.go +++ b/internal/pkg/checkin/bulk.go @@ -113,6 +113,18 @@ func WithDeleteAudit(del bool) Option { } } +func WithAgentPolicyID(id string) Option { + return func(pending *pendingT) { + pending.agentPolicyID = id + } +} + +func WithPolicyRevisionIDX(idx int64) Option { + return func(pending *pendingT) { + pending.revisionIDX = idx + } +} + type extraT struct { meta []byte seqNo sqn.SeqNo @@ -128,6 +140,8 @@ type pendingT struct { ts string status string message string + agentPolicyID string // may be empty + revisionIDX int64 extra *extraT unhealthyReason *[]string } @@ -314,6 +328,10 @@ func toUpdateBody(now string, pending pendingT) ([]byte, error) { dl.FieldLastCheckinMessage: pending.message, // Set the status message dl.FieldUnhealthyReason: pending.unhealthyReason, } + if pending.agentPolicyID != "" { + fields[dl.FieldAgentPolicyID] = pending.agentPolicyID + fields[dl.FieldPolicyRevisionIdx] = pending.revisionIDX + } if pending.extra != nil { // If the agent version is not empty it needs to be updated // Assuming the agent can by upgraded keeping the same id, but incrementing the version @@ -353,11 +371,13 @@ func encodeParams(now string, data pendingT) (map[string]json.RawMessage, error) reason json.RawMessage // optional attributes below - ver json.RawMessage - meta json.RawMessage - components json.RawMessage - isSet json.RawMessage - seqNo json.RawMessage + policyID json.RawMessage + revisionIDX json.RawMessage + ver json.RawMessage + meta json.RawMessage + components json.RawMessage + isSet json.RawMessage + seqNo json.RawMessage err error ) @@ -371,6 +391,10 @@ func encodeParams(now string, data pendingT) (map[string]json.RawMessage, error) Err = errors.Join(Err, err) reason, err = json.Marshal(data.unhealthyReason) Err = errors.Join(Err, err) + policyID, err = json.Marshal(data.agentPolicyID) + Err = errors.Join(Err, err) + revisionIDX, err = json.Marshal(data.revisionIDX) + Err = errors.Join(Err, err) ver, err = json.Marshal(data.extra.ver) Err = errors.Join(Err, err) isSet, err = json.Marshal(data.extra.seqNo.IsSet()) @@ -394,6 +418,8 @@ func encodeParams(now string, data pendingT) (map[string]json.RawMessage, error) "Status": status, "Message": message, "UnhealthyReason": reason, + "PolicyID": policyID, + "RevisionIDX": revisionIDX, "Ver": ver, "Meta": meta, "Components": components, diff --git a/internal/pkg/checkin/bulk_test.go b/internal/pkg/checkin/bulk_test.go index af22075319..d0cb344d54 100644 --- a/internal/pkg/checkin/bulk_test.go +++ b/internal/pkg/checkin/bulk_test.go @@ -5,7 +5,6 @@ package checkin import ( - "bytes" "context" "encoding/json" "testing" @@ -19,7 +18,9 @@ import ( "github.com/google/go-cmp/cmp" "github.com/rs/xid" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" ) // Test simple, @@ -43,38 +44,36 @@ func matchOp(tb testing.TB, c testcase, ts time.Time) func(ops []bulk.MultiOp) b // Decode and match operation // NOTE putting the extra validation here seems strange, maybe we should read the args in the test body intstead? type updateT struct { - LastCheckin string `json:"last_checkin"` - Status string `json:"last_checkin_status"` - UpdatedAt string `json:"updated_at"` - Meta json.RawMessage `json:"local_metadata"` - SeqNo sqn.SeqNo `json:"action_seq_no"` + LastCheckin string `json:"last_checkin"` + Status string `json:"last_checkin_status"` + UpdatedAt string `json:"updated_at"` + AgentPolicyID string `json:"agent_policy_id,omitempty"` + RevisionIDX int64 `json:"policy_revision_idx,omitempty"` + Meta json.RawMessage `json:"local_metadata"` + SeqNo sqn.SeqNo `json:"action_seq_no"` } m := make(map[string]updateT) - if err := json.Unmarshal(ops[0].Body, &m); err != nil { - tb.Fatalf("unable to validate operation: %v", err) - } + err := json.Unmarshal(ops[0].Body, &m) + require.NoErrorf(tb, err, "unable to validate operation body %s", string(ops[0].Body)) sub, ok := m["doc"] - if !ok { - tb.Fatal("unable to validate operation: expected doc") - } + require.True(tb, ok, "unable to validate operation: expected doc") + validateTimestamp(tb, ts.Truncate(time.Second), sub.LastCheckin) validateTimestamp(tb, ts.Truncate(time.Second), sub.UpdatedAt) + assert.Equal(tb, c.policyID, sub.AgentPolicyID) + assert.Equal(tb, c.revisionIDX, sub.RevisionIDX) if c.seqno != nil { if cdiff := cmp.Diff(c.seqno, sub.SeqNo); cdiff != "" { tb.Error(cdiff) } } - if c.meta != nil && !bytes.Equal(c.meta, sub.Meta) { - tb.Error("meta doesn't match up") + if c.meta != nil { + assert.Equal(tb, json.RawMessage(c.meta), sub.Meta) } - - if c.status != sub.Status { - tb.Error("status mismatch") - } - + assert.Equal(tb, c.status, sub.Status) return true } } @@ -84,6 +83,8 @@ type testcase struct { id string status string message string + policyID string + revisionIDX int64 meta []byte components []byte seqno sqn.SeqNo @@ -95,107 +96,73 @@ func TestBulkSimple(t *testing.T) { start := time.Now() const ver = "8.9.0" - cases := []testcase{ - { - "Simple case", - "simpleId", - "online", - "message", - nil, - nil, - nil, - "", - nil, - }, - { - "has meta with fips attribute", - "metaCaseID", - "online", - "message", - []byte(`{"fips":true,"snapshot":false}`), - nil, - nil, - "", - nil, - }, - { - "Singled field case", - "singleFieldId", - "online", - "message", - []byte(`{"hey":"now"}`), - []byte(`[{"id":"winlog-default"}]`), - nil, - "", - nil, - }, - { - "Multi field case", - "multiFieldId", - "online", - "message", - []byte(`{"hey":"now","brown":"cow"}`), - []byte(`[{"id":"winlog-default","type":"winlog"}]`), - nil, - ver, - nil, - }, - { - "Multi field nested case", - "multiFieldNestedId", - "online", - "message", - []byte(`{"hey":"now","wee":{"little":"doggie"}}`), - []byte(`[{"id":"winlog-default","type":"winlog"}]`), - nil, - "", - nil, - }, - { - "Simple case with seqNo", - "simpleseqno", - "online", - "message", - nil, - nil, - sqn.SeqNo{1, 2, 3, 4}, - ver, - nil, - }, - { - "Field case with seqNo", - "simpleseqno", - "online", - "message", - []byte(`{"uncle":"fester"}`), - []byte(`[{"id":"log-default"}]`), - sqn.SeqNo{5, 6, 7, 8}, - ver, - nil, - }, - { - "Unusual status", - "singleFieldId", - "unusual", - "message", - nil, - nil, - nil, - "", - nil, - }, - { - "Empty status", - "singleFieldId", - "", - "message", - nil, - nil, - nil, - "", - nil, - }, - } + cases := []testcase{{ + name: "Simple case", + id: "simpleId", + status: "online", + message: "message", + }, { + name: "Simple case with policy id and revision idx", + id: "simpleId", + status: "online", + message: "message", + policyID: "testPolicy", + revisionIDX: 1, + }, { + name: "has meta with fips attribute", + id: "metaCaseID", + status: "online", + message: "message", + meta: []byte(`{"fips":true,"snapshot":false}`), + }, { + name: "Singled field case", + id: "singleFieldId", + status: "online", + message: "message", + meta: []byte(`{"hey":"now"}`), + components: []byte(`[{"id":"winlog-default"}]`), + }, { + name: "Multi field case", + id: "multiFieldId", + status: "online", + message: "message", + meta: []byte(`{"hey":"now","brown":"cow"}`), + components: []byte(`[{"id":"winlog-default","type":"winlog"}]`), + ver: ver, + }, { + name: "Multi field nested case", + id: "multiFieldNestedId", + status: "online", + message: "message", + meta: []byte(`{"hey":"now","wee":{"little":"doggie"}}`), + components: []byte(`[{"id":"winlog-default","type":"winlog"}]`), + }, { + name: "Simple case with seqNo", + id: "simpleseqno", + status: "online", + message: "message", + seqno: sqn.SeqNo{1, 2, 3, 4}, + ver: ver, + }, { + name: "Field case with seqNo", + id: "simpleseqno", + status: "online", + message: "message", + meta: []byte(`{"uncle":"fester"}`), + components: []byte(`[{"id":"log-default"}]`), + seqno: sqn.SeqNo{5, 6, 7, 8}, + ver: ver, + }, { + name: "Unusual status", + id: "singleFieldId", + status: "unusual", + message: "message", + }, { + name: "Empty status", + id: "singleFieldId", + status: "", + message: "message", + }} for _, c := range cases { t.Run(c.name, func(t *testing.T) { @@ -205,6 +172,9 @@ func TestBulkSimple(t *testing.T) { bc := NewBulk(mockBulk) opts := []Option{WithStatus(c.status), WithMessage(c.message)} + if c.policyID != "" { + opts = append(opts, WithAgentPolicyID(c.policyID), WithPolicyRevisionIDX(c.revisionIDX)) + } if c.meta != nil { opts = append(opts, WithMeta(c.meta)) } @@ -221,13 +191,10 @@ func TestBulkSimple(t *testing.T) { opts = append(opts, WithUnhealthyReason(c.unhealthyReason)) } - if err := bc.CheckIn(c.id, opts...); err != nil { - t.Fatal(err) - } - - if err := bc.flush(ctx); err != nil { - t.Fatal(err) - } + err := bc.CheckIn(c.id, opts...) + require.NoError(t, err) + err = bc.flush(ctx) + require.NoError(t, err) mockBulk.AssertExpectations(t) }) @@ -235,11 +202,9 @@ func TestBulkSimple(t *testing.T) { } func validateTimestamp(tb testing.TB, start time.Time, ts string) { - if t1, err := time.Parse(time.RFC3339, ts); err != nil { - tb.Error("expected rfc3999") - } else if start.After(t1) { - tb.Error("timestamp in the past") - } + t1, err := time.Parse(time.RFC3339, ts) + require.NoErrorf(tb, err, "expected %q to be in RFC 3339 format", ts) + require.False(tb, start.After(t1), "timestamp in the past") } func benchmarkBulk(n int, b *testing.B) { diff --git a/internal/pkg/checkin/deleteAuditFieldsOnCheckin.painless b/internal/pkg/checkin/deleteAuditFieldsOnCheckin.painless index 6588cd1ce8..3b0f111865 100644 --- a/internal/pkg/checkin/deleteAuditFieldsOnCheckin.painless +++ b/internal/pkg/checkin/deleteAuditFieldsOnCheckin.painless @@ -3,6 +3,10 @@ ctx._source.updated_at = params.Now; ctx._source.last_checkin_status = params.Status; ctx._source.last_checkin_message = params.Message; ctx._source.unhealthy_reason = params.UnhealthyReason; +if (params.PolicyID != "") { + ctx._source.agent_policy_id = params.PolicyID; + ctx._source.policy_revision_idx = params.RevisionIDX; +} if (params.Ver != "") { ctx._source.agent.version = params.Ver; } diff --git a/internal/pkg/config/input.go b/internal/pkg/config/input.go index ee4e59a377..64c341ac54 100644 --- a/internal/pkg/config/input.go +++ b/internal/pkg/config/input.go @@ -77,6 +77,7 @@ type ( StaticPolicyTokens StaticPolicyTokens `config:"static_policy_tokens"` PGP PGP `config:"pgp"` PDKDF2 PBKDF2 `config:"pdkdf2"` + Features FeatureFlags `config:"feature_flags"` } StaticPolicyTokens struct { @@ -91,6 +92,13 @@ type ( TokenKey string `config:"token_key"` PolicyID string `config:"policy_id"` } + + // FeatureFlags contains toggles to enable new behaviour, or restore old behaviour. + FeatureFlags struct { + // IgnoreCheckinPolicyID when true will ignore the agent_policy_id and policy_revision_idx attributes in checkin request bodies. + // This setting restores previous behaviour where all POLICY_CHANGE actions need an explicit ack. + IgnoreCheckinPolicyID bool `config:"ignore_checkin_policy_id"` + } ) // InitDefaults initializes the defaults for the configuration. diff --git a/internal/pkg/dl/constants.go b/internal/pkg/dl/constants.go index 9c191955aa..7937d62877 100644 --- a/internal/pkg/dl/constants.go +++ b/internal/pkg/dl/constants.go @@ -34,6 +34,7 @@ const ( FieldLastCheckinMessage = "last_checkin_message" FieldLocalMetadata = "local_metadata" FieldComponents = "components" + FieldAgentPolicyID = "agent_policy_id" FieldPolicyID = "policy_id" FieldPolicyOutputAPIKey = "api_key" FieldPolicyOutputAPIKeyID = "api_key_id" diff --git a/internal/pkg/model/schema.go b/internal/pkg/model/schema.go index 8b88937cc7..92e78df86e 100644 --- a/internal/pkg/model/schema.go +++ b/internal/pkg/model/schema.go @@ -132,6 +132,9 @@ type Agent struct { Active bool `json:"active"` Agent *AgentMetadata `json:"agent,omitempty"` + // The policy ID that the Elastic Agent is currently running. + AgentPolicyID string `json:"agent_policy_id,omitempty"` + // Agent reason for unenroll/uninstall annotation. AuditUnenrolledReason string `json:"audit_unenrolled_reason,omitempty"` @@ -183,7 +186,7 @@ type Agent struct { // The current policy coordinator for the Elastic Agent PolicyCoordinatorIdx int64 `json:"policy_coordinator_idx,omitempty"` - // The policy ID for the Elastic Agent + // The policy ID that the Elastic Agent should run. PolicyID string `json:"policy_id,omitempty"` // Deprecated. Use Outputs instead. The policy output permissions hash diff --git a/internal/pkg/policy/monitor.go b/internal/pkg/policy/monitor.go index 55563b902b..59a979360b 100644 --- a/internal/pkg/policy/monitor.go +++ b/internal/pkg/policy/monitor.go @@ -62,6 +62,9 @@ type Monitor interface { // Unsubscribe removes the current subscription. Unsubscribe(sub Subscription) error + + // LatestRev returns the latest revision idx for the specified policy. + LatestRev(ctx context.Context, policyID string) int64 } type policyFetcher func(ctx context.Context, bulker bulk.Bulk, opt ...dl.Option) ([]model.Policy, error) @@ -557,3 +560,34 @@ func (m *monitorT) Unsubscribe(sub Subscription) error { return nil } + +// LatestRev returns the revision_idx for the passed policy ID. +// If the policy does not exist in the map, then all policies are foribly reloaded. +// On an error with the reload, or if the policy does not exist a 0 is returned. +func (m *monitorT) LatestRev(ctx context.Context, id string) int64 { + if id == "" { + return 0 + } + + m.mut.Lock() + p, ok := m.policies[id] + m.mut.Unlock() + + if !ok { + // We've not seen this policy before, force load. + err := m.loadPolicies(ctx) + if err != nil { + m.log.Error().Err(err).Str(ecs.PolicyID, id).Msg("Unable to load policies.") + return 0 + } + + m.mut.Lock() + p, ok = m.policies[id] + m.mut.Unlock() + if !ok { + m.log.Warn().Str(ecs.PolicyID, id).Msg("Unable to find policy after load.") + return 0 + } + } + return p.pp.Policy.RevisionIdx +} diff --git a/internal/pkg/policy/monitor_test.go b/internal/pkg/policy/monitor_test.go index 24cf5f0260..44f9141297 100644 --- a/internal/pkg/policy/monitor_test.go +++ b/internal/pkg/policy/monitor_test.go @@ -9,6 +9,7 @@ package policy import ( "context" "encoding/json" + "fmt" "sync" "testing" "time" @@ -549,3 +550,76 @@ LOOP: ms.AssertExpectations(t) mm.AssertExpectations(t) } + +func TestMonitor_LatestRev(t *testing.T) { + t.Run("empty policy id", func(t *testing.T) { + pm := &monitorT{} + idx := pm.LatestRev(t.Context(), "") + assert.Equal(t, int64(0), idx) + }) + + t.Run("policy load error", func(t *testing.T) { + bulker := ftesting.NewMockBulk() + mm := mmock.NewMockMonitor() + monitor := NewMonitor(bulker, mm, config.ServerLimits{}) + pm := monitor.(*monitorT) + pm.policyF = func(ctx context.Context, bulker bulk.Bulk, opt ...dl.Option) ([]model.Policy, error) { + return nil, fmt.Errorf("policy fetch error") + } + + idx := pm.LatestRev(t.Context(), "test-id") + assert.Equal(t, int64(0), idx) + }) + + t.Run("policy not found", func(t *testing.T) { + bulker := ftesting.NewMockBulk() + mm := mmock.NewMockMonitor() + monitor := NewMonitor(bulker, mm, config.ServerLimits{}) + pm := monitor.(*monitorT) + pm.policyF = func(ctx context.Context, bulker bulk.Bulk, opt ...dl.Option) ([]model.Policy, error) { + return []model.Policy{}, nil + } + idx := pm.LatestRev(t.Context(), "test-id") + assert.Equal(t, int64(0), idx) + }) + + t.Run("policy found after load", func(t *testing.T) { + bulker := ftesting.NewMockBulk() + mm := mmock.NewMockMonitor() + monitor := NewMonitor(bulker, mm, config.ServerLimits{}) + pm := monitor.(*monitorT) + policyId := uuid.Must(uuid.NewV4()).String() + rId := xid.New().String() + policy := model.Policy{ + ESDocument: model.ESDocument{ + Id: rId, + Version: 1, + SeqNo: 1, + }, + PolicyID: policyId, + Data: policyDataDefault, + RevisionIdx: 2, + } + pm.policyF = func(ctx context.Context, bulker bulk.Bulk, opt ...dl.Option) ([]model.Policy, error) { + return []model.Policy{policy}, nil + } + idx := pm.LatestRev(t.Context(), policyId) + assert.Equal(t, int64(2), idx) + }) + + t.Run("policy found", func(t *testing.T) { + pm := &monitorT{ + policies: map[string]policyT{ + "test-id": policyT{ + pp: ParsedPolicy{ + Policy: model.Policy{ + RevisionIdx: 1, + }, + }, + }, + }, + } + idx := pm.LatestRev(t.Context(), "test-id") + assert.Equal(t, int64(1), idx) + }) +} diff --git a/model/openapi.yml b/model/openapi.yml index d09e89ac72..6d10389128 100644 --- a/model/openapi.yml +++ b/model/openapi.yml @@ -422,6 +422,15 @@ components: format: duration upgrade_details: $ref: "#/components/schemas/upgrade_details" + agent_policy_id: + description: | + The ID of the policy that the agent is currently running. + type: string + policy_revision_idx: + description: | + The revision of the policy that the agent is currently running. + type: integer + format: int64 actionSignature: description: Optional action signing data. type: object diff --git a/model/schema.json b/model/schema.json index 1b187bd8a8..754d9c99d9 100644 --- a/model/schema.json +++ b/model/schema.json @@ -579,7 +579,12 @@ "format": "raw" }, "policy_id": { - "description": "The policy ID for the Elastic Agent", + "description": "The policy ID that the Elastic Agent should run.", + "type": "string", + "format": "uuid" + }, + "agent_policy_id": { + "description": "The policy ID that the Elastic Agent is currently running.", "type": "string", "format": "uuid" }, diff --git a/pkg/api/types.gen.go b/pkg/api/types.gen.go index 8b3c1b732c..9f9f78e24f 100644 --- a/pkg/api/types.gen.go +++ b/pkg/api/types.gen.go @@ -311,6 +311,9 @@ type CheckinRequest struct { // Translated to a sequence number in fleet-server in order to retrieve any new actions for the agent from the last checkin. AckToken *string `json:"ack_token,omitempty"` + // AgentPolicyId The ID of the policy that the agent is currently running. + AgentPolicyId *string `json:"agent_policy_id,omitempty"` + // Components An embedded JSON object that holds component information that the agent is running. // Defined in fleet-server as a `json.RawMessage`, defined as an object in the elastic-agent. // fleet-server will update the components in an agent record if they differ from this object. @@ -325,6 +328,9 @@ type CheckinRequest struct { // Message State message, may be overridden or use the error message of a failing component. Message string `json:"message"` + // PolicyRevisionIdx The revision of the policy that the agent is currently running. + PolicyRevisionIdx *int64 `json:"policy_revision_idx,omitempty"` + // PollTimeout An optional timeout value that informs fleet-server of when a client will time out on it's checkin request. // If not specified fleet-server will use the timeout values specified in the config (defaults to 5m polling and a 10m write timeout). // The value, if specified is expected to be a string that is parsable by [time.ParseDuration](https://pkg.go.dev/time#ParseDuration). diff --git a/testing/e2e/api_version/client_api_current.go b/testing/e2e/api_version/client_api_current.go index 77ce1432a5..7ed80cccdf 100644 --- a/testing/e2e/api_version/client_api_current.go +++ b/testing/e2e/api_version/client_api_current.go @@ -33,6 +33,9 @@ import ( "strings" "time" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/elastic/fleet-server/pkg/api" "github.com/elastic/fleet-server/testing/e2e/scaffold" "github.com/elastic/fleet-server/v7/version" @@ -513,3 +516,257 @@ func (tester *ClientAPITester) TestEnrollUpgradeAction_MetadataDownloadRate_Stri _, _, statusCode = tester.Checkin(ctx, agentKey, agentID, ackToken, &dur, body) tester.Require().Equal(http.StatusOK, statusCode, "Expected status code 200 for successful checkin") } + +func (tester *ClientAPITester) TestCheckinWithPolicyIDRevision() { + ctx, cancel := context.WithTimeout(tester.T().Context(), 4*time.Minute) + defer cancel() + dur := "60s" // 60s is the min poll duraton fleet-server allows + + tester.T().Log("Enroll an agent") + agentID, agentKey := tester.Enroll(ctx, tester.enrollmentKey) + tester.VerifyAgentInKibana(ctx, agentID) + + client, err := api.NewClientWithResponses(tester.endpoint, api.WithHTTPClient(tester.Client), api.WithRequestEditorFn(func(ctx context.Context, req *http.Request) error { + req.Header.Set("Authorization", "ApiKey "+agentKey) + return nil + })) + tester.Require().NoError(err) + + tester.T().Logf("test checkin 1: retrieve POLICY_CHANGE action for agent %s", agentID) + resp, err := client.AgentCheckinWithResponse(ctx, agentID, &api.AgentCheckinParams{UserAgent: "elastic agent " + version.DefaultVersion}, api.AgentCheckinJSONRequestBody{ + Status: api.CheckinRequestStatusOnline, + Message: "test checkin", + }) + tester.Require().NoError(err) + tester.Require().Equal(http.StatusOK, resp.StatusCode()) + + checkin := resp.JSON200 + tester.Require().NotEmpty(checkin.Actions) + var policyChange api.ActionPolicyChange + found := false + for _, action := range checkin.Actions { + if action.Type == api.POLICYCHANGE { + policyChange, err = action.Data.AsActionPolicyChange() + tester.Require().NoError(err) + found = true + break + } + } + tester.Require().True(found, "unable to find POLICY_CHANGE action in 1st checkin response") + policyID := policyChange.Policy.Id + revIDX := int64(policyChange.Policy.Revision) // TODO change mapping in openapi? + + // Checkin with policyID revIDX + // No actions should be returned + // Manage any API keys if present + tester.T().Logf("test checkin 2: agent %s with policy %s:%d in request body", agentID, policyID, revIDX) + resp, err = client.AgentCheckinWithResponse(ctx, agentID, &api.AgentCheckinParams{UserAgent: "elastic agent " + version.DefaultVersion}, api.AgentCheckinJSONRequestBody{ + Status: api.CheckinRequestStatusOnline, + Message: "test checkin", + PollTimeout: &dur, + AgentPolicyId: &policyID, + PolicyRevisionIdx: &revIDX, + }) + tester.Require().NoError(err) + tester.Require().Equal(http.StatusOK, resp.StatusCode()) + checkin = resp.JSON200 + tester.Require().Empty(checkin.Actions, "Unexpected action in response") + + tester.Require().EventuallyWithT(func(c *assert.CollectT) { + agent := tester.GetAgent(ctx, agentID) + assert.Equal(c, policyID, agent.AgentPolicyID) + assert.Equal(c, revIDX, int64(agent.Revision)) + }, time.Second*10, time.Second) + + // Check in with revIDX that does not exist + // POLICY_CHANGE should be returned + // No API keys changed + // agent doc will be updated with sent values + newRevIDX := revIDX + 1 + tester.T().Logf("test checkin 3: agent %s with revision_idx+1 %d (fast forward)", agentID, newRevIDX) + resp, err = client.AgentCheckinWithResponse(ctx, agentID, &api.AgentCheckinParams{UserAgent: "elastic agent " + version.DefaultVersion}, api.AgentCheckinJSONRequestBody{ + Status: api.CheckinRequestStatusOnline, + Message: "test checkin", + PollTimeout: &dur, + AgentPolicyId: &policyID, + PolicyRevisionIdx: &newRevIDX, + }) + tester.Require().NoError(err) + tester.Require().Equal(http.StatusOK, resp.StatusCode()) + + checkin = resp.JSON200 + found = false + for _, action := range checkin.Actions { + if action.Type == api.POLICYCHANGE { + policyChange, err = action.Data.AsActionPolicyChange() + tester.Require().NoError(err) + found = true + break + } + } + tester.Require().True(found, "unable to find POLICY_CHANGE action in 3rd checkin response") + tester.Require().Equal(policyID, policyChange.Policy.Id) + tester.Require().Equal(revIDX, int64(policyChange.Policy.Revision)) + + tester.Require().EventuallyWithT(func(c *assert.CollectT) { + agent := tester.GetAgent(ctx, agentID) + assert.Equal(c, policyID, agent.AgentPolicyID) + assert.Equal(c, newRevIDX, int64(agent.Revision)) + }, time.Second*10, time.Second) + + // Update policy + // Get the policy then "update" it without changing anything - revision ID should increment + tester.T().Logf("Update policy %s", policyID) + rawPolicy := tester.GetPolicy(ctx, policyID) + var obj map[string]any + err = json.Unmarshal(rawPolicy, &obj) + tester.Require().NoError(err) + item, ok := obj["item"] + tester.Require().True(ok, "Expected item in object: %v", obj) + obj, ok = item.(map[string]any) + tester.Require().True(ok, "Expected item to be object: %T", item) + reqObj := make(map[string]any) + // Copy some attributes - name and namespace are required. + for _, k := range []string{"name", "namespace", "id", "space_ids", "inactivity_timeout"} { + reqObj[k] = obj[k] + } + rawPolicy, err = json.Marshal(reqObj) + tester.Require().NoError(err) + + tester.UpdatePolicy(ctx, policyID, rawPolicy) + rawPolicy = tester.GetPolicy(ctx, policyID) + + // Verify that the revision has incremented + err = json.Unmarshal(rawPolicy, &obj) + tester.Require().NoError(err) + item, ok = obj["item"] + tester.Require().True(ok, "Expected item in object: %v", obj) + obj, ok = item.(map[string]any) + tester.Require().True(ok, "Expected item to be object: %T", item) + oRev, ok := obj["revision"] + tester.Require().True(ok, "revision not found in: %v", obj) + iRev, ok := oRev.(float64) // numbers will serialize to float64 by default + tester.Require().True(ok, "revision is not a float64: %T", oRev) + tester.Require().Equal(revIDX+1, int64(iRev), "Expected policy revision to be exactly one greater than last revision.") + tester.T().Logf("Policy has been updated to revision %d.", int64(iRev)) + + // Do a checkin with revIDX (policy.revision - 1) + // Last checkin should have already recorded the agent as running policy_revision, but this checkin must return a POLICY_CHANGE action. + // Note that API keys (if any) would be managed here + tester.T().Logf("test checkin 4: agent %s with policy.revision-1 %d", agentID, revIDX) + resp, err = client.AgentCheckinWithResponse(ctx, agentID, &api.AgentCheckinParams{UserAgent: "elastic agent " + version.DefaultVersion}, api.AgentCheckinJSONRequestBody{ + Status: api.CheckinRequestStatusOnline, + Message: "test checkin", + PollTimeout: &dur, + AgentPolicyId: &policyID, + PolicyRevisionIdx: &revIDX, + }) + tester.Require().NoError(err) + tester.Require().Equal(http.StatusOK, resp.StatusCode()) + checkin = resp.JSON200 + tester.Require().NotEmpty(checkin.Actions, "Expected an action in the response") + found = false + for _, action := range checkin.Actions { + if action.Type == api.POLICYCHANGE { + policyChange, err = action.Data.AsActionPolicyChange() + tester.Require().NoError(err) + found = true + break + } + } + tester.Require().True(found, "unable to find POLICY_CHANGE action in 4th checkin response") + revIDX = int64(policyChange.Policy.Revision) + tester.Require().Equal(int64(iRev), revIDX, "Expected POLICY_CHANGE action to be for updated policy revision") + + tester.Require().EventuallyWithT(func(c *assert.CollectT) { + agent := tester.GetAgent(ctx, agentID) + require.Equal(c, policyID, agent.AgentPolicyID) + require.Equal(c, revIDX, int64(agent.Revision)) + }, time.Second*10, time.Second) + + // Do a normal checkin to "reset" to latest revision_idx + // no actions are returned + // Manage any API keys if present + tester.T().Logf("test checkin 5: agent %s with policy %s:%d in request body", agentID, policyID, revIDX) + resp, err = client.AgentCheckinWithResponse(ctx, agentID, &api.AgentCheckinParams{UserAgent: "elastic agent " + version.DefaultVersion}, api.AgentCheckinJSONRequestBody{ + Status: api.CheckinRequestStatusOnline, + Message: "test checkin", + PollTimeout: &dur, + AgentPolicyId: &policyID, + PolicyRevisionIdx: &revIDX, + }) + tester.Require().NoError(err) + tester.Require().Equal(http.StatusOK, resp.StatusCode()) + checkin = resp.JSON200 + tester.Require().Empty(checkin.Actions, "Unexpected action in response") + + tester.Require().EventuallyWithT(func(c *assert.CollectT) { + agent := tester.GetAgent(ctx, agentID) + require.Equal(c, policyID, agent.AgentPolicyID) + require.Equal(c, revIDX, int64(agent.Revision)) + }, time.Second*10, time.Second) + + // Test that if the agent is "restored" to an earlier revIDX a policy_change is sent + prevRev := revIDX - 1 + tester.T().Logf("test checkin 6: agent %s with policy %s:%d (rewind)", agentID, policyID, prevRev) + resp, err = client.AgentCheckinWithResponse(ctx, agentID, &api.AgentCheckinParams{UserAgent: "elastic agent " + version.DefaultVersion}, api.AgentCheckinJSONRequestBody{ + Status: api.CheckinRequestStatusOnline, + Message: "test checkin", + PollTimeout: &dur, + AgentPolicyId: &policyID, + PolicyRevisionIdx: &prevRev, + }) + tester.Require().NoError(err) + tester.Require().Equal(http.StatusOK, resp.StatusCode()) + checkin = resp.JSON200 + tester.Require().NotEmpty(checkin.Actions, "Expected action in response") + + tester.Require().EventuallyWithT(func(c *assert.CollectT) { + agent := tester.GetAgent(ctx, agentID) + require.Equal(c, policyID, agent.AgentPolicyID) + require.Equal(c, prevRev, int64(agent.Revision)) + }, time.Second*10, time.Second) + + // agent is now recorded as on a previous revision - check to make sure a checkin without AgentPolicyId and revision result in a POLICY_CHANGE action + tester.T().Logf("test checkin 7: agent %s with no policy or revision", agentID) + resp, err = client.AgentCheckinWithResponse(ctx, agentID, &api.AgentCheckinParams{UserAgent: "elastic agent " + version.DefaultVersion}, api.AgentCheckinJSONRequestBody{ + Status: api.CheckinRequestStatusOnline, + Message: "test checkin", + PollTimeout: &dur, + }) + tester.Require().NoError(err) + tester.Require().Equal(http.StatusOK, resp.StatusCode()) + checkin = resp.JSON200 + tester.Require().NotEmpty(checkin.Actions, "Expected action in response") + actionID := "" + for _, action := range checkin.Actions { + if action.Type == api.POLICYCHANGE { + actionID = action.Id + break + } + } + tester.Require().NotEmptyf(actionID, "expected to find POLICY_CHANGE action id in %+v", checkin.Actions) + + tester.T().Log("Ack the POLICY_CHANGE action") + tester.Acks(ctx, agentKey, agentID, []string{actionID}) + + tester.T().Logf("test checkin 8: agent %s with no policy or revision should not recieve action", agentID) + resp, err = client.AgentCheckinWithResponse(ctx, agentID, &api.AgentCheckinParams{UserAgent: "elastic agent " + version.DefaultVersion}, api.AgentCheckinJSONRequestBody{ + Status: api.CheckinRequestStatusOnline, + Message: "test checkin", + PollTimeout: &dur, + }) + tester.Require().NoError(err) + tester.Require().Equal(http.StatusOK, resp.StatusCode()) + checkin = resp.JSON200 + tester.Require().Empty(checkin.Actions, "Unexpected action in response") + + tester.Require().EventuallyWithT(func(c *assert.CollectT) { + agent := tester.GetAgent(ctx, agentID) + assert.Equal(c, policyID, agent.AgentPolicyID) + assert.Equal(c, revIDX, int64(agent.Revision)) + }, time.Second*10, time.Second) + + // sanity check agent status in kibana + tester.AgentIsOnline(ctx, agentID) +} diff --git a/testing/e2e/scaffold/scaffold.go b/testing/e2e/scaffold/scaffold.go index f12150b2e4..7d54b0c3c2 100644 --- a/testing/e2e/scaffold/scaffold.go +++ b/testing/e2e/scaffold/scaffold.go @@ -307,6 +307,30 @@ type KibanaAgent struct { Status string `json:"status"` } +type ESAgentDoc struct { + Revision int `json:"policy_revision_idx"` + PolicyID string `json:"policy_id"` + AgentPolicyID string `json:"agent_policy_id"` +} + +func (s *Scaffold) GetAgent(ctx context.Context, id string) ESAgentDoc { + // NOTE we use ES instead of Kibana here as Kibana does not support agent_policy_id yet + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:9200/.fleet-agents/_doc/"+id, nil) + s.Require().NoError(err) + req.SetBasicAuth(s.ElasticUser, s.ElasticPass) + + resp, err := s.Client.Do(req) + s.Require().NoError(err) + defer resp.Body.Close() + s.Require().Equal(http.StatusOK, resp.StatusCode) + var obj struct { + Source ESAgentDoc `json:"_source"` + } + err = json.NewDecoder(resp.Body).Decode(&obj) + s.Require().NoError(err) + return obj.Source +} + func (s *Scaffold) GetAgents(ctx context.Context) (int, []KibanaAgent) { // TODO handle pagination if needed in the future req, err := http.NewRequestWithContext(ctx, "GET", "http://localhost:5601/api/fleet/agents", nil) @@ -644,7 +668,11 @@ func (s *Scaffold) AddPolicyOverrides(ctx context.Context, id string, overrides } p, err := json.Marshal(&body) s.Require().NoError(err) - req, err := http.NewRequestWithContext(ctx, http.MethodPut, fmt.Sprintf("http://localhost:5601/api/fleet/agent_policies/%s", id), bytes.NewReader(p)) + s.UpdatePolicy(ctx, id, p) +} + +func (s *Scaffold) UpdatePolicy(ctx context.Context, id string, body []byte) { + req, err := http.NewRequestWithContext(ctx, http.MethodPut, fmt.Sprintf("http://localhost:5601/api/fleet/agent_policies/%s", id), bytes.NewReader(body)) s.Require().NoError(err) req.SetBasicAuth(s.ElasticUser, s.ElasticPass) req.Header.Set("Content-Type", "application/json") @@ -654,3 +682,17 @@ func (s *Scaffold) AddPolicyOverrides(ctx context.Context, id string, overrides defer resp.Body.Close() s.Require().Equal(http.StatusOK, resp.StatusCode) } + +func (s *Scaffold) GetPolicy(ctx context.Context, id string) []byte { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://localhost:5601/api/fleet/agent_policies/%s", id), nil) + s.Require().NoError(err) + req.SetBasicAuth(s.ElasticUser, s.ElasticPass) + req.Header.Set("kbn-xsrf", "e2e-test") + resp, err := s.Client.Do(req) + s.Require().NoError(err) + defer resp.Body.Close() + s.Require().Equal(http.StatusOK, resp.StatusCode) + p, err := io.ReadAll(resp.Body) + s.Require().NoError(err) + return p +}