Skip to content

Commit f580eef

Browse files
committed
discovery: Add eks audit log tests
Add tests for `eksAuditLogWatcher` and `eksAuditLogFetcher`. Copy the grpc stream testing util from the access graph repo into teleport as it is useful for the bidirectional streaming methods uses by access graph, and makes it easier to test on the client side.
1 parent 53da2e5 commit f580eef

File tree

3 files changed

+547
-0
lines changed

3 files changed

+547
-0
lines changed
Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
package discovery
2+
3+
import (
4+
"context"
5+
"errors"
6+
"fmt"
7+
"log/slog"
8+
"testing"
9+
"testing/synctest"
10+
"time"
11+
12+
"github.com/aws/aws-sdk-go-v2/aws"
13+
cwltypes "github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs/types"
14+
"github.com/stretchr/testify/require"
15+
16+
accessgraphv1alpha "github.com/gravitational/teleport/gen/proto/go/accessgraph/v1alpha"
17+
"github.com/gravitational/teleport/lib/utils/testutils/grpctest"
18+
)
19+
20+
type eksAuditLogFetcherFixture struct {
21+
ctx context.Context
22+
cancel context.CancelFunc
23+
server kalsServer
24+
fetcherErr error
25+
cluster *accessgraphv1alpha.AWSEKSClusterV1
26+
fakeLogFetcher *fakeCloudWatchLogFetcher
27+
}
28+
29+
// Start the fixture. Must be called inside synctest bubble.
30+
func (f *eksAuditLogFetcherFixture) Start(t *testing.T) {
31+
t.Helper()
32+
33+
f.ctx, f.cancel = context.WithCancel(t.Context())
34+
tester := grpctest.NewGRPCTester[kalsRequest, kalsResponse](f.ctx)
35+
f.server = tester.NewServerStream()
36+
logger := slog.New(slog.DiscardHandler)
37+
f.fakeLogFetcher = newFakeCloudWatchLogFetcher()
38+
f.cluster = &accessgraphv1alpha.AWSEKSClusterV1{
39+
Name: "cluster-name",
40+
Arn: "cluster-arn",
41+
}
42+
logFetcher := newEKSAuditLogFetcher(f.fakeLogFetcher, f.cluster, tester.NewClientStream(), logger)
43+
go func() { f.fetcherErr = logFetcher.Run(f.ctx) }()
44+
}
45+
46+
// End the fixture. Must be called inside synctest bubble.
47+
func (f *eksAuditLogFetcherFixture) End(t *testing.T) {
48+
t.Helper()
49+
f.cancel()
50+
synctest.Wait()
51+
require.ErrorIs(t, f.fetcherErr, context.Canceled)
52+
}
53+
54+
func (f *eksAuditLogFetcherFixture) testInitializeNewStream(t *testing.T) {
55+
t.Helper()
56+
57+
// Wait for a NewStream action, and verify it contains what we expect
58+
msg, err := f.server.Recv()
59+
require.NoError(t, err)
60+
newStream := msg.GetNewStream()
61+
require.NotNil(t, newStream)
62+
cursor := newStream.GetInitial()
63+
require.NotNil(t, cursor)
64+
require.Equal(t, accessgraphv1alpha.KubeAuditLogCursor_KUBE_AUDIT_LOG_SOURCE_EKS, cursor.GetLogSource())
65+
require.Equal(t, f.cluster.GetArn(), cursor.GetClusterId())
66+
67+
// Send back a ResumeState
68+
err = f.server.Send(newKubeAuditLogResponseResumeState(cursor))
69+
require.NoError(t, err)
70+
}
71+
72+
// TestEKSAuditLogFetcher_NewStream_Unknown tests that when a new log stream
73+
// is set up for a cluster, logs start being fetched from the cursor returned
74+
// by the grpc service.
75+
func TestEKSAuditLogFetcher_NewStream(t *testing.T) {
76+
synctest.Test(t, func(t *testing.T) {
77+
f := &eksAuditLogFetcherFixture{}
78+
f.Start(t)
79+
f.testInitializeNewStream(t)
80+
f.End(t)
81+
})
82+
}
83+
84+
func TestEKSAuditLogFetcher_Batching(t *testing.T) {
85+
synctest.Test(t, func(t *testing.T) {
86+
startTime := time.Now().UTC()
87+
logEpoch := startTime.Add(-7 * 24 * time.Hour)
88+
f := &eksAuditLogFetcherFixture{}
89+
f.Start(t)
90+
f.testInitializeNewStream(t)
91+
92+
f.fakeLogFetcher.events <- nil
93+
// Wait for a polling loop to occur. As there are no logs left,
94+
// the time should now be the synctest epoch plus the poll interval
95+
time.Sleep(logPollInterval)
96+
synctest.Wait()
97+
require.Equal(t, startTime.Add(logPollInterval), time.Now().UTC())
98+
99+
// Wait for an Events action with the log listed. Verify the log and cursor.
100+
f.fakeLogFetcher.events <- []cwltypes.FilteredLogEvent{
101+
makeEvent(logEpoch, 0, "{}"),
102+
makeEvent(logEpoch.Add(time.Second), 1, `{"log": "value"}`),
103+
}
104+
msg, err := f.server.Recv()
105+
require.NoError(t, err)
106+
events := msg.GetEvents()
107+
require.NotNil(t, events)
108+
require.Len(t, events.GetEvents(), 2)
109+
require.Len(t, events.GetEvents()[0].GetFields(), 0)
110+
require.Len(t, events.GetEvents()[1].GetFields(), 1)
111+
cursor := events.GetCursor()
112+
require.NotNil(t, cursor)
113+
require.Equal(t, accessgraphv1alpha.KubeAuditLogCursor_KUBE_AUDIT_LOG_SOURCE_EKS, cursor.GetLogSource())
114+
require.Equal(t, f.cluster.GetArn(), cursor.GetClusterId())
115+
require.Equal(t, "event-id-1", cursor.GetEventId())
116+
require.Equal(t, logEpoch.Add(time.Second), cursor.GetLastEventTime().AsTime())
117+
118+
f.fakeLogFetcher.events <- []cwltypes.FilteredLogEvent{
119+
makeEvent(logEpoch.Add(time.Second), 2, `{"log": "value2"}`),
120+
makeEvent(logEpoch.Add(2*time.Second), 3, `{}`),
121+
}
122+
msg, err = f.server.Recv()
123+
require.NoError(t, err)
124+
events = msg.GetEvents()
125+
require.NotNil(t, events)
126+
require.Len(t, events.GetEvents(), 2)
127+
require.Len(t, events.GetEvents()[0].GetFields(), 1)
128+
require.Len(t, events.GetEvents()[1].GetFields(), 0)
129+
cursor = events.GetCursor()
130+
require.NotNil(t, cursor)
131+
require.Equal(t, accessgraphv1alpha.KubeAuditLogCursor_KUBE_AUDIT_LOG_SOURCE_EKS, cursor.GetLogSource())
132+
require.Equal(t, f.cluster.GetArn(), cursor.GetClusterId())
133+
require.Equal(t, "event-id-3", cursor.GetEventId())
134+
require.Equal(t, logEpoch.Add(2*time.Second), cursor.GetLastEventTime().AsTime())
135+
136+
f.End(t)
137+
})
138+
}
139+
140+
func TestEKSAuditLogFetcher_ContinueOnError(t *testing.T) {
141+
synctest.Test(t, func(t *testing.T) {
142+
startTime := time.Now().UTC()
143+
logEpoch := startTime.Add(-7 * 24 * time.Hour)
144+
f := &eksAuditLogFetcherFixture{}
145+
f.Start(t)
146+
f.testInitializeNewStream(t)
147+
148+
f.fakeLogFetcher.err <- errors.New("oh noes. something went wrong")
149+
// Wait for a polling loop to occur. As there are no logs left,
150+
// the time should now be the synctest epoch plus the poll interval
151+
time.Sleep(logPollInterval)
152+
synctest.Wait()
153+
require.Equal(t, startTime.Add(logPollInterval), time.Now().UTC())
154+
155+
// Wait for an Events action with the log listed. Verify the log and cursor.
156+
f.fakeLogFetcher.events <- []cwltypes.FilteredLogEvent{
157+
makeEvent(logEpoch, 0, "{}"),
158+
makeEvent(logEpoch.Add(time.Second), 1, `{"log": "value"}`),
159+
}
160+
msg, err := f.server.Recv()
161+
require.NoError(t, err)
162+
events := msg.GetEvents()
163+
require.NotNil(t, events)
164+
require.Len(t, events.GetEvents(), 2)
165+
require.Len(t, events.GetEvents()[0].GetFields(), 0)
166+
require.Len(t, events.GetEvents()[1].GetFields(), 1)
167+
cursor := events.GetCursor()
168+
require.NotNil(t, cursor)
169+
require.Equal(t, accessgraphv1alpha.KubeAuditLogCursor_KUBE_AUDIT_LOG_SOURCE_EKS, cursor.GetLogSource())
170+
require.Equal(t, f.cluster.GetArn(), cursor.GetClusterId())
171+
require.Equal(t, "event-id-1", cursor.GetEventId())
172+
require.Equal(t, logEpoch.Add(time.Second), cursor.GetLastEventTime().AsTime())
173+
174+
f.End(t)
175+
})
176+
}
177+
178+
func newKubeAuditLogResponseResumeState(cursor *accessgraphv1alpha.KubeAuditLogCursor) *kalsResponse {
179+
return &kalsResponse{
180+
State: &accessgraphv1alpha.KubeAuditLogStreamResponse_ResumeState{
181+
ResumeState: &accessgraphv1alpha.KubeAuditLogResumeState{
182+
Cursor: cursor,
183+
},
184+
},
185+
}
186+
}
187+
188+
func makeEvent(t time.Time, id int, msg string) cwltypes.FilteredLogEvent {
189+
return cwltypes.FilteredLogEvent{
190+
EventId: aws.String(fmt.Sprintf("event-id-%d", id)),
191+
IngestionTime: aws.Int64(t.UnixMilli()),
192+
Timestamp: aws.Int64(t.UnixMilli()),
193+
LogStreamName: aws.String("kube-apiserver-audit-12345678"),
194+
Message: aws.String(msg),
195+
}
196+
}
197+
198+
func newFakeCloudWatchLogFetcher() *fakeCloudWatchLogFetcher {
199+
return &fakeCloudWatchLogFetcher{
200+
events: make(chan []cwltypes.FilteredLogEvent),
201+
err: make(chan error),
202+
}
203+
}
204+
205+
// fakeCloudWatchLogFetcher is a cloudwatch log fetcher that waits on channels
206+
// for the data to return. This allows the unit under test to rendezvous with
207+
// the tests, allowing the tests to advance the state of the fetcher as it
208+
// needs.
209+
type fakeCloudWatchLogFetcher struct {
210+
events chan []cwltypes.FilteredLogEvent
211+
err chan error
212+
}
213+
214+
func (f *fakeCloudWatchLogFetcher) FetchEKSAuditLogs(
215+
ctx context.Context,
216+
cluster *accessgraphv1alpha.AWSEKSClusterV1,
217+
cursor *accessgraphv1alpha.KubeAuditLogCursor,
218+
) ([]cwltypes.FilteredLogEvent, error) {
219+
select {
220+
case events := <-f.events:
221+
return events, nil
222+
case err := <-f.err:
223+
return nil, err
224+
case <-ctx.Done():
225+
return nil, ctx.Err()
226+
}
227+
}

0 commit comments

Comments
 (0)