diff --git a/pkg/registry/core/serviceaccount/storage/storage_test.go b/pkg/registry/core/serviceaccount/storage/storage_test.go index ecf13431f5ca2..a67c27126832f 100644 --- a/pkg/registry/core/serviceaccount/storage/storage_test.go +++ b/pkg/registry/core/serviceaccount/storage/storage_test.go @@ -138,7 +138,7 @@ func TestCreate_Token_SetsCredentialIDAuditAnnotation(t *testing.T) { } auditContext := audit.AuditContextFrom(ctx) - issuedCredentialID, ok := auditContext.Event.Annotations["authentication.kubernetes.io/issued-credential-id"] + issuedCredentialID, ok := auditContext.GetEventAnnotation("authentication.kubernetes.io/issued-credential-id") if !ok || len(issuedCredentialID) == 0 { t.Errorf("did not find issued-credential-id in audit event annotations") } diff --git a/staging/src/k8s.io/apiserver/pkg/admission/audit.go b/staging/src/k8s.io/apiserver/pkg/admission/audit.go index 7c0993f0908f7..f9f90cd024757 100644 --- a/staging/src/k8s.io/apiserver/pkg/admission/audit.go +++ b/staging/src/k8s.io/apiserver/pkg/admission/audit.go @@ -83,7 +83,7 @@ func ensureAnnotationGetter(a Attributes) error { } func (handler *auditHandler) logAnnotations(ctx context.Context, a Attributes) { - ae := audit.AuditEventFrom(ctx) + ae := audit.AuditContextFrom(ctx) if ae == nil { return } @@ -91,9 +91,9 @@ func (handler *auditHandler) logAnnotations(ctx context.Context, a Attributes) { var annotations map[string]string switch a := a.(type) { case privateAnnotationsGetter: - annotations = a.getAnnotations(ae.Level) + annotations = a.getAnnotations(ae.GetEventLevel()) case AnnotationsGetter: - annotations = a.GetAnnotations(ae.Level) + annotations = a.GetAnnotations(ae.GetEventLevel()) default: // this will never happen, because we have already checked it in ensureAnnotationGetter } diff --git a/staging/src/k8s.io/apiserver/pkg/admission/audit_test.go b/staging/src/k8s.io/apiserver/pkg/admission/audit_test.go index 36c7e719d1f96..a3c6f003e5ba4 100644 --- a/staging/src/k8s.io/apiserver/pkg/admission/audit_test.go +++ b/staging/src/k8s.io/apiserver/pkg/admission/audit_test.go @@ -144,8 +144,10 @@ func TestWithAudit(t *testing.T) { var handler Interface = fakeHandler{tc.admit, tc.admitAnnotations, tc.validate, tc.validateAnnotations, tc.handles} ctx := audit.WithAuditContext(context.Background()) ac := audit.AuditContextFrom(ctx) - ae := &ac.Event - ae.Level = auditinternal.LevelMetadata + if err := ac.Init(audit.RequestAuditConfig{Level: auditinternal.LevelMetadata}, nil); err != nil { + t.Fatal(err) + } + auditHandler := WithAudit(handler) a := attributes() @@ -171,9 +173,9 @@ func TestWithAudit(t *testing.T) { annotations[k] = v } if len(annotations) == 0 { - assert.Nil(t, ae.Annotations, tcName+": unexptected annotations set in audit event") + assert.Nil(t, ac.GetEventAnnotations(), tcName+": unexptected annotations set in audit event") } else { - assert.Equal(t, annotations, ae.Annotations, tcName+": unexptected annotations set in audit event") + assert.Equal(t, annotations, ac.GetEventAnnotations(), tcName+": unexptected annotations set in audit event") } } } @@ -187,8 +189,6 @@ func TestWithAuditConcurrency(t *testing.T) { } var handler Interface = fakeHandler{admitAnnotations: admitAnnotations, handles: true} ctx := audit.WithAuditContext(context.Background()) - ac := audit.AuditContextFrom(ctx) - ac.Event.Level = auditinternal.LevelMetadata auditHandler := WithAudit(handler) a := attributes() diff --git a/staging/src/k8s.io/apiserver/pkg/audit/context.go b/staging/src/k8s.io/apiserver/pkg/audit/context.go index 9648587378ecd..5b93d594bffa8 100644 --- a/staging/src/k8s.io/apiserver/pkg/audit/context.go +++ b/staging/src/k8s.io/apiserver/pkg/audit/context.go @@ -18,10 +18,18 @@ package audit import ( "context" + "errors" + "maps" "sync" + "sync/atomic" + "time" + authnv1 "k8s.io/api/authentication/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/types" auditinternal "k8s.io/apiserver/pkg/apis/audit" + "k8s.io/apiserver/pkg/authentication/user" genericapirequest "k8s.io/apiserver/pkg/endpoints/request" "k8s.io/klog/v2" ) @@ -35,22 +43,223 @@ const auditKey key = iota // AuditContext holds the information for constructing the audit events for the current request. type AuditContext struct { - // RequestAuditConfig is the audit configuration that applies to the request - RequestAuditConfig RequestAuditConfig - - // Event is the audit Event object that is being captured to be written in + // initialized indicates whether requestAuditConfig and sink have been populated and are safe to read unguarded. + // This should only be set via Init(). + initialized atomic.Bool + // requestAuditConfig is the audit configuration that applies to the request. + // This should only be written via Init(RequestAuditConfig, Sink), and only read when initialized.Load() is true. + requestAuditConfig RequestAuditConfig + // sink is the sink to use when processing event stages. + // This should only be written via Init(RequestAuditConfig, Sink), and only read when initialized.Load() is true. + sink Sink + + // lock guards event + lock sync.Mutex + + // event is the audit Event object that is being captured to be written in // the API audit log. - Event auditinternal.Event + event auditinternal.Event - // annotationMutex guards event.Annotations - annotationMutex sync.Mutex + // unguarded copy of auditID from the event + auditID atomic.Value } // Enabled checks whether auditing is enabled for this audit context. func (ac *AuditContext) Enabled() bool { - // Note: An unset Level should be considered Enabled, so that request data (e.g. annotations) - // can still be captured before the audit policy is evaluated. - return ac != nil && ac.RequestAuditConfig.Level != auditinternal.LevelNone + if ac == nil { + // protect against nil pointers + return false + } + if !ac.initialized.Load() { + // Note: An unset Level should be considered Enabled, so that request data (e.g. annotations) + // can still be captured before the audit policy is evaluated. + return true + } + return ac.requestAuditConfig.Level != auditinternal.LevelNone +} + +func (ac *AuditContext) Init(requestAuditConfig RequestAuditConfig, sink Sink) error { + ac.lock.Lock() + defer ac.lock.Unlock() + if ac.initialized.Load() { + return errors.New("audit context was already initialized") + } + ac.requestAuditConfig = requestAuditConfig + ac.sink = sink + ac.event.Level = requestAuditConfig.Level + ac.initialized.Store(true) + return nil +} + +func (ac *AuditContext) AuditID() types.UID { + // return the unguarded copy of the auditID + id, _ := ac.auditID.Load().(types.UID) + return id +} + +func (ac *AuditContext) visitEvent(f func(event *auditinternal.Event)) { + ac.lock.Lock() + defer ac.lock.Unlock() + f(&ac.event) +} + +// ProcessEventStage returns true on success, false if there was an error processing the stage. +func (ac *AuditContext) ProcessEventStage(ctx context.Context, stage auditinternal.Stage) bool { + if ac == nil || !ac.initialized.Load() { + return true + } + if ac.sink == nil { + return true + } + for _, omitStage := range ac.requestAuditConfig.OmitStages { + if stage == omitStage { + return true + } + } + + processed := false + ac.visitEvent(func(event *auditinternal.Event) { + event.Stage = stage + if stage == auditinternal.StageRequestReceived { + event.StageTimestamp = event.RequestReceivedTimestamp + } else { + event.StageTimestamp = metav1.NewMicroTime(time.Now()) + } + + ObserveEvent(ctx) + processed = ac.sink.ProcessEvents(event) + }) + return processed +} + +func (ac *AuditContext) LogImpersonatedUser(user user.Info) { + ac.visitEvent(func(ev *auditinternal.Event) { + if ev == nil || ev.Level.Less(auditinternal.LevelMetadata) { + return + } + ev.ImpersonatedUser = &authnv1.UserInfo{ + Username: user.GetName(), + } + ev.ImpersonatedUser.Groups = user.GetGroups() + ev.ImpersonatedUser.UID = user.GetUID() + ev.ImpersonatedUser.Extra = map[string]authnv1.ExtraValue{} + for k, v := range user.GetExtra() { + ev.ImpersonatedUser.Extra[k] = authnv1.ExtraValue(v) + } + }) +} + +func (ac *AuditContext) LogResponseObject(status *metav1.Status, obj *runtime.Unknown) { + ac.visitEvent(func(ae *auditinternal.Event) { + if status != nil { + // selectively copy the bounded fields. + ae.ResponseStatus = &metav1.Status{ + Status: status.Status, + Message: status.Message, + Reason: status.Reason, + Details: status.Details, + Code: status.Code, + } + } + if ae.Level.Less(auditinternal.LevelRequestResponse) { + return + } + ae.ResponseObject = obj + }) +} + +// LogRequestPatch fills in the given patch as the request object into an audit event. +func (ac *AuditContext) LogRequestPatch(patch []byte) { + ac.visitEvent(func(ae *auditinternal.Event) { + ae.RequestObject = &runtime.Unknown{ + Raw: patch, + ContentType: runtime.ContentTypeJSON, + } + }) +} + +func (ac *AuditContext) GetEventAnnotation(key string) (string, bool) { + var val string + var ok bool + ac.visitEvent(func(event *auditinternal.Event) { + val, ok = event.Annotations[key] + }) + return val, ok +} + +func (ac *AuditContext) GetEventLevel() auditinternal.Level { + var level auditinternal.Level + ac.visitEvent(func(event *auditinternal.Event) { + level = event.Level + }) + return level +} + +func (ac *AuditContext) SetEventStage(stage auditinternal.Stage) { + ac.visitEvent(func(event *auditinternal.Event) { + event.Stage = stage + }) +} + +func (ac *AuditContext) GetEventStage() auditinternal.Stage { + var stage auditinternal.Stage + ac.visitEvent(func(event *auditinternal.Event) { + stage = event.Stage + }) + return stage +} + +func (ac *AuditContext) SetEventStageTimestamp(timestamp metav1.MicroTime) { + ac.visitEvent(func(event *auditinternal.Event) { + event.StageTimestamp = timestamp + }) +} + +func (ac *AuditContext) GetEventResponseStatus() *metav1.Status { + var status *metav1.Status + ac.visitEvent(func(event *auditinternal.Event) { + status = event.ResponseStatus + }) + return status +} + +func (ac *AuditContext) GetEventRequestReceivedTimestamp() metav1.MicroTime { + var timestamp metav1.MicroTime + ac.visitEvent(func(event *auditinternal.Event) { + timestamp = event.RequestReceivedTimestamp + }) + return timestamp +} + +func (ac *AuditContext) GetEventStageTimestamp() metav1.MicroTime { + var timestamp metav1.MicroTime + ac.visitEvent(func(event *auditinternal.Event) { + timestamp = event.StageTimestamp + }) + return timestamp +} + +func (ac *AuditContext) SetEventResponseStatus(status *metav1.Status) { + ac.visitEvent(func(event *auditinternal.Event) { + event.ResponseStatus = status + }) +} + +func (ac *AuditContext) SetEventResponseStatusCode(statusCode int32) { + ac.visitEvent(func(event *auditinternal.Event) { + if event.ResponseStatus == nil { + event.ResponseStatus = &metav1.Status{} + } + event.ResponseStatus.Code = statusCode + }) +} + +func (ac *AuditContext) GetEventAnnotations() map[string]string { + var annotations map[string]string + ac.visitEvent(func(event *auditinternal.Event) { + annotations = maps.Clone(event.Annotations) + }) + return annotations } // AddAuditAnnotation sets the audit annotation for the given key, value pair. @@ -66,8 +275,8 @@ func AddAuditAnnotation(ctx context.Context, key, value string) { return } - ac.annotationMutex.Lock() - defer ac.annotationMutex.Unlock() + ac.lock.Lock() + defer ac.lock.Unlock() addAuditAnnotationLocked(ac, key, value) } @@ -81,8 +290,8 @@ func AddAuditAnnotations(ctx context.Context, keysAndValues ...string) { return } - ac.annotationMutex.Lock() - defer ac.annotationMutex.Unlock() + ac.lock.Lock() + defer ac.lock.Unlock() if len(keysAndValues)%2 != 0 { klog.Errorf("Dropping mismatched audit annotation %q", keysAndValues[len(keysAndValues)-1]) @@ -100,8 +309,8 @@ func AddAuditAnnotationsMap(ctx context.Context, annotations map[string]string) return } - ac.annotationMutex.Lock() - defer ac.annotationMutex.Unlock() + ac.lock.Lock() + defer ac.lock.Unlock() for k, v := range annotations { addAuditAnnotationLocked(ac, k, v) @@ -110,8 +319,7 @@ func AddAuditAnnotationsMap(ctx context.Context, annotations map[string]string) // addAuditAnnotationLocked records the audit annotation on the event. func addAuditAnnotationLocked(ac *AuditContext, key, value string) { - ae := &ac.Event - + ae := &ac.event if ae.Annotations == nil { ae.Annotations = make(map[string]string) } @@ -128,15 +336,11 @@ func WithAuditContext(parent context.Context) context.Context { return parent // Avoid double registering. } - return genericapirequest.WithValue(parent, auditKey, &AuditContext{}) -} - -// AuditEventFrom returns the audit event struct on the ctx -func AuditEventFrom(ctx context.Context) *auditinternal.Event { - if ac := AuditContextFrom(ctx); ac.Enabled() { - return &ac.Event - } - return nil + return genericapirequest.WithValue(parent, auditKey, &AuditContext{ + event: auditinternal.Event{ + Stage: auditinternal.StageResponseStarted, + }, + }) } // AuditContextFrom returns the pair of the audit configuration object @@ -154,7 +358,10 @@ func WithAuditID(ctx context.Context, auditID types.UID) { return } if ac := AuditContextFrom(ctx); ac != nil { - ac.Event.AuditID = auditID + ac.visitEvent(func(event *auditinternal.Event) { + ac.auditID.Store(auditID) + event.AuditID = auditID + }) } } @@ -162,7 +369,8 @@ func WithAuditID(ctx context.Context, auditID types.UID) { // auditing is enabled. func AuditIDFrom(ctx context.Context) (types.UID, bool) { if ac := AuditContextFrom(ctx); ac != nil { - return ac.Event.AuditID, true + id, _ := ac.auditID.Load().(types.UID) + return id, true } return "", false } diff --git a/staging/src/k8s.io/apiserver/pkg/audit/context_test.go b/staging/src/k8s.io/apiserver/pkg/audit/context_test.go index 2bb3d39dd0187..9606d395cdbf4 100644 --- a/staging/src/k8s.io/apiserver/pkg/audit/context_test.go +++ b/staging/src/k8s.io/apiserver/pkg/audit/context_test.go @@ -40,16 +40,34 @@ func TestEnabled(t *testing.T) { ctx: &AuditContext{}, expectEnabled: true, // An AuditContext should be considered enabled before the level is set }, { - name: "level None", - ctx: &AuditContext{RequestAuditConfig: RequestAuditConfig{Level: auditinternal.LevelNone}}, + name: "level None", + ctx: func() *AuditContext { + ctx := &AuditContext{} + if err := ctx.Init(RequestAuditConfig{Level: auditinternal.LevelNone}, nil); err != nil { + t.Fatal(err) + } + return ctx + }(), expectEnabled: false, }, { - name: "level Metadata", - ctx: &AuditContext{RequestAuditConfig: RequestAuditConfig{Level: auditinternal.LevelMetadata}}, + name: "level Metadata", + ctx: func() *AuditContext { + ctx := &AuditContext{} + if err := ctx.Init(RequestAuditConfig{Level: auditinternal.LevelMetadata}, nil); err != nil { + t.Fatal(err) + } + return ctx + }(), expectEnabled: true, }, { - name: "level RequestResponse", - ctx: &AuditContext{RequestAuditConfig: RequestAuditConfig{Level: auditinternal.LevelRequestResponse}}, + name: "level RequestResponse", + ctx: func() *AuditContext { + ctx := &AuditContext{} + if err := ctx.Init(RequestAuditConfig{Level: auditinternal.LevelRequestResponse}, nil); err != nil { + t.Fatal(err) + } + return ctx + }(), expectEnabled: true, }} @@ -72,7 +90,7 @@ func TestAddAuditAnnotation(t *testing.T) { assert.Len(t, annotations, numAnnotations) } - ctxWithAnnotation := withAuditContextAndLevel(context.Background(), auditinternal.LevelMetadata) + ctxWithAnnotation := withAuditContextAndLevel(context.Background(), t, auditinternal.LevelMetadata) AddAuditAnnotation(ctxWithAnnotation, fmt.Sprintf(annotationKeyTemplate, 0), annotationExtraValue) tests := []struct { @@ -89,28 +107,28 @@ func TestAddAuditAnnotation(t *testing.T) { // Annotations should be retained. ctx: WithAuditContext(context.Background()), validator: func(t *testing.T, ctx context.Context) { - ev := AuditContextFrom(ctx).Event + ev := AuditContextFrom(ctx).event expectAnnotations(t, ev.Annotations) }, }, { description: "with metadata level", - ctx: withAuditContextAndLevel(context.Background(), auditinternal.LevelMetadata), + ctx: withAuditContextAndLevel(context.Background(), t, auditinternal.LevelMetadata), validator: func(t *testing.T, ctx context.Context) { - ev := AuditContextFrom(ctx).Event + ev := AuditContextFrom(ctx).event expectAnnotations(t, ev.Annotations) }, }, { description: "with none level", - ctx: withAuditContextAndLevel(context.Background(), auditinternal.LevelNone), + ctx: withAuditContextAndLevel(context.Background(), t, auditinternal.LevelNone), validator: func(t *testing.T, ctx context.Context) { - ev := AuditContextFrom(ctx).Event + ev := AuditContextFrom(ctx).event assert.Empty(t, ev.Annotations) }, }, { description: "with overlapping annotations", ctx: ctxWithAnnotation, validator: func(t *testing.T, ctx context.Context) { - ev := AuditContextFrom(ctx).Event + ev := AuditContextFrom(ctx).event expectAnnotations(t, ev.Annotations) // Verify that the pre-existing annotation is not overwritten. assert.Equal(t, annotationExtraValue, ev.Annotations[fmt.Sprintf(annotationKeyTemplate, 0)]) @@ -144,8 +162,8 @@ func TestAuditAnnotationsWithAuditLoggingSetup(t *testing.T) { AddAuditAnnotation(ctx, "before-evaluation", "1") // policy evaluated, audit logging enabled - if ac := AuditContextFrom(ctx); ac != nil { - ac.RequestAuditConfig.Level = auditinternal.LevelMetadata + if err := AuditContextFrom(ctx).Init(RequestAuditConfig{Level: auditinternal.LevelMetadata}, nil); err != nil { + t.Fatal(err) } AddAuditAnnotation(ctx, "after-evaluation", "2") @@ -153,13 +171,14 @@ func TestAuditAnnotationsWithAuditLoggingSetup(t *testing.T) { "before-evaluation": "1", "after-evaluation": "2", } - actual := AuditContextFrom(ctx).Event.Annotations + actual := AuditContextFrom(ctx).event.Annotations assert.Equal(t, expected, actual) } -func withAuditContextAndLevel(ctx context.Context, l auditinternal.Level) context.Context { +func withAuditContextAndLevel(ctx context.Context, t *testing.T, l auditinternal.Level) context.Context { ctx = WithAuditContext(ctx) - ac := AuditContextFrom(ctx) - ac.RequestAuditConfig.Level = l + if err := AuditContextFrom(ctx).Init(RequestAuditConfig{Level: l}, nil); err != nil { + t.Fatal(err) + } return ctx } diff --git a/staging/src/k8s.io/apiserver/pkg/audit/request.go b/staging/src/k8s.io/apiserver/pkg/audit/request.go index 9185278f06fbd..d5f9c730f518e 100644 --- a/staging/src/k8s.io/apiserver/pkg/audit/request.go +++ b/staging/src/k8s.io/apiserver/pkg/audit/request.go @@ -40,110 +40,73 @@ const ( userAgentTruncateSuffix = "...TRUNCATED" ) -func LogRequestMetadata(ctx context.Context, req *http.Request, requestReceivedTimestamp time.Time, level auditinternal.Level, attribs authorizer.Attributes) { +func LogRequestMetadata(ctx context.Context, req *http.Request, requestReceivedTimestamp time.Time, attribs authorizer.Attributes) { ac := AuditContextFrom(ctx) if !ac.Enabled() { return } - ev := &ac.Event - - ev.RequestReceivedTimestamp = metav1.NewMicroTime(requestReceivedTimestamp) - ev.Verb = attribs.GetVerb() - ev.RequestURI = req.URL.RequestURI() - ev.UserAgent = maybeTruncateUserAgent(req) - ev.Level = level - - ips := utilnet.SourceIPs(req) - ev.SourceIPs = make([]string, len(ips)) - for i := range ips { - ev.SourceIPs[i] = ips[i].String() - } - if user := attribs.GetUser(); user != nil { - ev.User.Username = user.GetName() - ev.User.Extra = map[string]authnv1.ExtraValue{} - for k, v := range user.GetExtra() { - ev.User.Extra[k] = authnv1.ExtraValue(v) + ac.visitEvent(func(ev *auditinternal.Event) { + ev.RequestReceivedTimestamp = metav1.NewMicroTime(requestReceivedTimestamp) + ev.Verb = attribs.GetVerb() + ev.RequestURI = req.URL.RequestURI() + ev.UserAgent = maybeTruncateUserAgent(req) + + ips := utilnet.SourceIPs(req) + ev.SourceIPs = make([]string, len(ips)) + for i := range ips { + ev.SourceIPs[i] = ips[i].String() } - ev.User.Groups = user.GetGroups() - ev.User.UID = user.GetUID() - } - if attribs.IsResourceRequest() { - ev.ObjectRef = &auditinternal.ObjectReference{ - Namespace: attribs.GetNamespace(), - Name: attribs.GetName(), - Resource: attribs.GetResource(), - Subresource: attribs.GetSubresource(), - APIGroup: attribs.GetAPIGroup(), - APIVersion: attribs.GetAPIVersion(), + if user := attribs.GetUser(); user != nil { + ev.User.Username = user.GetName() + ev.User.Extra = map[string]authnv1.ExtraValue{} + for k, v := range user.GetExtra() { + ev.User.Extra[k] = authnv1.ExtraValue(v) + } + ev.User.Groups = user.GetGroups() + ev.User.UID = user.GetUID() } - } + + if attribs.IsResourceRequest() { + ev.ObjectRef = &auditinternal.ObjectReference{ + Namespace: attribs.GetNamespace(), + Name: attribs.GetName(), + Resource: attribs.GetResource(), + Subresource: attribs.GetSubresource(), + APIGroup: attribs.GetAPIGroup(), + APIVersion: attribs.GetAPIVersion(), + } + } + }) } // LogImpersonatedUser fills in the impersonated user attributes into an audit event. -func LogImpersonatedUser(ae *auditinternal.Event, user user.Info) { - if ae == nil || ae.Level.Less(auditinternal.LevelMetadata) { +func LogImpersonatedUser(ctx context.Context, user user.Info) { + ac := AuditContextFrom(ctx) + if !ac.Enabled() { return } - ae.ImpersonatedUser = &authnv1.UserInfo{ - Username: user.GetName(), - } - ae.ImpersonatedUser.Groups = user.GetGroups() - ae.ImpersonatedUser.UID = user.GetUID() - ae.ImpersonatedUser.Extra = map[string]authnv1.ExtraValue{} - for k, v := range user.GetExtra() { - ae.ImpersonatedUser.Extra[k] = authnv1.ExtraValue(v) - } + ac.LogImpersonatedUser(user) } // LogRequestObject fills in the request object into an audit event. The passed runtime.Object // will be converted to the given gv. func LogRequestObject(ctx context.Context, obj runtime.Object, objGV schema.GroupVersion, gvr schema.GroupVersionResource, subresource string, s runtime.NegotiatedSerializer) { - ae := AuditEventFrom(ctx) - if ae == nil || ae.Level.Less(auditinternal.LevelMetadata) { + ac := AuditContextFrom(ctx) + if !ac.Enabled() { return } - - // complete ObjectRef - if ae.ObjectRef == nil { - ae.ObjectRef = &auditinternal.ObjectReference{} - } - - // meta.Accessor is more general than ObjectMetaAccessor, but if it fails, we can just skip setting these bits - if meta, err := meta.Accessor(obj); err == nil { - if len(ae.ObjectRef.Namespace) == 0 { - ae.ObjectRef.Namespace = meta.GetNamespace() - } - if len(ae.ObjectRef.Name) == 0 { - ae.ObjectRef.Name = meta.GetName() - } - if len(ae.ObjectRef.UID) == 0 { - ae.ObjectRef.UID = meta.GetUID() - } - if len(ae.ObjectRef.ResourceVersion) == 0 { - ae.ObjectRef.ResourceVersion = meta.GetResourceVersion() - } - } - if len(ae.ObjectRef.APIVersion) == 0 { - ae.ObjectRef.APIGroup = gvr.Group - ae.ObjectRef.APIVersion = gvr.Version - } - if len(ae.ObjectRef.Resource) == 0 { - ae.ObjectRef.Resource = gvr.Resource - } - if len(ae.ObjectRef.Subresource) == 0 { - ae.ObjectRef.Subresource = subresource - } - - if ae.Level.Less(auditinternal.LevelRequest) { + if ac.GetEventLevel().Less(auditinternal.LevelMetadata) { return } - if shouldOmitManagedFields(ctx) { + // meta.Accessor is more general than ObjectMetaAccessor, but if it fails, we can just skip setting these bits + objMeta, _ := meta.Accessor(obj) + if shouldOmitManagedFields(ac) { copy, ok, err := copyWithoutManagedFields(obj) if err != nil { - klog.ErrorS(err, "Error while dropping managed fields from the request", "auditID", ae.AuditID) + klog.ErrorS(err, "Error while dropping managed fields from the request", "auditID", ac.AuditID()) } if ok { obj = copy @@ -151,54 +114,75 @@ func LogRequestObject(ctx context.Context, obj runtime.Object, objGV schema.Grou } // TODO(audit): hook into the serializer to avoid double conversion - var err error - ae.RequestObject, err = encodeObject(obj, objGV, s) + requestObject, err := encodeObject(obj, objGV, s) if err != nil { // TODO(audit): add error slice to audit event struct - klog.ErrorS(err, "Encoding failed of request object", "auditID", ae.AuditID, "gvr", gvr.String(), "obj", obj) + klog.ErrorS(err, "Encoding failed of request object", "auditID", ac.AuditID(), "gvr", gvr.String(), "obj", obj) return } + + ac.visitEvent(func(ae *auditinternal.Event) { + if ae.ObjectRef == nil { + ae.ObjectRef = &auditinternal.ObjectReference{} + } + + if objMeta != nil { + if len(ae.ObjectRef.Namespace) == 0 { + ae.ObjectRef.Namespace = objMeta.GetNamespace() + } + if len(ae.ObjectRef.Name) == 0 { + ae.ObjectRef.Name = objMeta.GetName() + } + if len(ae.ObjectRef.UID) == 0 { + ae.ObjectRef.UID = objMeta.GetUID() + } + if len(ae.ObjectRef.ResourceVersion) == 0 { + ae.ObjectRef.ResourceVersion = objMeta.GetResourceVersion() + } + } + if len(ae.ObjectRef.APIVersion) == 0 { + ae.ObjectRef.APIGroup = gvr.Group + ae.ObjectRef.APIVersion = gvr.Version + } + if len(ae.ObjectRef.Resource) == 0 { + ae.ObjectRef.Resource = gvr.Resource + } + if len(ae.ObjectRef.Subresource) == 0 { + ae.ObjectRef.Subresource = subresource + } + + if ae.Level.Less(auditinternal.LevelRequest) { + return + } + ae.RequestObject = requestObject + }) } // LogRequestPatch fills in the given patch as the request object into an audit event. func LogRequestPatch(ctx context.Context, patch []byte) { - ae := AuditEventFrom(ctx) - if ae == nil || ae.Level.Less(auditinternal.LevelRequest) { + ac := AuditContextFrom(ctx) + if ac.GetEventLevel().Less(auditinternal.LevelRequest) { return } - - ae.RequestObject = &runtime.Unknown{ - Raw: patch, - ContentType: runtime.ContentTypeJSON, - } + ac.LogRequestPatch(patch) } // LogResponseObject fills in the response object into an audit event. The passed runtime.Object // will be converted to the given gv. func LogResponseObject(ctx context.Context, obj runtime.Object, gv schema.GroupVersion, s runtime.NegotiatedSerializer) { - ae := AuditEventFrom(ctx) - if ae == nil || ae.Level.Less(auditinternal.LevelMetadata) { + ac := AuditContextFrom(WithAuditContext(ctx)) + status, _ := obj.(*metav1.Status) + if ac.GetEventLevel().Less(auditinternal.LevelMetadata) { return - } - if status, ok := obj.(*metav1.Status); ok { - // selectively copy the bounded fields. - ae.ResponseStatus = &metav1.Status{ - Status: status.Status, - Message: status.Message, - Reason: status.Reason, - Details: status.Details, - Code: status.Code, - } - } - - if ae.Level.Less(auditinternal.LevelRequestResponse) { + } else if ac.GetEventLevel().Less(auditinternal.LevelRequestResponse) { + ac.LogResponseObject(status, nil) return } - if shouldOmitManagedFields(ctx) { + if shouldOmitManagedFields(ac) { copy, ok, err := copyWithoutManagedFields(obj) if err != nil { - klog.ErrorS(err, "Error while dropping managed fields from the response", "auditID", ae.AuditID) + klog.ErrorS(err, "Error while dropping managed fields from the response", "auditID", ac.AuditID()) } if ok { obj = copy @@ -207,10 +191,11 @@ func LogResponseObject(ctx context.Context, obj runtime.Object, gv schema.GroupV // TODO(audit): hook into the serializer to avoid double conversion var err error - ae.ResponseObject, err = encodeObject(obj, gv, s) + responseObject, err := encodeObject(obj, gv, s) if err != nil { - klog.ErrorS(err, "Encoding failed of response object", "auditID", ae.AuditID, "obj", obj) + klog.ErrorS(err, "Encoding failed of response object", "auditID", ac.AuditID(), "obj", obj) } + ac.LogResponseObject(status, responseObject) } func encodeObject(obj runtime.Object, gv schema.GroupVersion, serializer runtime.NegotiatedSerializer) (*runtime.Unknown, error) { @@ -301,9 +286,9 @@ func removeManagedFields(obj runtime.Object) error { return nil } -func shouldOmitManagedFields(ctx context.Context) bool { - if auditContext := AuditContextFrom(ctx); auditContext != nil { - return auditContext.RequestAuditConfig.OmitManagedFields +func shouldOmitManagedFields(ac *AuditContext) bool { + if ac != nil && ac.initialized.Load() && ac.requestAuditConfig.OmitManagedFields { + return true } // If we can't decide, return false to maintain current behavior which is diff --git a/staging/src/k8s.io/apiserver/pkg/audit/request_log_test.go b/staging/src/k8s.io/apiserver/pkg/audit/request_log_test.go new file mode 100644 index 0000000000000..6ef3316d1df12 --- /dev/null +++ b/staging/src/k8s.io/apiserver/pkg/audit/request_log_test.go @@ -0,0 +1,362 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package audit + +import ( + "context" + "io" + "strings" + "testing" + + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/apimachinery/pkg/runtime/serializer" + auditinternal "k8s.io/apiserver/pkg/apis/audit" +) + +func TestLogResponseObjectWithPod(t *testing.T) { + testPod := &corev1.Pod{ + TypeMeta: metav1.TypeMeta{ + APIVersion: "v1", + Kind: "Pod", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: "test-pod", + Namespace: "test-namespace", + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "test-container", + Image: "test-image", + }, + }, + }, + } + + scheme := runtime.NewScheme() + if err := corev1.AddToScheme(scheme); err != nil { + t.Fatalf("Failed to add core/v1 to scheme: %v", err) + } + codecs := serializer.NewCodecFactory(scheme) + negotiatedSerializer := codecs.WithoutConversion() + + // Create audit context with RequestResponse level + ctx := WithAuditContext(context.Background()) + ac := AuditContextFrom(ctx) + + captureSink := &capturingAuditSink{} + if err := ac.Init(RequestAuditConfig{Level: auditinternal.LevelRequestResponse}, captureSink); err != nil { + t.Fatalf("Failed to initialize audit context: %v", err) + } + + LogResponseObject(ctx, testPod, schema.GroupVersion{Group: "", Version: "v1"}, negotiatedSerializer) + ac.ProcessEventStage(ctx, auditinternal.StageResponseComplete) + + if len(captureSink.events) != 1 { + t.Fatalf("Expected one audit event to be captured, got %d", len(captureSink.events)) + } + event := captureSink.events[0] + if event.ResponseObject == nil { + t.Fatal("Expected ResponseObject to be set, but it was nil") + } + if event.ResponseObject.ContentType != runtime.ContentTypeJSON { + t.Errorf("Expected ContentType to be %q, got %q", runtime.ContentTypeJSON, event.ResponseObject.ContentType) + } + if len(event.ResponseObject.Raw) == 0 { + t.Error("Expected ResponseObject.Raw to contain data, but it was empty") + } + + responseJSON := string(event.ResponseObject.Raw) + expectedFields := []string{"test-pod", "test-namespace", "test-container", "test-image"} + for _, field := range expectedFields { + if !strings.Contains(responseJSON, field) { + t.Errorf("Response should contain %q but didn't. Response: %s", field, responseJSON) + } + } + + if event.ResponseStatus != nil { + t.Errorf("Expected ResponseStatus to be nil for regular object, got: %+v", event.ResponseStatus) + } +} + +func TestLogResponseObjectWithStatus(t *testing.T) { + testCases := []struct { + name string + level auditinternal.Level + status *metav1.Status + shouldEncode bool + expectResponseObj bool + expectStatusFields bool + }{ + { + name: "RequestResponse level should encode and log status fields", + level: auditinternal.LevelRequestResponse, + status: &metav1.Status{Status: "Success", Message: "Test message", Code: 200}, + shouldEncode: true, + expectResponseObj: true, + expectStatusFields: true, + }, + { + name: "Metadata level should log status fields without encoding", + level: auditinternal.LevelMetadata, + status: &metav1.Status{Status: "Success", Message: "Test message", Code: 200}, + shouldEncode: false, + expectResponseObj: false, + expectStatusFields: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + scheme := runtime.NewScheme() + if err := metav1.AddMetaToScheme(scheme); err != nil { + t.Fatalf("Failed to add meta to scheme: %v", err) + } + scheme.AddKnownTypes(schema.GroupVersion{Version: "v1"}, &metav1.Status{}) + codecs := serializer.NewCodecFactory(scheme) + negotiatedSerializer := codecs.WithoutConversion() + + ctx := WithAuditContext(context.Background()) + ac := AuditContextFrom(ctx) + + captureSink := &capturingAuditSink{} + if err := ac.Init(RequestAuditConfig{Level: tc.level}, captureSink); err != nil { + t.Fatalf("Failed to initialize audit context: %v", err) + } + + LogResponseObject(ctx, tc.status, schema.GroupVersion{Group: "", Version: "v1"}, negotiatedSerializer) + ac.ProcessEventStage(ctx, auditinternal.StageResponseComplete) + + if len(captureSink.events) != 1 { + t.Fatalf("Expected one audit event to be captured, got %d", len(captureSink.events)) + } + event := captureSink.events[0] + + if tc.expectResponseObj { + if event.ResponseObject == nil { + t.Error("Expected ResponseObject to be set, but it was nil") + } + } else { + if event.ResponseObject != nil { + t.Error("Expected ResponseObject to be nil") + } + } + + if tc.expectStatusFields { + if event.ResponseStatus == nil { + t.Fatal("Expected ResponseStatus to be set, but it was nil") + } + if event.ResponseStatus.Status != tc.status.Status { + t.Errorf("Expected ResponseStatus.Status to be %q, got %q", tc.status.Status, event.ResponseStatus.Status) + } + if event.ResponseStatus.Message != tc.status.Message { + t.Errorf("Expected ResponseStatus.Message to be %q, got %q", tc.status.Message, event.ResponseStatus.Message) + } + if event.ResponseStatus.Code != tc.status.Code { + t.Errorf("Expected ResponseStatus.Code to be %d, got %d", tc.status.Code, event.ResponseStatus.Code) + } + } else if event.ResponseStatus != nil { + t.Error("Expected ResponseStatus to be nil") + } + }) + } +} + +func TestLogResponseObjectLevelCheck(t *testing.T) { + testCases := []struct { + name string + level auditinternal.Level + obj runtime.Object + shouldEncode bool + expectResponseObj bool + expectStatusFields bool + }{ + { + name: "None level should not encode or log anything", + level: auditinternal.LevelNone, + obj: &corev1.Pod{}, + shouldEncode: false, + expectResponseObj: false, + expectStatusFields: false, + }, + { + name: "Metadata level should not encode or log anything", + level: auditinternal.LevelMetadata, + obj: &corev1.Pod{}, + shouldEncode: false, + expectResponseObj: false, + expectStatusFields: false, + }, + { + name: "Metadata level with Status should log status fields without encoding", + level: auditinternal.LevelMetadata, + obj: &metav1.Status{ + Status: "Success", + Message: "Test message", + Code: 200, + }, + shouldEncode: false, + expectResponseObj: false, + expectStatusFields: true, + }, + { + name: "Request level with Pod should not encode or log", + level: auditinternal.LevelRequest, + obj: &corev1.Pod{}, + shouldEncode: false, + expectResponseObj: false, + expectStatusFields: false, + }, + { + name: "Request level with Status should log status fields without encoding", + level: auditinternal.LevelRequest, + obj: &metav1.Status{ + Status: "Success", + Message: "Test message", + Code: 200, + }, + shouldEncode: false, + expectResponseObj: false, + expectStatusFields: true, + }, + { + name: "RequestResponse level with Pod should encode", + level: auditinternal.LevelRequestResponse, + obj: &corev1.Pod{}, + shouldEncode: true, + expectResponseObj: true, + expectStatusFields: false, + }, + { + name: "RequestResponse level with Status should encode and log status fields", + level: auditinternal.LevelRequestResponse, + obj: &metav1.Status{ + Status: "Success", + Message: "Test message", + Code: 200, + }, + shouldEncode: true, + expectResponseObj: true, + expectStatusFields: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx := WithAuditContext(context.Background()) + ac := AuditContextFrom(ctx) + + captureSink := &capturingAuditSink{} + if err := ac.Init(RequestAuditConfig{Level: tc.level}, captureSink); err != nil { + t.Fatalf("Failed to initialize audit context: %v", err) + } + + mockSerializer := &mockNegotiatedSerializer{} + LogResponseObject(ctx, tc.obj, schema.GroupVersion{Group: "", Version: "v1"}, mockSerializer) + ac.ProcessEventStage(ctx, auditinternal.StageResponseComplete) + + if mockSerializer.encodeCalled != tc.shouldEncode { + t.Errorf("Expected encoding to be called: %v, but got: %v", tc.shouldEncode, mockSerializer.encodeCalled) + } + + if len(captureSink.events) != 1 { + t.Fatalf("Expected one audit event to be captured, got %d", len(captureSink.events)) + } + event := captureSink.events[0] + + if tc.expectResponseObj { + if event.ResponseObject == nil { + t.Error("Expected ResponseObject to be set, but it was nil") + } + } else { + if event.ResponseObject != nil { + t.Error("Expected ResponseObject to be nil") + } + } + + // Check ResponseStatus for Status objects + status, isStatus := tc.obj.(*metav1.Status) + if isStatus && tc.expectStatusFields { + if event.ResponseStatus == nil { + t.Error("Expected ResponseStatus to be set for Status object, but it was nil") + } else { + if event.ResponseStatus.Status != status.Status { + t.Errorf("Expected ResponseStatus.Status to be %q, got %q", status.Status, event.ResponseStatus.Status) + } + if event.ResponseStatus.Message != status.Message { + t.Errorf("Expected ResponseStatus.Message to be %q, got %q", status.Message, event.ResponseStatus.Message) + } + if event.ResponseStatus.Code != status.Code { + t.Errorf("Expected ResponseStatus.Code to be %d, got %d", status.Code, event.ResponseStatus.Code) + } + } + } else if event.ResponseStatus != nil { + t.Error("Expected ResponseStatus to be nil") + } + }) + } +} + +type mockNegotiatedSerializer struct { + encodeCalled bool +} + +func (m *mockNegotiatedSerializer) SupportedMediaTypes() []runtime.SerializerInfo { + return []runtime.SerializerInfo{ + { + MediaType: runtime.ContentTypeJSON, + EncodesAsText: true, + Serializer: nil, + PrettySerializer: nil, + StreamSerializer: nil, + }, + } +} + +func (m *mockNegotiatedSerializer) EncoderForVersion(serializer runtime.Encoder, gv runtime.GroupVersioner) runtime.Encoder { + m.encodeCalled = true + return &mockEncoder{} +} + +func (m *mockNegotiatedSerializer) DecoderToVersion(serializer runtime.Decoder, gv runtime.GroupVersioner) runtime.Decoder { + return nil +} + +type mockEncoder struct{} + +func (e *mockEncoder) Encode(obj runtime.Object, w io.Writer) error { + return nil +} + +func (e *mockEncoder) Identifier() runtime.Identifier { + return runtime.Identifier("mock") +} + +type capturingAuditSink struct { + events []*auditinternal.Event +} + +func (s *capturingAuditSink) ProcessEvents(events ...*auditinternal.Event) bool { + for _, event := range events { + eventCopy := event.DeepCopy() + s.events = append(s.events, eventCopy) + } + return true +} diff --git a/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/cached_token_authenticator.go b/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/cached_token_authenticator.go index 18167dddc2bfc..9d1556e633659 100644 --- a/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/cached_token_authenticator.go +++ b/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/cached_token_authenticator.go @@ -33,7 +33,6 @@ import ( "golang.org/x/sync/singleflight" apierrors "k8s.io/apimachinery/pkg/api/errors" - auditinternal "k8s.io/apiserver/pkg/apis/audit" "k8s.io/apiserver/pkg/audit" "k8s.io/apiserver/pkg/authentication/authenticator" "k8s.io/apiserver/pkg/warning" @@ -199,12 +198,9 @@ func (a *cachedTokenAuthenticator) doAuthenticateToken(ctx context.Context, toke ctx = audit.WithAuditContext(ctx) ac := audit.AuditContextFrom(ctx) - // since this is shared work between multiple requests, we have no way of knowing if any - // particular request supports audit annotations. thus we always attempt to record them. - ac.Event.Level = auditinternal.LevelMetadata record.resp, record.ok, record.err = a.authenticator.AuthenticateToken(ctx, token) - record.annotations = ac.Event.Annotations + record.annotations = ac.GetEventAnnotations() record.warnings = recorder.extractWarnings() if !a.cacheErrs && record.err != nil { diff --git a/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/cached_token_authenticator_test.go b/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/cached_token_authenticator_test.go index 14f9a26eb419d..f1950feca2646 100644 --- a/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/cached_token_authenticator_test.go +++ b/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/cached_token_authenticator_test.go @@ -35,7 +35,6 @@ import ( utilrand "k8s.io/apimachinery/pkg/util/rand" "k8s.io/apimachinery/pkg/util/uuid" - auditinternal "k8s.io/apiserver/pkg/apis/audit" "k8s.io/apiserver/pkg/audit" "k8s.io/apiserver/pkg/authentication/authenticator" "k8s.io/apiserver/pkg/authentication/user" @@ -306,7 +305,7 @@ func TestCachedAuditAnnotations(t *testing.T) { ctx := withAudit(context.Background()) _, _, _ = a.AuthenticateToken(ctx, "token") - allAnnotations <- audit.AuditEventFrom(ctx).Annotations + allAnnotations <- audit.AuditContextFrom(ctx).GetEventAnnotations() }() } @@ -343,7 +342,7 @@ func TestCachedAuditAnnotations(t *testing.T) { for i := 0; i < cap(allAnnotations); i++ { ctx := withAudit(context.Background()) _, _, _ = a.AuthenticateToken(ctx, "token") - allAnnotations = append(allAnnotations, audit.AuditEventFrom(ctx).Annotations) + allAnnotations = append(allAnnotations, audit.AuditContextFrom(ctx).GetEventAnnotations()) } if len(allAnnotations) != cap(allAnnotations) { @@ -370,14 +369,14 @@ func TestCachedAuditAnnotations(t *testing.T) { ctx1 := withAudit(context.Background()) _, _, _ = a.AuthenticateToken(ctx1, "token1") - annotations1 := audit.AuditEventFrom(ctx1).Annotations + annotations1 := audit.AuditContextFrom(ctx1).GetEventAnnotations() // guarantee different now times time.Sleep(time.Second) ctx2 := withAudit(context.Background()) _, _, _ = a.AuthenticateToken(ctx2, "token2") - annotations2 := audit.AuditEventFrom(ctx2).Annotations + annotations2 := audit.AuditContextFrom(ctx2).GetEventAnnotations() if ok := len(annotations1) == 1 && len(annotations1["timestamp"]) > 0; !ok { t.Errorf("invalid annotations 1: %v", annotations1) @@ -546,8 +545,6 @@ func (s *singleBenchmark) bench(b *testing.B) { // extraction. func withAudit(ctx context.Context) context.Context { ctx = audit.WithAuditContext(ctx) - ac := audit.AuditContextFrom(ctx) - ac.Event.Level = auditinternal.LevelMetadata return ctx } diff --git a/staging/src/k8s.io/apiserver/pkg/endpoints/filters/audit.go b/staging/src/k8s.io/apiserver/pkg/endpoints/filters/audit.go index 6f850f728bfdb..d25bf35ae3af0 100644 --- a/staging/src/k8s.io/apiserver/pkg/endpoints/filters/audit.go +++ b/staging/src/k8s.io/apiserver/pkg/endpoints/filters/audit.go @@ -44,7 +44,7 @@ func WithAudit(handler http.Handler, sink audit.Sink, policy audit.PolicyRuleEva return handler } return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - ac, err := evaluatePolicyAndCreateAuditEvent(req, policy) + ac, err := evaluatePolicyAndCreateAuditEvent(req, policy, sink) if err != nil { utilruntime.HandleError(fmt.Errorf("failed to create audit event: %v", err)) responsewriters.InternalError(w, req, errors.New("failed to create audit event")) @@ -55,41 +55,37 @@ func WithAudit(handler http.Handler, sink audit.Sink, policy audit.PolicyRuleEva handler.ServeHTTP(w, req) return } - ev := &ac.Event ctx := req.Context() - omitStages := ac.RequestAuditConfig.OmitStages - ev.Stage = auditinternal.StageRequestReceived - if processed := processAuditEvent(ctx, sink, ev, omitStages); !processed { + if processed := ac.ProcessEventStage(ctx, auditinternal.StageRequestReceived); !processed { audit.ApiserverAuditDroppedCounter.WithContext(ctx).Inc() responsewriters.InternalError(w, req, errors.New("failed to store audit event")) return } // intercept the status code - var longRunningSink audit.Sink + isLongRunning := false if longRunningCheck != nil { ri, _ := request.RequestInfoFrom(ctx) if longRunningCheck(req, ri) { - longRunningSink = sink + isLongRunning = true } } - respWriter := decorateResponseWriter(ctx, w, ev, longRunningSink, omitStages) + respWriter := decorateResponseWriter(ctx, w, isLongRunning) // send audit event when we leave this func, either via a panic or cleanly. In the case of long // running requests, this will be the second audit event. defer func() { if r := recover(); r != nil { defer panic(r) - ev.Stage = auditinternal.StagePanic - ev.ResponseStatus = &metav1.Status{ + ac.SetEventResponseStatus(&metav1.Status{ Code: http.StatusInternalServerError, Status: metav1.StatusFailure, Reason: metav1.StatusReasonInternalError, Message: fmt.Sprintf("APIServer panic'd: %v", r), - } - processAuditEvent(ctx, sink, ev, omitStages) + }) + ac.ProcessEventStage(ctx, auditinternal.StagePanic) return } @@ -100,27 +96,25 @@ func WithAudit(handler http.Handler, sink audit.Sink, policy audit.PolicyRuleEva Status: metav1.StatusSuccess, Message: "Connection closed early", } - if ev.ResponseStatus == nil && longRunningSink != nil { - ev.ResponseStatus = fakedSuccessStatus - ev.Stage = auditinternal.StageResponseStarted - processAuditEvent(ctx, longRunningSink, ev, omitStages) - } - - ev.Stage = auditinternal.StageResponseComplete - if ev.ResponseStatus == nil { - ev.ResponseStatus = fakedSuccessStatus + if ac.GetEventResponseStatus() == nil { + ac.SetEventResponseStatus(fakedSuccessStatus) + if isLongRunning { + // A nil ResponseStatus means the writer never processed the ResponseStarted stage, so do that now. + ac.ProcessEventStage(ctx, auditinternal.StageResponseStarted) + } } - processAuditEvent(ctx, sink, ev, omitStages) + writeLatencyToAnnotation(ctx) + ac.ProcessEventStage(ctx, auditinternal.StageResponseComplete) }() handler.ServeHTTP(respWriter, req) }) } // evaluatePolicyAndCreateAuditEvent is responsible for evaluating the audit -// policy configuration applicable to the request and create a new audit -// event that will be written to the API audit log. +// policy configuration applicable to the request and initializing the audit +// context with the audit config for the request, the sink to write to, and the request metadata. // - error if anything bad happened -func evaluatePolicyAndCreateAuditEvent(req *http.Request, policy audit.PolicyRuleEvaluator) (*audit.AuditContext, error) { +func evaluatePolicyAndCreateAuditEvent(req *http.Request, policy audit.PolicyRuleEvaluator, sink audit.Sink) (*audit.AuditContext, error) { ctx := req.Context() ac := audit.AuditContextFrom(ctx) if ac == nil { @@ -135,7 +129,10 @@ func evaluatePolicyAndCreateAuditEvent(req *http.Request, policy audit.PolicyRul rac := policy.EvaluatePolicyRule(attribs) audit.ObservePolicyLevel(ctx, rac.Level) - ac.RequestAuditConfig = rac + err = ac.Init(rac, sink) + if err != nil { + return nil, fmt.Errorf("failed to initialize audit context: %w", err) + } if rac.Level == auditinternal.LevelNone { // Don't audit. return ac, nil @@ -145,7 +142,7 @@ func evaluatePolicyAndCreateAuditEvent(req *http.Request, policy audit.PolicyRul if !ok { requestReceivedTimestamp = time.Now() } - audit.LogRequestMetadata(ctx, req, requestReceivedTimestamp, rac.Level, attribs) + audit.LogRequestMetadata(ctx, req, requestReceivedTimestamp, attribs) return ac, nil } @@ -153,13 +150,14 @@ func evaluatePolicyAndCreateAuditEvent(req *http.Request, policy audit.PolicyRul // writeLatencyToAnnotation writes the latency incurred in different // layers of the apiserver to the annotations of the audit object. // it should be invoked after ev.StageTimestamp has been set appropriately. -func writeLatencyToAnnotation(ctx context.Context, ev *auditinternal.Event) { +func writeLatencyToAnnotation(ctx context.Context) { + ac := audit.AuditContextFrom(ctx) // we will track latency in annotation only when the total latency // of the given request exceeds 500ms, this is in keeping with the // traces in rest/handlers for create, delete, update, // get, list, and deletecollection. const threshold = 500 * time.Millisecond - latency := ev.StageTimestamp.Time.Sub(ev.RequestReceivedTimestamp.Time) + latency := ac.GetEventStageTimestamp().Sub(ac.GetEventRequestReceivedTimestamp().Time) if latency <= threshold { return } @@ -177,34 +175,12 @@ func writeLatencyToAnnotation(ctx context.Context, ev *auditinternal.Event) { audit.AddAuditAnnotationsMap(ctx, layerLatencies) } -func processAuditEvent(ctx context.Context, sink audit.Sink, ev *auditinternal.Event, omitStages []auditinternal.Stage) bool { - for _, stage := range omitStages { - if ev.Stage == stage { - return true - } - } - - switch { - case ev.Stage == auditinternal.StageRequestReceived: - ev.StageTimestamp = metav1.NewMicroTime(ev.RequestReceivedTimestamp.Time) - case ev.Stage == auditinternal.StageResponseComplete: - ev.StageTimestamp = metav1.NewMicroTime(time.Now()) - writeLatencyToAnnotation(ctx, ev) - default: - ev.StageTimestamp = metav1.NewMicroTime(time.Now()) - } - - audit.ObserveEvent(ctx) - return sink.ProcessEvents(ev) -} - -func decorateResponseWriter(ctx context.Context, responseWriter http.ResponseWriter, ev *auditinternal.Event, sink audit.Sink, omitStages []auditinternal.Stage) http.ResponseWriter { +func decorateResponseWriter(ctx context.Context, responseWriter http.ResponseWriter, processResponseStartedStage bool) http.ResponseWriter { delegate := &auditResponseWriter{ ctx: ctx, ResponseWriter: responseWriter, - event: ev, - sink: sink, - omitStages: omitStages, + + processResponseStartedStage: processResponseStartedStage, } return responsewriter.WrapForHTTP1Or2(delegate) @@ -217,11 +193,10 @@ var _ responsewriter.UserProvidedDecorator = &auditResponseWriter{} // create immediately an event (for long running requests). type auditResponseWriter struct { http.ResponseWriter - ctx context.Context - event *auditinternal.Event - once sync.Once - sink audit.Sink - omitStages []auditinternal.Stage + ctx context.Context + once sync.Once + + processResponseStartedStage bool } func (a *auditResponseWriter) Unwrap() http.ResponseWriter { @@ -230,14 +205,10 @@ func (a *auditResponseWriter) Unwrap() http.ResponseWriter { func (a *auditResponseWriter) processCode(code int) { a.once.Do(func() { - if a.event.ResponseStatus == nil { - a.event.ResponseStatus = &metav1.Status{} - } - a.event.ResponseStatus.Code = int32(code) - a.event.Stage = auditinternal.StageResponseStarted - - if a.sink != nil { - processAuditEvent(a.ctx, a.sink, a.event, a.omitStages) + ac := audit.AuditContextFrom(a.ctx) + ac.SetEventResponseStatusCode(int32(code)) + if a.processResponseStartedStage { + ac.ProcessEventStage(a.ctx, auditinternal.StageResponseStarted) } }) } diff --git a/staging/src/k8s.io/apiserver/pkg/endpoints/filters/audit_test.go b/staging/src/k8s.io/apiserver/pkg/endpoints/filters/audit_test.go index e9c375a9b0566..cffed607b4be7 100644 --- a/staging/src/k8s.io/apiserver/pkg/endpoints/filters/audit_test.go +++ b/staging/src/k8s.io/apiserver/pkg/endpoints/filters/audit_test.go @@ -18,24 +18,37 @@ package filters import ( "context" + "math/rand" "net/http" "net/http/httptest" + "net/url" "reflect" "sync" "testing" "time" + "unsafe" + "github.com/google/go-cmp/cmp" "github.com/google/uuid" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/serializer" "k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/util/wait" auditinternal "k8s.io/apiserver/pkg/apis/audit" + auditv1 "k8s.io/apiserver/pkg/apis/audit/v1" "k8s.io/apiserver/pkg/audit" "k8s.io/apiserver/pkg/audit/policy" "k8s.io/apiserver/pkg/authentication/user" "k8s.io/apiserver/pkg/endpoints/request" "k8s.io/apiserver/pkg/endpoints/responsewriter" + "k8s.io/apiserver/plugin/pkg/audit/buffered" + "k8s.io/apiserver/plugin/pkg/audit/log" + "k8s.io/apiserver/plugin/pkg/audit/webhook" + "k8s.io/client-go/rest" + "k8s.io/client-go/util/flowcontrol" + "k8s.io/client-go/util/retry" ) type fakeAuditSink struct { @@ -76,7 +89,7 @@ func (s *fakeAuditSink) Pop(timeout time.Duration) (*auditinternal.Event, error) func TestConstructResponseWriter(t *testing.T) { inner := &responsewriter.FakeResponseWriter{} - actual := decorateResponseWriter(context.Background(), inner, nil, nil, nil) + actual := decorateResponseWriter(context.Background(), inner, false) switch v := actual.(type) { case *auditResponseWriter: default: @@ -86,7 +99,7 @@ func TestConstructResponseWriter(t *testing.T) { t.Errorf("Expected the decorator to return the inner http.ResponseWriter object") } - actual = decorateResponseWriter(context.Background(), &responsewriter.FakeResponseWriterFlusherCloseNotifier{}, nil, nil, nil) + actual = decorateResponseWriter(context.Background(), &responsewriter.FakeResponseWriterFlusherCloseNotifier{}, false) //lint:file-ignore SA1019 Keep supporting deprecated http.CloseNotifier if _, ok := actual.(http.CloseNotifier); !ok { t.Errorf("Expected http.ResponseWriter to implement http.CloseNotifier") @@ -98,7 +111,7 @@ func TestConstructResponseWriter(t *testing.T) { t.Errorf("Expected http.ResponseWriter not to implement http.Hijacker") } - actual = decorateResponseWriter(context.Background(), &responsewriter.FakeResponseWriterFlusherCloseNotifierHijacker{}, nil, nil, nil) + actual = decorateResponseWriter(context.Background(), &responsewriter.FakeResponseWriterFlusherCloseNotifierHijacker{}, false) //lint:file-ignore SA1019 Keep supporting deprecated http.CloseNotifier if _, ok := actual.(http.CloseNotifier); !ok { t.Errorf("Expected http.ResponseWriter to implement http.CloseNotifier") @@ -112,37 +125,43 @@ func TestConstructResponseWriter(t *testing.T) { } func TestDecorateResponseWriterWithoutChannel(t *testing.T) { - ev := &auditinternal.Event{} - actual := decorateResponseWriter(context.Background(), &responsewriter.FakeResponseWriter{}, ev, nil, nil) + ctx := audit.WithAuditContext(context.Background()) + ac := audit.AuditContextFrom(ctx) + actual := decorateResponseWriter(ctx, &responsewriter.FakeResponseWriter{}, false) // write status. This will not block because firstEventSentCh is nil actual.WriteHeader(42) - if ev.ResponseStatus == nil { + if ac.GetEventResponseStatus() == nil { t.Fatalf("Expected ResponseStatus to be non-nil") } - if ev.ResponseStatus.Code != 42 { - t.Errorf("expected status code 42, got %d", ev.ResponseStatus.Code) + if ac.GetEventResponseStatus().Code != 42 { + t.Errorf("expected status code 42, got %d", ac.GetEventResponseStatus().Code) } } func TestDecorateResponseWriterWithImplicitWrite(t *testing.T) { - ev := &auditinternal.Event{} - actual := decorateResponseWriter(context.Background(), &responsewriter.FakeResponseWriter{}, ev, nil, nil) + ctx := audit.WithAuditContext(context.Background()) + ac := audit.AuditContextFrom(ctx) + actual := decorateResponseWriter(ctx, &responsewriter.FakeResponseWriter{}, false) // write status. This will not block because firstEventSentCh is nil actual.Write([]byte("foo")) - if ev.ResponseStatus == nil { + if ac.GetEventResponseStatus() == nil { t.Fatalf("Expected ResponseStatus to be non-nil") } - if ev.ResponseStatus.Code != 200 { - t.Errorf("expected status code 200, got %d", ev.ResponseStatus.Code) + if ac.GetEventResponseStatus().Code != 200 { + t.Errorf("expected status code 200, got %d", ac.GetEventResponseStatus().Code) } } func TestDecorateResponseWriterChannel(t *testing.T) { + ctx := audit.WithAuditContext(context.Background()) sink := &fakeAuditSink{} - ev := &auditinternal.Event{} - actual := decorateResponseWriter(context.Background(), &responsewriter.FakeResponseWriter{}, ev, sink, nil) + auditContext := audit.AuditContextFrom(ctx) + if err := auditContext.Init(audit.RequestAuditConfig{}, sink); err != nil { + t.Fatal(err) + } + actual := decorateResponseWriter(ctx, &responsewriter.FakeResponseWriter{}, true) done := make(chan struct{}) go func() { @@ -164,8 +183,11 @@ func TestDecorateResponseWriterChannel(t *testing.T) { } t.Logf("Seen event with status %v", ev1.ResponseStatus) - if !reflect.DeepEqual(ev, ev1) { - t.Fatalf("ev1 and ev must be equal") + ev := getAuditContextEvent(auditContext) + if diff := cmp.Diff(ev, ev1, cmp.FilterPath(func(p cmp.Path) bool { + return p.String() == "StageTimestamp" + }, cmp.Ignore())); diff != "" { + t.Fatalf("ev1 and ev must be equal, diff: %s", diff) } <-done @@ -178,6 +200,20 @@ func TestDecorateResponseWriterChannel(t *testing.T) { } } +func getAuditContextEvent(ac *audit.AuditContext) *auditinternal.Event { + // Get the reflect.Value of the AuditContext + val := reflect.ValueOf(ac).Elem() + + // Access the unexported `event` field + eventField := val.FieldByName("event") + + // Use unsafe to get a pointer to the field + eventPtr := unsafe.Pointer(eventField.UnsafeAddr()) + + // Cast the pointer to the correct type + return (*auditinternal.Event)(eventPtr) +} + type fakeHTTPHandler struct{} func (*fakeHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { @@ -848,11 +884,130 @@ func withTestContext(req *http.Request, user user.Info, ae *auditinternal.Event) ctx = request.WithUser(ctx, user) } if ae != nil { - ac := audit.AuditContextFrom(ctx) - ac.Event = *ae + ev := getAuditContextEvent(audit.AuditContextFrom(ctx)) + *ev = *ae } if info, err := newTestRequestInfoResolver().NewRequestInfo(req); err == nil { ctx = request.WithRequestInfo(ctx, info) } return req.WithContext(ctx) } + +type fakeAuditFile struct{} + +func (s fakeAuditFile) Write(p []byte) (n int, err error) { + time.Sleep(time.Duration(rand.Int63n(10000))) + return len(p), nil +} + +type fakeAuditWebhookAuditBackend struct { +} + +func (f fakeAuditWebhookAuditBackend) RoundTrip(r *http.Request) (*http.Response, error) { + time.Sleep(time.Duration(rand.Int63n(10000))) + return &http.Response{ + StatusCode: http.StatusOK, + }, nil +} + +// Test case for https://github.com/kubernetes/kubernetes/issues/120507 +// to test for race conditions in audit backends use the following command: +// `go test ./ -race --run=TestAuditBackendRaceCondition -v` +func TestAuditBackendRaceCondition(t *testing.T) { + defaultFakeLogBackend := log.NewBackend(fakeAuditFile{}, log.FormatJson, auditv1.SchemeGroupVersion) + testCases := []struct { + name string + backendBuilder func() audit.Backend + }{ + { + "log audit backend", + func() audit.Backend { + return defaultFakeLogBackend + }, + }, + { + "buffered audit backend", + func() audit.Backend { + backend := buffered.NewBackend(defaultFakeLogBackend, buffered.BatchConfig{ + BufferSize: 10000, + MaxBatchSize: 1, + ThrottleEnable: false, + AsyncDelegate: false, + }) + err := backend.Run(wait.NeverStop) + if err != nil { + t.Fatal(err) + } + return backend + }, + }, + { + name: "webhook audit backend", + backendBuilder: func() audit.Backend { + codecFactory := audit.Codecs + codec := codecFactory.LegacyCodec(auditv1.SchemeGroupVersion) + negotiatedSerializer := serializer.NegotiatedSerializerWrapper(runtime.SerializerInfo{Serializer: codec}) + client, err := rest.NewRESTClient(&url.URL{}, "/hello", rest.ClientContentConfig{ + ContentType: "application/json", + Negotiator: runtime.NewClientNegotiator(negotiatedSerializer, auditv1.SchemeGroupVersion), + }, flowcontrol.NewTokenBucketRateLimiter(100, 200), &http.Client{Transport: fakeAuditWebhookAuditBackend{}}) + if err != nil { + t.Fatal(err) + } + return webhook.NewDynamicBackend(client, retry.DefaultBackoff) + }, + }, + { + "union audit backend", + func() audit.Backend { + return audit.Union(defaultFakeLogBackend, defaultFakeLogBackend) + }, + }, + } + fakeRuleEvaluator := policy.NewFakePolicyRuleEvaluator(auditinternal.LevelRequestResponse, nil) + longRunningCheck := func(r *http.Request, ri *request.RequestInfo) bool { return false } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), wait.ForeverTestTimeout) + defer cancel() + for { + select { + case <-ctx.Done(): + // finished the test + return + default: + } + serveStarted := make(chan struct{}) + req, _ := http.NewRequest(http.MethodGet, "/api/v1/namespaces/default/pods/foo", nil) + req = withTestContext(req, &user.DefaultInfo{Name: "admin"}, nil) + backend := tc.backendBuilder() + go func() { + <-serveStarted + for { + select { + case <-ctx.Done(): + // finished the test + backend.Shutdown() + return + default: + } + audit.AddAuditAnnotations(req.Context(), "a", "b") + } + }() + realHandler := http.HandlerFunc(func(writer http.ResponseWriter, r *http.Request) { + close(serveStarted) + // mock some business logic + time.Sleep(time.Millisecond) + }) + handler := WithAudit(realHandler, backend, fakeRuleEvaluator, longRunningCheck) + handler = WithAuditInit(handler) + serveFinished := make(chan struct{}) + go func() { + defer close(serveFinished) + handler.ServeHTTP(httptest.NewRecorder(), req) + }() + <-serveFinished + } + }) + } +} diff --git a/staging/src/k8s.io/apiserver/pkg/endpoints/filters/authn_audit.go b/staging/src/k8s.io/apiserver/pkg/endpoints/filters/authn_audit.go index 4bd6bbc139668..d9cdcd2d62d19 100644 --- a/staging/src/k8s.io/apiserver/pkg/endpoints/filters/authn_audit.go +++ b/staging/src/k8s.io/apiserver/pkg/endpoints/filters/authn_audit.go @@ -24,7 +24,6 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" utilruntime "k8s.io/apimachinery/pkg/util/runtime" - auditinternal "k8s.io/apiserver/pkg/apis/audit" "k8s.io/apiserver/pkg/audit" "k8s.io/apiserver/pkg/endpoints/handlers/responsewriters" ) @@ -36,7 +35,7 @@ func WithFailedAuthenticationAudit(failedHandler http.Handler, sink audit.Sink, return failedHandler } return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - ac, err := evaluatePolicyAndCreateAuditEvent(req, policy) + ac, err := evaluatePolicyAndCreateAuditEvent(req, policy, sink) if err != nil { utilruntime.HandleError(fmt.Errorf("failed to create audit event: %v", err)) responsewriters.InternalError(w, req, errors.New("failed to create audit event")) @@ -47,13 +46,11 @@ func WithFailedAuthenticationAudit(failedHandler http.Handler, sink audit.Sink, failedHandler.ServeHTTP(w, req) return } - ev := &ac.Event - ev.ResponseStatus = &metav1.Status{} - ev.ResponseStatus.Message = getAuthMethods(req) - ev.Stage = auditinternal.StageResponseStarted - - rw := decorateResponseWriter(req.Context(), w, ev, sink, ac.RequestAuditConfig.OmitStages) + ac.SetEventResponseStatus(&metav1.Status{ + Message: getAuthMethods(req), + }) + rw := decorateResponseWriter(req.Context(), w, true) failedHandler.ServeHTTP(rw, req) }) } diff --git a/staging/src/k8s.io/apiserver/pkg/endpoints/filters/authorization_test.go b/staging/src/k8s.io/apiserver/pkg/endpoints/filters/authorization_test.go index deef9054b18af..b2bd49fd6ba39 100644 --- a/staging/src/k8s.io/apiserver/pkg/endpoints/filters/authorization_test.go +++ b/staging/src/k8s.io/apiserver/pkg/endpoints/filters/authorization_test.go @@ -286,11 +286,21 @@ func TestAuditAnnotation(t *testing.T) { req, _ := http.NewRequest("GET", "/api/v1/namespaces/default/pods", nil) req = withTestContext(req, nil, &auditinternal.Event{Level: auditinternal.LevelMetadata}) - ae := audit.AuditEventFrom(req.Context()) + ae := audit.AuditContextFrom(req.Context()) req.RemoteAddr = "127.0.0.1" handler.ServeHTTP(httptest.NewRecorder(), req) - assert.Equal(t, tc.decisionAnnotation, ae.Annotations[decisionAnnotationKey], k+": unexpected decision annotation") - assert.Equal(t, tc.reasonAnnotation, ae.Annotations[reasonAnnotationKey], k+": unexpected reason annotation") + + var annotation string + var ok bool + if len(tc.decisionAnnotation) > 0 { + annotation, ok = ae.GetEventAnnotation(decisionAnnotationKey) + assert.True(t, ok, k+": decision annotation not found") + assert.Equal(t, tc.decisionAnnotation, annotation, k+": unexpected decision annotation") + } + + annotation, ok = ae.GetEventAnnotation(reasonAnnotationKey) + assert.True(t, ok, k+": reason annotation not found") + assert.Equal(t, tc.reasonAnnotation, annotation, k+": unexpected reason annotation") } } diff --git a/staging/src/k8s.io/apiserver/pkg/endpoints/filters/impersonation.go b/staging/src/k8s.io/apiserver/pkg/endpoints/filters/impersonation.go index a6d293a159081..aa47a7536d016 100644 --- a/staging/src/k8s.io/apiserver/pkg/endpoints/filters/impersonation.go +++ b/staging/src/k8s.io/apiserver/pkg/endpoints/filters/impersonation.go @@ -166,8 +166,7 @@ func WithImpersonation(handler http.Handler, a authorizer.Authorizer, s runtime. oldUser, _ := request.UserFrom(ctx) httplog.LogOf(req, w).Addf("%v is impersonating %v", userString(oldUser), userString(newUser)) - ae := audit.AuditEventFrom(ctx) - audit.LogImpersonatedUser(ae, newUser) + audit.LogImpersonatedUser(audit.WithAuditContext(ctx), newUser) // clear all the impersonation headers from the request req.Header.Del(authenticationv1.ImpersonateUserHeader) diff --git a/staging/src/k8s.io/apiserver/pkg/endpoints/filters/request_deadline.go b/staging/src/k8s.io/apiserver/pkg/endpoints/filters/request_deadline.go index 7497bc38a4240..066d670a2ad1b 100644 --- a/staging/src/k8s.io/apiserver/pkg/endpoints/filters/request_deadline.go +++ b/staging/src/k8s.io/apiserver/pkg/endpoints/filters/request_deadline.go @@ -108,7 +108,7 @@ func withFailedRequestAudit(failedHandler http.Handler, statusErr *apierrors.Sta return failedHandler } return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - ac, err := evaluatePolicyAndCreateAuditEvent(req, policy) + ac, err := evaluatePolicyAndCreateAuditEvent(req, policy, sink) if err != nil { utilruntime.HandleError(fmt.Errorf("failed to create audit event: %v", err)) responsewriters.InternalError(w, req, errors.New("failed to create audit event")) @@ -119,15 +119,15 @@ func withFailedRequestAudit(failedHandler http.Handler, statusErr *apierrors.Sta failedHandler.ServeHTTP(w, req) return } - ev := &ac.Event - ev.ResponseStatus = &metav1.Status{} - ev.Stage = auditinternal.StageResponseStarted + respStatus := &metav1.Status{} if statusErr != nil { - ev.ResponseStatus.Message = statusErr.Error() + respStatus.Message = statusErr.Error() } + ac.SetEventResponseStatus(respStatus) + ac.SetEventStage(auditinternal.StageResponseStarted) - rw := decorateResponseWriter(req.Context(), w, ev, sink, ac.RequestAuditConfig.OmitStages) + rw := decorateResponseWriter(req.Context(), w, true) failedHandler.ServeHTTP(rw, req) }) } diff --git a/staging/src/k8s.io/apiserver/pkg/endpoints/filters/request_deadline_test.go b/staging/src/k8s.io/apiserver/pkg/endpoints/filters/request_deadline_test.go index 6cc1b3c383439..6216429f8d5f1 100644 --- a/staging/src/k8s.io/apiserver/pkg/endpoints/filters/request_deadline_test.go +++ b/staging/src/k8s.io/apiserver/pkg/endpoints/filters/request_deadline_test.go @@ -22,7 +22,6 @@ import ( "fmt" "net/http" "net/http/httptest" - "reflect" "strings" "testing" "time" @@ -408,21 +407,21 @@ func TestWithFailedRequestAudit(t *testing.T) { t.Errorf("expected an http.ResponseWriter of type: %T but got: %T", &auditResponseWriter{}, rwGot) } - auditEventGot := audit.AuditEventFrom(requestGot.Context()) - if auditEventGot == nil { + auditContext := audit.AuditContextFrom(requestGot.Context()) + if auditContext == nil { t.Fatal("expected an audit event object but got nil") } - if auditEventGot.Stage != auditinternal.StageResponseStarted { - t.Errorf("expected audit event Stage: %s, but got: %s", auditinternal.StageResponseStarted, auditEventGot.Stage) + if auditContext.GetEventStage() != auditinternal.StageResponseStarted { + t.Errorf("expected audit event Stage: %s, but got: %s", auditinternal.StageResponseStarted, auditContext.GetEventStage()) } - if auditEventGot.ResponseStatus == nil { + if auditContext.GetEventResponseStatus() == nil { t.Fatal("expected a ResponseStatus field of the audit event object, but got nil") } - if test.statusCodeExpected != int(auditEventGot.ResponseStatus.Code) { - t.Errorf("expected audit event ResponseStatus.Code: %d, but got: %d", test.statusCodeExpected, auditEventGot.ResponseStatus.Code) + if test.statusCodeExpected != int(auditContext.GetEventResponseStatus().Code) { + t.Errorf("expected audit event ResponseStatus.Code: %d, but got: %d", test.statusCodeExpected, auditContext.GetEventResponseStatus().Code) } - if test.statusErr.Error() != auditEventGot.ResponseStatus.Message { - t.Errorf("expected audit event ResponseStatus.Message: %s, but got: %s", test.statusErr, auditEventGot.ResponseStatus.Message) + if test.statusErr.Error() != auditContext.GetEventResponseStatus().Message { + t.Errorf("expected audit event ResponseStatus.Message: %s, but got: %s", test.statusErr, auditContext.GetEventResponseStatus().Message) } // verify that the audit event from the request context is written to the audit sink. @@ -430,8 +429,12 @@ func TestWithFailedRequestAudit(t *testing.T) { t.Fatalf("expected audit sink to have 1 event, but got: %d", len(fakeSink.events)) } auditEventFromSink := fakeSink.events[0] - if !reflect.DeepEqual(auditEventGot, auditEventFromSink) { - t.Errorf("expected the audit event from the request context to be written to the audit sink, but got diffs: %s", cmp.Diff(auditEventGot, auditEventFromSink)) + eventFromAuditContext := getAuditContextEvent(auditContext) + + if diff := cmp.Diff(eventFromAuditContext, auditEventFromSink, cmp.FilterPath(func(p cmp.Path) bool { + return p.String() == "StageTimestamp" + }, cmp.Ignore())); diff != "" { + t.Errorf("expected the audit event from the request context to be written to the audit sink, but got diffs: %s", diff) } } }) diff --git a/staging/src/k8s.io/apiserver/pkg/endpoints/handlers/delete_test.go b/staging/src/k8s.io/apiserver/pkg/endpoints/handlers/delete_test.go index 074330cdd8a48..d3cdbaaa5e208 100644 --- a/staging/src/k8s.io/apiserver/pkg/endpoints/handlers/delete_test.go +++ b/staging/src/k8s.io/apiserver/pkg/endpoints/handlers/delete_test.go @@ -31,7 +31,7 @@ import ( "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/runtime/serializer" - auditapis "k8s.io/apiserver/pkg/apis/audit" + auditinternal "k8s.io/apiserver/pkg/apis/audit" "k8s.io/apiserver/pkg/audit" "k8s.io/apiserver/pkg/endpoints/handlers/negotiation" "k8s.io/apiserver/pkg/registry/rest" @@ -66,7 +66,9 @@ func TestDeleteResourceAuditLogRequestObject(t *testing.T) { ctx := audit.WithAuditContext(context.TODO()) ac := audit.AuditContextFrom(ctx) - ac.Event.Level = auditapis.LevelRequestResponse + if err := ac.Init(audit.RequestAuditConfig{Level: auditinternal.LevelRequestResponse}, nil); err != nil { + t.Fatal(err) + } policy := metav1.DeletePropagationBackground deleteOption := &metav1.DeleteOptions{ diff --git a/staging/src/k8s.io/apiserver/pkg/server/config_test.go b/staging/src/k8s.io/apiserver/pkg/server/config_test.go index 75314e2cafa33..5e7a84dfb7a0f 100644 --- a/staging/src/k8s.io/apiserver/pkg/server/config_test.go +++ b/staging/src/k8s.io/apiserver/pkg/server/config_test.go @@ -348,8 +348,8 @@ func TestAuthenticationAuditAnnotationsDefaultChain(t *testing.T) { } // confirm that we have an audit event - ae := audit.AuditEventFrom(r.Context()) - if ae == nil { + ac := audit.AuditContextFrom(r.Context()) + if ac == nil { t.Error("unexpected nil audit event") } @@ -373,11 +373,15 @@ func TestAuthenticationAuditAnnotationsDefaultChain(t *testing.T) { } // these should all be the same because the handler chain mutates the event in place want := map[string]string{"pandas": "are awesome", "dogs": "are okay"} + foundResponseComplete := false for _, event := range backend.events { + if event.Stage == auditinternal.StageRequestReceived { + continue + } if event.Stage != auditinternal.StageResponseComplete { t.Errorf("expected event stage to be complete, got: %s", event.Stage) } - + foundResponseComplete = true for wantK, wantV := range want { gotV, ok := event.Annotations[wantK] if !ok { @@ -389,6 +393,9 @@ func TestAuthenticationAuditAnnotationsDefaultChain(t *testing.T) { } } } + if !foundResponseComplete { + t.Errorf("expected to find %s in events", auditinternal.StageResponseComplete) + } } type testBackend struct { diff --git a/staging/src/k8s.io/apiserver/pkg/server/filters/with_early_late_annotations_test.go b/staging/src/k8s.io/apiserver/pkg/server/filters/with_early_late_annotations_test.go index 152a5c377dea3..b35eb361e4bcd 100644 --- a/staging/src/k8s.io/apiserver/pkg/server/filters/with_early_late_annotations_test.go +++ b/staging/src/k8s.io/apiserver/pkg/server/filters/with_early_late_annotations_test.go @@ -257,9 +257,7 @@ func TestWithStartupEarlyAnnotation(t *testing.T) { if ac == nil { t.Fatalf("expected audit context inside the request context") } - ac.Event = auditinternal.Event{ - Level: auditinternal.LevelMetadata, - } + ac.Init(audit.RequestAuditConfig{Level: auditinternal.LevelMetadata}, nil) w := httptest.NewRecorder() w.Code = 0 @@ -275,11 +273,11 @@ func TestWithStartupEarlyAnnotation(t *testing.T) { key := "apiserver.k8s.io/startup" switch { case len(test.annotationExpected) == 0: - if valueGot, ok := ac.Event.Annotations[key]; ok { + if valueGot, ok := ac.GetEventAnnotation(key); ok { t.Errorf("did not expect annotation to be added, but got: %s", valueGot) } default: - if valueGot, ok := ac.Event.Annotations[key]; !ok || test.annotationExpected != valueGot { + if valueGot, ok := ac.GetEventAnnotation(key); !ok || test.annotationExpected != valueGot { t.Errorf("expected annotation: %s, but got: %s", test.annotationExpected, valueGot) } } diff --git a/staging/src/k8s.io/apiserver/pkg/util/x509metrics/server_cert_deprecations_test.go b/staging/src/k8s.io/apiserver/pkg/util/x509metrics/server_cert_deprecations_test.go index dfdd565b25c80..eaa17bcf8df18 100644 --- a/staging/src/k8s.io/apiserver/pkg/util/x509metrics/server_cert_deprecations_test.go +++ b/staging/src/k8s.io/apiserver/pkg/util/x509metrics/server_cert_deprecations_test.go @@ -30,7 +30,6 @@ import ( "testing" "github.com/stretchr/testify/require" - auditapi "k8s.io/apiserver/pkg/apis/audit" "k8s.io/apiserver/pkg/audit" "k8s.io/component-base/metrics" "k8s.io/component-base/metrics/testutil" @@ -247,15 +246,14 @@ func TestCheckForHostnameError(t *testing.T) { } req = req.WithContext(audit.WithAuditContext(req.Context())) auditCtx := audit.AuditContextFrom(req.Context()) - auditCtx.Event.Level = auditapi.LevelMetadata _, err = client.Transport.RoundTrip(req) if sanChecker.CheckRoundTripError(err) { sanChecker.IncreaseMetricsCounter(req) - - if len(auditCtx.Event.Annotations["missing-san.invalid-cert.kubernetes.io/"+req.URL.Hostname()]) == 0 { - t.Errorf("expected audit annotations, got %#v", auditCtx.Event.Annotations) + annotations := auditCtx.GetEventAnnotations() + if len(annotations["missing-san.invalid-cert.kubernetes.io/"+req.URL.Hostname()]) == 0 { + t.Errorf("expected audit annotations, got %#v", annotations) } } @@ -390,7 +388,6 @@ func TestCheckForInsecureAlgorithmError(t *testing.T) { } req = req.WithContext(audit.WithAuditContext(req.Context())) auditCtx := audit.AuditContextFrom(req.Context()) - auditCtx.Event.Level = auditapi.LevelMetadata // can't use tlsServer.Client() as it contains the server certificate // in tls.Config.Certificates. The signatures are, however, only checked @@ -414,9 +411,9 @@ func TestCheckForInsecureAlgorithmError(t *testing.T) { if sha1checker.CheckRoundTripError(err) { sha1checker.IncreaseMetricsCounter(req) - - if len(auditCtx.Event.Annotations["insecure-sha1.invalid-cert.kubernetes.io/"+req.URL.Hostname()]) == 0 { - t.Errorf("expected audit annotations, got %#v", auditCtx.Event.Annotations) + annotations := auditCtx.GetEventAnnotations() + if len(annotations["insecure-sha1.invalid-cert.kubernetes.io/"+req.URL.Hostname()]) == 0 { + t.Errorf("expected audit annotations, got %#v", annotations) } }