Skip to content

Commit 840b60d

Browse files
Tenermarcoandrediniscthach
authored
Refactor cloud instance watchers to use generic Fetcher interface (#62210)
* Refactor cloud instance watchers to use generic Fetcher interface * Correct channel initialization order * fix license corruption * remove redundant channel initialization * Refactor watcher to use utils.SyncMap * Increase timeout duration * Simplify setup * remove TODO * Update lib/srv/server/watcher.go Co-authored-by: Marco Dinis <marco.dinis@goteleport.com> * document NewWatcher function in watcher.go * Document SetFetchers and DeleteFetchers methods. * Remove unused parameter from NewWatcher function in EC2 watcher test * Update lib/srv/discovery/discovery_test.go Co-authored-by: Chris Thach <chris.thach@protonmail.com> * Update lib/srv/server/watcher.go Co-authored-by: Chris Thach <chris.thach@protonmail.com> * Update lib/srv/discovery/discovery_test.go Co-authored-by: Chris Thach <chris.thach@protonmail.com> * Use utils.SyncMap instead of *utils.SyncMap * `AzureInstances` -> `*AzureInstances` * `EC2Instances` -> `*EC2Instances` * `GCPInstances` -> `*GCPInstances` * Fix initialization race condition. * post-merge fixes --------- Co-authored-by: Marco Dinis <marco.dinis@goteleport.com> Co-authored-by: Chris Thach <chris.thach@protonmail.com>
1 parent c6fab90 commit 840b60d

File tree

8 files changed

+250
-359
lines changed

8 files changed

+250
-359
lines changed

lib/srv/discovery/discovery.go

Lines changed: 91 additions & 177 deletions
Large diffs are not rendered by default.

lib/srv/discovery/discovery_test.go

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3231,15 +3231,19 @@ func TestAzureVMDiscovery(t *testing.T) {
32313231
emitter.t = t
32323232

32333233
if tc.discoveryConfig != nil {
3234+
sub := server.newDiscoveryConfigChangedSub()
3235+
32343236
_, err := tlsServer.Auth().DiscoveryConfigs.CreateDiscoveryConfig(ctx, tc.discoveryConfig)
32353237
require.NoError(t, err)
32363238

3237-
// Wait for the DiscoveryConfig to be added to the dynamic matchers
3238-
require.Eventually(t, func() bool {
3239-
server.muDynamicServerAzureFetchers.RLock()
3240-
defer server.muDynamicServerAzureFetchers.RUnlock()
3241-
return len(server.dynamicServerAzureFetchers) > 0
3242-
}, 1*time.Second, 50*time.Millisecond)
3239+
// wait for discovery config update
3240+
select {
3241+
case <-sub:
3242+
case <-time.After(3 * time.Second):
3243+
require.Fail(t, "timed out waiting for an update")
3244+
case <-t.Context().Done():
3245+
require.Fail(t, "test context done while waiting for an update")
3246+
}
32433247
}
32443248

32453249
require.NoError(t, server.Start())
@@ -3544,18 +3548,22 @@ func TestGCPVMDiscovery(t *testing.T) {
35443548
emitter.t = t
35453549

35463550
if tc.discoveryConfig != nil {
3551+
sub := server.newDiscoveryConfigChangedSub()
3552+
35473553
_, err := tlsServer.Auth().DiscoveryConfigs.CreateDiscoveryConfig(ctx, tc.discoveryConfig)
35483554
require.NoError(t, err)
35493555

3550-
// Wait for the DiscoveryConfig to be added to the dynamic matchers
3551-
require.Eventually(t, func() bool {
3552-
server.muDynamicServerGCPFetchers.RLock()
3553-
defer server.muDynamicServerGCPFetchers.RUnlock()
3554-
return len(server.dynamicServerGCPFetchers) > 0
3555-
}, 1*time.Second, 100*time.Millisecond)
3556+
// wait for discovery config update
3557+
select {
3558+
case <-sub:
3559+
case <-time.After(3 * time.Second):
3560+
t.Fatal("timed out waiting for channel update")
3561+
case <-t.Context().Done():
3562+
require.Fail(t, "test context done while waiting for an update")
3563+
}
35563564
}
35573565

3558-
go server.Start()
3566+
require.NoError(t, server.Start())
35593567
t.Cleanup(server.Stop)
35603568

35613569
if len(tc.wantInstalledInstances) > 0 {

lib/srv/server/azure_watcher.go

Lines changed: 7 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,10 @@ import (
2222
"context"
2323
"log/slog"
2424
"slices"
25-
"time"
2625

2726
"github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
2827
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6"
2928
"github.com/gravitational/trace"
30-
"github.com/jonboulle/clockwork"
3129

3230
usageeventsv1 "github.com/gravitational/teleport/api/gen/proto/go/usageevents/v1"
3331
"github.com/gravitational/teleport/api/types"
@@ -73,27 +71,9 @@ func (instances *AzureInstances) MakeEvents() map[string]*usageeventsv1.Resource
7371

7472
type azureClientGetter func(ctx context.Context, integration string) (azure.Clients, error)
7573

76-
// NewAzureWatcher creates a new Azure watcher instance.
77-
func NewAzureWatcher(ctx context.Context, fetchersFn func() []Fetcher, opts ...Option) (*Watcher, error) {
78-
cancelCtx, cancelFn := context.WithCancel(ctx)
79-
watcher := Watcher{
80-
fetchersFn: fetchersFn,
81-
ctx: cancelCtx,
82-
cancel: cancelFn,
83-
pollInterval: time.Minute,
84-
clock: clockwork.NewRealClock(),
85-
triggerFetchC: make(<-chan struct{}),
86-
InstancesC: make(chan Instances),
87-
}
88-
for _, opt := range opts {
89-
opt(&watcher)
90-
}
91-
return &watcher, nil
92-
}
93-
9474
// MatchersToAzureInstanceFetchers converts a list of Azure VM Matchers into a list of Azure VM Fetchers.
95-
func MatchersToAzureInstanceFetchers(logger *slog.Logger, matchers []types.AzureMatcher, getClient azureClientGetter, discoveryConfigName string) []Fetcher {
96-
ret := make([]Fetcher, 0)
75+
func MatchersToAzureInstanceFetchers(logger *slog.Logger, matchers []types.AzureMatcher, getClient azureClientGetter, discoveryConfigName string) []Fetcher[*AzureInstances] {
76+
ret := make([]Fetcher[*AzureInstances], 0)
9777
for _, matcher := range matchers {
9878
for _, subscription := range matcher.Subscriptions {
9979
for _, resourceGroup := range matcher.ResourceGroups {
@@ -147,7 +127,7 @@ func newAzureInstanceFetcher(cfg azureFetcherConfig) *azureInstanceFetcher {
147127
}
148128
}
149129

150-
func (*azureInstanceFetcher) GetMatchingInstances(_ context.Context, _ []types.Server, _ bool) ([]Instances, error) {
130+
func (*azureInstanceFetcher) GetMatchingInstances(_ context.Context, _ []types.Server, _ bool) ([]*AzureInstances, error) {
151131
return nil, trace.NotImplemented("not implemented for azure fetchers")
152132
}
153133

@@ -167,7 +147,7 @@ type resourceGroupLocation struct {
167147
}
168148

169149
// GetInstances fetches all Azure virtual machines matching configured filters.
170-
func (f *azureInstanceFetcher) GetInstances(ctx context.Context, _ bool) ([]Instances, error) {
150+
func (f *azureInstanceFetcher) GetInstances(ctx context.Context, _ bool) ([]*AzureInstances, error) {
171151
azureClients, err := f.AzureClientGetter(ctx, f.IntegrationName())
172152
if err != nil {
173153
return nil, trace.Wrap(err)
@@ -228,16 +208,16 @@ func (f *azureInstanceFetcher) GetInstances(ctx context.Context, _ bool) ([]Inst
228208
instancesByRegionAndResourceGroup[batchGroup] = append(instancesByRegionAndResourceGroup[batchGroup], vm)
229209
}
230210

231-
var instances []Instances
211+
var instances []*AzureInstances
232212
for batchGroup, vms := range instancesByRegionAndResourceGroup {
233-
instances = append(instances, Instances{Azure: &AzureInstances{
213+
instances = append(instances, &AzureInstances{
234214
SubscriptionID: f.Subscription,
235215
Region: batchGroup.location,
236216
ResourceGroup: batchGroup.resourceGroup,
237217
Instances: vms,
238218
Integration: f.Integration,
239219
InstallerParams: f.InstallerParams,
240-
}})
220+
})
241221
}
242222

243223
return instances, nil

lib/srv/server/azure_watcher_test.go

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -156,12 +156,15 @@ func TestAzureWatcher(t *testing.T) {
156156
t.Run(tc.name, func(t *testing.T) {
157157
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
158158
t.Cleanup(cancel)
159-
watcher, err := NewAzureWatcher(ctx, func() []Fetcher {
160-
return MatchersToAzureInstanceFetchers(logger, []types.AzureMatcher{tc.matcher}, func(ctx context.Context, integration string) (azure.Clients, error) {
161-
return &clients, nil
162-
}, "" /* discovery config */)
163-
})
164-
require.NoError(t, err)
159+
watcher := NewWatcher[*AzureInstances](ctx)
160+
161+
const noDiscoveryConfig = ""
162+
watcher.SetFetchers(noDiscoveryConfig,
163+
MatchersToAzureInstanceFetchers(logger, []types.AzureMatcher{tc.matcher},
164+
func(ctx context.Context, integration string) (azure.Clients, error) {
165+
return &clients, nil
166+
}, noDiscoveryConfig),
167+
)
165168

166169
go watcher.Run()
167170
t.Cleanup(watcher.Stop)
@@ -171,13 +174,13 @@ func TestAzureWatcher(t *testing.T) {
171174
for len(vmIDs) < len(tc.wantVMs) {
172175
select {
173176
case results := <-watcher.InstancesC:
174-
for _, vm := range results.Azure.Instances {
177+
for _, vm := range results.Instances {
175178
parsedResource, err := arm.ParseResourceID(*vm.ID)
176179
require.NoError(t, err)
177180
vmID := parsedResource.Name
178181
vmIDs = append(vmIDs, vmID)
179182
}
180-
require.NotEqual(t, "*", results.Azure.ResourceGroup)
183+
require.NotEqual(t, "*", results.ResourceGroup)
181184
case <-ctx.Done():
182185
require.Fail(t, "Expected %v VMs, got %v", tc.wantVMs, len(vmIDs))
183186
}

lib/srv/server/ec2_watcher.go

Lines changed: 13 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ import (
2323
"log/slog"
2424
"strings"
2525
"sync"
26-
"time"
2726

2827
"github.com/aws/aws-sdk-go-v2/aws"
2928
"github.com/aws/aws-sdk-go-v2/aws/arn"
@@ -32,7 +31,6 @@ import (
3231
"github.com/aws/aws-sdk-go-v2/service/ec2"
3332
ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types"
3433
"github.com/gravitational/trace"
35-
"github.com/jonboulle/clockwork"
3634

3735
usageeventsv1 "github.com/gravitational/teleport/api/gen/proto/go/usageevents/v1"
3836
"github.com/gravitational/teleport/api/types"
@@ -144,17 +142,6 @@ func (i *EC2Instances) ServerInfos() ([]types.ServerInfo, error) {
144142
return serverInfos, nil
145143
}
146144

147-
// Option is a functional option for the Watcher.
148-
type Option func(*Watcher)
149-
150-
// WithPollInterval sets the interval at which the watcher will fetch
151-
// instances from AWS.
152-
func WithPollInterval(interval time.Duration) Option {
153-
return func(w *Watcher) {
154-
w.pollInterval = interval
155-
}
156-
}
157-
158145
// MakeEvents generates ResourceCreateEvents for these instances.
159146
func (instances *EC2Instances) MakeEvents() map[string]*usageeventsv1.ResourceCreateEvent {
160147
resourceType := types.DiscoveredResourceNode
@@ -180,25 +167,6 @@ func (instances *EC2Instances) MakeEvents() map[string]*usageeventsv1.ResourceCr
180167
return events
181168
}
182169

183-
// NewEC2Watcher creates a new EC2 watcher instance.
184-
func NewEC2Watcher(ctx context.Context, fetchersFn func() []Fetcher, missedRotation <-chan []types.Server, opts ...Option) (*Watcher, error) {
185-
cancelCtx, cancelFn := context.WithCancel(ctx)
186-
watcher := Watcher{
187-
fetchersFn: fetchersFn,
188-
ctx: cancelCtx,
189-
cancel: cancelFn,
190-
clock: clockwork.NewRealClock(),
191-
pollInterval: time.Minute,
192-
triggerFetchC: make(<-chan struct{}),
193-
InstancesC: make(chan Instances),
194-
missedRotation: missedRotation,
195-
}
196-
for _, opt := range opts {
197-
opt(&watcher)
198-
}
199-
return &watcher, nil
200-
}
201-
202170
// EC2ClientGetter gets an AWS EC2 client for the given region.
203171
type EC2ClientGetter func(ctx context.Context, region string, opts ...awsconfig.OptionsFn) (ec2.DescribeInstancesAPIClient, error)
204172

@@ -231,8 +199,8 @@ type MatcherToEC2FetcherParams struct {
231199
}
232200

233201
// MatchersToEC2InstanceFetchers converts a list of AWS EC2 Matchers into a list of AWS EC2 Fetchers.
234-
func MatchersToEC2InstanceFetchers(ctx context.Context, matcherParams MatcherToEC2FetcherParams) ([]Fetcher, error) {
235-
ret := []Fetcher{}
202+
func MatchersToEC2InstanceFetchers(ctx context.Context, matcherParams MatcherToEC2FetcherParams) ([]Fetcher[*EC2Instances], error) {
203+
var ret []Fetcher[*EC2Instances]
236204
for _, matcher := range matcherParams.Matchers {
237205
fetcher := newEC2InstanceFetcher(ec2FetcherConfig{
238206
Matcher: matcher,
@@ -400,7 +368,7 @@ func ssmRunCommandParameters(ctx context.Context, cfg ec2FetcherConfig) (map[str
400368
}
401369

402370
// GetMatchingInstances returns a list of EC2 instances from a list of matching Teleport nodes
403-
func (f *ec2InstanceFetcher) GetMatchingInstances(ctx context.Context, nodes []types.Server, rotation bool) ([]Instances, error) {
371+
func (f *ec2InstanceFetcher) GetMatchingInstances(ctx context.Context, nodes []types.Server, rotation bool) ([]*EC2Instances, error) {
404372
ssmRunParams, err := ssmRunCommandParameters(ctx, f.ec2FetcherConfig)
405373
if err != nil {
406374
return nil, trace.Wrap(err)
@@ -459,12 +427,12 @@ func (f *ec2InstanceFetcher) GetMatchingInstances(ctx context.Context, nodes []t
459427

460428
// chunkInstances splits instances into chunks of 50.
461429
// This is required because SSM SendCommand API calls only accept up to 50 instance IDs at a time.
462-
func chunkInstances(instancesByRegion map[string]EC2Instances) []Instances {
463-
var instColl []Instances
430+
func chunkInstances(instancesByRegion map[string]EC2Instances) []*EC2Instances {
431+
var instColl []*EC2Instances
464432
for _, insts := range instancesByRegion {
465433
for i := 0; i < len(insts.Instances); i += awsEC2APIChunkSize {
466434
end := min(i+awsEC2APIChunkSize, len(insts.Instances))
467-
inst := EC2Instances{
435+
inst := &EC2Instances{
468436
AccountID: insts.AccountID,
469437
Region: insts.Region,
470438
DocumentName: insts.DocumentName,
@@ -474,7 +442,7 @@ func chunkInstances(instancesByRegion map[string]EC2Instances) []Instances {
474442
Integration: insts.Integration,
475443
DiscoveryConfigName: insts.DiscoveryConfigName,
476444
}
477-
instColl = append(instColl, Instances{EC2: &inst})
445+
instColl = append(instColl, inst)
478446
}
479447
}
480448
return instColl
@@ -606,14 +574,14 @@ func (f *ec2InstanceFetcher) allAssumeRoles(ctx context.Context) ([]assumeRoleWi
606574
}
607575

608576
// GetInstances fetches all EC2 instances matching configured filters.
609-
func (f *ec2InstanceFetcher) GetInstances(ctx context.Context, rotation bool) ([]Instances, error) {
577+
func (f *ec2InstanceFetcher) GetInstances(ctx context.Context, rotation bool) ([]*EC2Instances, error) {
610578
ssmRunParams, err := ssmRunCommandParameters(ctx, f.ec2FetcherConfig)
611579
if err != nil {
612580
return nil, trace.Wrap(err)
613581
}
614582

615583
f.cachedInstances.clear()
616-
var allInstances []Instances
584+
var allInstances []*EC2Instances
617585

618586
accountRolesToAssume, err := f.allAssumeRoles(ctx)
619587
if err != nil {
@@ -672,13 +640,13 @@ type getInstancesInRegionParams struct {
672640
}
673641

674642
// getInstancesInRegion fetches all EC2 instances in a given region.
675-
func (f *ec2InstanceFetcher) getInstancesInRegion(ctx context.Context, params getInstancesInRegionParams) ([]Instances, error) {
643+
func (f *ec2InstanceFetcher) getInstancesInRegion(ctx context.Context, params getInstancesInRegionParams) ([]*EC2Instances, error) {
676644
ec2Client, err := f.EC2ClientGetter(ctx, params.region, params.awsOpts...)
677645
if err != nil {
678646
return nil, trace.Wrap(err)
679647
}
680648

681-
var instances []Instances
649+
var instances []*EC2Instances
682650

683651
paginator := ec2.NewDescribeInstancesPaginator(ec2Client, &ec2.DescribeInstancesInput{
684652
Filters: f.Filters,
@@ -694,7 +662,7 @@ func (f *ec2InstanceFetcher) getInstancesInRegion(ctx context.Context, params ge
694662
for i := 0; i < len(res.Instances); i += awsEC2APIChunkSize {
695663
end := min(i+awsEC2APIChunkSize, len(res.Instances))
696664
ownerID := aws.ToString(res.OwnerId)
697-
inst := EC2Instances{
665+
inst := &EC2Instances{
698666
AccountID: ownerID,
699667
Region: params.region,
700668
DocumentName: f.Matcher.SSM.DocumentName,
@@ -710,7 +678,7 @@ func (f *ec2InstanceFetcher) getInstancesInRegion(ctx context.Context, params ge
710678
for _, ec2inst := range res.Instances[i:end] {
711679
f.cachedInstances.add(ownerID, aws.ToString(ec2inst.InstanceId))
712680
}
713-
instances = append(instances, Instances{EC2: &inst})
681+
instances = append(instances, inst)
714682
}
715683
}
716684
}

0 commit comments

Comments
 (0)