Skip to content

Commit b628cd9

Browse files
authored
Improve operation collector logic (#39)
Encapsulate go routine management during operation collection. Leverage wait groups and channel lifecycles to simplify collector interface. Track start time in collector to simplify state management for operation poller. Add unit tests for operation collector.
1 parent 645fdf5 commit b628cd9

File tree

5 files changed

+105
-38
lines changed

5 files changed

+105
-38
lines changed

pkg/cloudmap/api.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,6 @@ func (sdApi *serviceDiscoveryApi) ListServices(ctx context.Context, nsId string)
9393
Name: types.ServiceFilterNameNamespaceId,
9494
Values: []string{nsId},
9595
}
96-
sdApi.log.Info("paginating", "nsId", nsId)
9796

9897
pages := sd.NewListServicesPaginator(sdApi.awsFacade, &sd.ListServicesInput{Filters: []types.ServiceFilter{filter}})
9998

pkg/cloudmap/client.go

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -174,17 +174,15 @@ func (sdc *serviceDiscoveryClient) RegisterEndpoints(ctx context.Context, servic
174174
return err
175175
}
176176

177-
startTime := Now()
178-
opCollector := NewOperationCollector(len(service.Endpoints))
177+
opCollector := NewOperationCollector()
179178

180179
for _, endpt := range service.Endpoints {
181-
go func(endpt *model.Endpoint) {
182-
opId, endptErr := sdc.sdApi.RegisterInstance(ctx, svcId, endpt.Id, endpt.GetCloudMapAttributes())
183-
opCollector.Add(endpt.Id, opId, endptErr)
184-
}(endpt)
180+
opCollector.Add(func() (opId string, err error) {
181+
return sdc.sdApi.RegisterInstance(ctx, svcId, endpt.Id, endpt.GetCloudMapAttributes())
182+
})
185183
}
186184

187-
err = NewRegisterInstancePoller(sdc.sdApi, svcId, opCollector.Collect(), startTime).Poll(ctx)
185+
err = NewRegisterInstancePoller(sdc.sdApi, svcId, opCollector.Collect(), opCollector.GetStartTime()).Poll(ctx)
188186

189187
// Evict cache entry so next list call reflects changes
190188
sdc.evictEndpoints(svcId)
@@ -214,17 +212,15 @@ func (sdc *serviceDiscoveryClient) DeleteEndpoints(ctx context.Context, service
214212
return err
215213
}
216214

217-
startTime := Now()
218-
opCollector := NewOperationCollector(len(service.Endpoints))
215+
opCollector := NewOperationCollector()
219216

220217
for _, endpt := range service.Endpoints {
221-
go func(endpt *model.Endpoint) {
222-
opId, endptErr := sdc.sdApi.DeregisterInstance(ctx, svcId, endpt.Id)
223-
opCollector.Add(endpt.Id, opId, endptErr)
224-
}(endpt)
218+
opCollector.Add(func() (opId string, err error) {
219+
return sdc.sdApi.DeregisterInstance(ctx, svcId, endpt.Id)
220+
})
225221
}
226222

227-
err = NewDeregisterInstancePoller(sdc.sdApi, svcId, opCollector.Collect(), startTime).Poll(ctx)
223+
err = NewDeregisterInstancePoller(sdc.sdApi, svcId, opCollector.Collect(), opCollector.GetStartTime()).Poll(ctx)
228224

229225
// Evict cache entry so next list call reflects changes
230226
sdc.evictEndpoints(svcId)

pkg/cloudmap/operation_collector.go

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,52 +3,66 @@ package cloudmap
33
import (
44
"github.com/go-logr/logr"
55
ctrl "sigs.k8s.io/controller-runtime"
6+
"sync"
67
)
78

8-
// OperationCollector collects a list operations with thread safety.
9+
// OperationCollector collects a list of operation IDs asynchronously with thread safety.
910
type OperationCollector interface {
10-
// Add an operation to poll with thread safety, to be called from within a go routine.
11-
Add(endpointId string, operationId string, operationError error)
11+
// Add calls an operation provider function to asynchronously collect operations to poll.
12+
Add(operationProvider func() (operationId string, err error))
1213

13-
// Collect waits for all operations to be added and returns a list of successfully created operation IDs.
14+
// Collect waits for all create operation results to be provided and returns a list of the successfully created operation IDs.
1415
Collect() []string
1516

17+
// GetStartTime returns the start time range to poll the collected operations.
18+
GetStartTime() int64
19+
1620
// IsAllOperationsCreated returns true if all operations were created successfully.
1721
IsAllOperationsCreated() bool
1822
}
1923

2024
type opCollector struct {
2125
log logr.Logger
2226
opChan chan opResult
23-
opCount int
27+
wg sync.WaitGroup
28+
startTime int64
2429
createOpsSuccess bool
2530
}
2631

2732
type opResult struct {
28-
instId string
29-
opId string
30-
err error
33+
opId string
34+
err error
3135
}
3236

33-
func NewOperationCollector(opCount int) OperationCollector {
37+
func NewOperationCollector() OperationCollector {
3438
return &opCollector{
3539
log: ctrl.Log.WithName("cloudmap"),
3640
opChan: make(chan opResult),
37-
opCount: opCount,
41+
startTime: Now(),
3842
createOpsSuccess: true,
3943
}
4044
}
4145

42-
func (opColl *opCollector) Add(endptId string, opId string, opErr error) {
43-
opColl.opChan <- opResult{endptId, opId, opErr}
46+
func (opColl *opCollector) Add(opProvider func() (opId string, err error)) {
47+
opColl.wg.Add(1)
48+
go func() {
49+
defer opColl.wg.Done()
50+
51+
opId, opErr := opProvider()
52+
opColl.opChan <- opResult{opId, opErr}
53+
}()
4454
}
4555

4656
func (opColl *opCollector) Collect() []string {
4757
opIds := make([]string, 0)
4858

49-
for i := 0; i < opColl.opCount; i++ {
50-
op := <-opColl.opChan
59+
// Run wait in separate go routine to unblock reading from the channel.
60+
go func() {
61+
opColl.wg.Wait()
62+
close(opColl.opChan)
63+
}()
5164

65+
for op := range opColl.opChan {
5266
if op.err != nil {
5367
opColl.log.Info("could not create operation", "error", op.err)
5468
opColl.createOpsSuccess = false
@@ -61,6 +75,10 @@ func (opColl *opCollector) Collect() []string {
6175
return opIds
6276
}
6377

78+
func (opColl *opCollector) GetStartTime() int64 {
79+
return opColl.startTime
80+
}
81+
6482
func (opColl *opCollector) IsAllOperationsCreated() bool {
6583
return opColl.createOpsSuccess
6684
}
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
package cloudmap
2+
3+
import (
4+
"errors"
5+
"github.com/stretchr/testify/assert"
6+
"testing"
7+
"time"
8+
)
9+
10+
func TestOpCollector_HappyCase(t *testing.T) {
11+
oc := NewOperationCollector()
12+
oc.Add(func() (opId string, err error) { return "one", nil })
13+
oc.Add(func() (opId string, err error) { return "two", nil })
14+
15+
result := oc.Collect()
16+
assert.True(t, oc.IsAllOperationsCreated())
17+
assert.Equal(t, 2, len(result))
18+
assert.Contains(t, result, "one")
19+
assert.Contains(t, result, "two")
20+
}
21+
22+
func TestOpCollector_AllFail(t *testing.T) {
23+
oc := NewOperationCollector()
24+
oc.Add(func() (opId string, err error) { return "one", errors.New("fail one") })
25+
oc.Add(func() (opId string, err error) { return "two", errors.New("fail two") })
26+
27+
result := oc.Collect()
28+
assert.False(t, oc.IsAllOperationsCreated())
29+
assert.Equal(t, 0, len(result))
30+
}
31+
32+
func TestOpCollector_MixedSuccess(t *testing.T) {
33+
oc := NewOperationCollector()
34+
oc.Add(func() (opId string, err error) { return "one", errors.New("fail one") })
35+
oc.Add(func() (opId string, err error) { return "two", nil })
36+
37+
result := oc.Collect()
38+
assert.False(t, oc.IsAllOperationsCreated())
39+
assert.Equal(t, []string{"two"}, result)
40+
}
41+
42+
func TestOpCollector_GetStartTime(t *testing.T) {
43+
oc1 := NewOperationCollector()
44+
time.Sleep(time.Second)
45+
oc2 := NewOperationCollector()
46+
47+
assert.Equal(t, oc1.GetStartTime(), oc1.GetStartTime(), "Start time should not change")
48+
assert.NotEqual(t, oc1.GetStartTime(), oc2.GetStartTime(), "Start time should reflect instantiation")
49+
assert.Less(t, oc1.GetStartTime(), oc2.GetStartTime(),
50+
"Start time should increase for later instantiations")
51+
}

pkg/cloudmap/operation_poller.go

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@ type operationPoller struct {
3333

3434
svcId string
3535
opType types.OperationType
36-
start int
36+
start int64
3737
}
3838

39-
func newOperationPoller(sdApi ServiceDiscoveryApi, svcId string, opIds []string, startTime int) operationPoller {
39+
func newOperationPoller(sdApi ServiceDiscoveryApi, svcId string, opIds []string, startTime int64) operationPoller {
4040
return operationPoller{
4141
log: ctrl.Log.WithName("cloudmap"),
4242
sdApi: sdApi,
@@ -48,14 +48,14 @@ func newOperationPoller(sdApi ServiceDiscoveryApi, svcId string, opIds []string,
4848
}
4949

5050
// NewRegisterInstancePoller creates a new operation poller for register instance operations.
51-
func NewRegisterInstancePoller(sdApi ServiceDiscoveryApi, serviceId string, opIds []string, startTime int) OperationPoller {
51+
func NewRegisterInstancePoller(sdApi ServiceDiscoveryApi, serviceId string, opIds []string, startTime int64) OperationPoller {
5252
poller := newOperationPoller(sdApi, serviceId, opIds, startTime)
5353
poller.opType = types.OperationTypeRegisterInstance
5454
return &poller
5555
}
5656

5757
// NewDeregisterInstancePoller creates a new operation poller for de-register instance operations.
58-
func NewDeregisterInstancePoller(sdApi ServiceDiscoveryApi, serviceId string, opIds []string, startTime int) OperationPoller {
58+
func NewDeregisterInstancePoller(sdApi ServiceDiscoveryApi, serviceId string, opIds []string, startTime int64) OperationPoller {
5959
poller := newOperationPoller(sdApi, serviceId, opIds, startTime)
6060
poller.opType = types.OperationTypeDeregisterInstance
6161
return &poller
@@ -124,9 +124,9 @@ func (opPoller *operationPoller) buildFilters() []types.OperationFilter {
124124
Name: types.OperationFilterNameUpdateDate,
125125
Condition: types.FilterConditionBetween,
126126
Values: []string{
127-
strconv.Itoa(opPoller.start),
127+
Itoa(opPoller.start),
128128
// Add one minute to end range in case op updates while list request is in flight
129-
strconv.Itoa(Now() + 60000),
129+
Itoa(Now() + 60000),
130130
},
131131
}
132132

@@ -143,8 +143,11 @@ func (opPoller *operationPoller) getFailedOpReason(ctx context.Context, opId str
143143

144144
return aws.ToString(op.ErrorMessage)
145145
}
146+
func Itoa(i int64) string {
147+
return strconv.FormatInt(i, 10)
148+
}
146149

147-
// Now returns current time with milliseconds, as used by operation UPDATE_DATE field
148-
func Now() int {
149-
return int(time.Now().UnixNano() / 1000000)
150+
// Now returns current time with milliseconds, as used by operation filter UPDATE_DATE field
151+
func Now() int64 {
152+
return time.Now().UnixNano() / 1000000
150153
}

0 commit comments

Comments
 (0)