Skip to content

Commit f031c9d

Browse files
committed
improve unit test code
1 parent 665b953 commit f031c9d

File tree

1 file changed

+106
-112
lines changed

1 file changed

+106
-112
lines changed

collector/internal/telemetryapi/listener_test.go

Lines changed: 106 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,64 @@ import (
2424
"testing"
2525
"time"
2626

27-
"github.com/golang-collections/go-datastructures/queue"
27+
"github.com/stretchr/testify/assert"
2828
"github.com/stretchr/testify/require"
2929
"go.uber.org/zap/zaptest"
3030
)
3131

32-
func WithEnv(t *testing.T, key, value string) {
32+
func withEnv(t *testing.T, key, value string) {
3333
t.Helper()
3434
require.NoError(t, os.Setenv(key, value))
3535
t.Cleanup(func() {
3636
require.NoError(t, os.Unsetenv(key))
3737
})
3838
}
3939

40+
func setupListener(t *testing.T) (*Listener, string) {
41+
t.Helper()
42+
withEnv(t, "AWS_SAM_LOCAL", "true")
43+
logger := zaptest.NewLogger(t)
44+
listener := NewListener(logger)
45+
46+
address, err := listener.Start()
47+
require.NoError(t, err)
48+
49+
t.Cleanup(func() {
50+
listener.Shutdown()
51+
})
52+
53+
return listener, address
54+
}
55+
56+
func submitEvents(t *testing.T, address string, events []Event) {
57+
t.Helper()
58+
body, err := json.Marshal(events)
59+
require.NoError(t, err)
60+
61+
resp, err := http.Post(address, "application/json", bytes.NewReader(body))
62+
require.NoError(t, err)
63+
require.NoError(t, resp.Body.Close())
64+
}
65+
66+
func assertWaitBlocks(t *testing.T, waitDone <-chan error, timeout time.Duration) {
67+
t.Helper()
68+
select {
69+
case err := <-waitDone:
70+
t.Fatalf("Wait() unexpectedly completed with error: %v", err)
71+
case <-time.After(timeout):
72+
}
73+
}
74+
75+
func assertWaitCompletes(t *testing.T, waitDone <-chan error, timeout time.Duration) {
76+
t.Helper()
77+
select {
78+
case err := <-waitDone:
79+
require.NoError(t, err)
80+
case <-time.After(timeout):
81+
t.Fatal("Wait() timed out")
82+
}
83+
}
84+
4085
type TestEventBuilder struct {
4186
requestID string
4287
timestamp time.Time
@@ -87,21 +132,10 @@ func TestNewListener(t *testing.T) {
87132
logger := zaptest.NewLogger(t)
88133
listener := NewListener(logger)
89134

90-
if listener == nil {
91-
t.Fatal("NewListener returned nil")
92-
}
93-
94-
if listener.httpServer != nil {
95-
t.Error("httpServer should initialized to nil")
96-
}
97-
98-
if listener.logger == nil {
99-
t.Error("logger should not be nil")
100-
}
101-
102-
if listener.queue == nil {
103-
t.Error("queue should not be nil")
104-
}
135+
require.NotNil(t, listener, "NewListener() returned nil listener")
136+
require.Nil(t, listener.httpServer, "httpServer should be initially nil")
137+
require.NotNil(t, listener.logger, "logger should not be nil")
138+
require.NotNil(t, listener.queue, "queue should not be nil")
105139
}
106140

107141
func TestListenOnAddress(t *testing.T) {
@@ -148,15 +182,7 @@ func TestListenOnAddress(t *testing.T) {
148182
}
149183

150184
func TestListener_StartAndShutdown(t *testing.T) {
151-
WithEnv(t, "AWS_SAM_LOCAL", "true")
152-
logger := zaptest.NewLogger(t)
153-
listener := NewListener(logger)
154-
155-
address, err := listener.Start()
156-
if err != nil {
157-
t.Fatalf("Failed to start listener: %v", err)
158-
}
159-
185+
listener, address := setupListener(t)
160186
require.NotEqual(t, address, "", "Start() should not return an empty address")
161187
require.True(t, strings.HasPrefix(address, "http://"), "Address should start with http://")
162188
require.NotNil(t, listener.httpServer, "httpServer should not be nil")
@@ -175,26 +201,17 @@ func TestListener_StartAndShutdown(t *testing.T) {
175201
func TestListener_Shutdown_NotStarted(t *testing.T) {
176202
logger := zaptest.NewLogger(t)
177203
listener := NewListener(logger)
178-
179204
listener.Shutdown()
180-
181205
require.Nil(t, listener.httpServer, "httpServer should be nil after Shutdown()")
182206
}
183207

184208
func TestListener_httpHandler(t *testing.T) {
185-
WithEnv(t, "AWS_SAM_LOCAL", "true")
186-
logger := zaptest.NewLogger(t)
187-
listener := NewListener(logger)
188-
189-
address, err := listener.Start()
190-
require.NoError(t, err, "Failed to start listener: %v", err)
191-
defer listener.Shutdown()
192209
eventBuilder := NewTestEventBuilder("test-request")
193210

194211
testCases := []struct {
195212
name string
196213
events []Event
197-
expectedCount int
214+
expectedCount int64
198215
}{
199216
{
200217
name: "single event",
@@ -223,32 +240,17 @@ func TestListener_httpHandler(t *testing.T) {
223240

224241
for _, test := range testCases {
225242
t.Run(test.name, func(t *testing.T) {
226-
listener.queue.Dispose()
227-
listener.queue = queue.New(initialQueueSize)
228-
229-
body, err := json.Marshal(test.events)
230-
require.NoError(t, err, "Failed to marshal events: %v", err)
231-
232-
resp, err := http.Post(address, "application/json", bytes.NewReader(body))
233-
require.NoError(t, err, "Failed to post events: %v", err)
234-
require.NoError(t, resp.Body.Close())
235-
236-
deadline := time.Now().Add(1 * time.Second)
237-
for time.Now().Before(deadline) {
238-
queueLen := listener.queue.Len()
239-
if queueLen == int64(test.expectedCount) {
240-
return
241-
}
242-
time.Sleep(10 * time.Millisecond)
243-
}
244-
queueLen := listener.queue.Len()
245-
require.Equal(t, test.expectedCount, queueLen, "Event queue length does not match")
243+
listener, address := setupListener(t)
244+
submitEvents(t, address, test.events)
245+
require.EventuallyWithT(t, func(c *assert.CollectT) {
246+
require.Equal(c, test.expectedCount, listener.queue.Len())
247+
}, 1*time.Second, 50*time.Millisecond)
246248
})
247249
}
248250
}
249251

250252
func TestListener_httpHandler_InvalidJSON(t *testing.T) {
251-
WithEnv(t, "AWS_SAM_LOCAL", "true")
253+
withEnv(t, "AWS_SAM_LOCAL", "true")
252254
logger := zaptest.NewLogger(t)
253255
listener := NewListener(logger)
254256

@@ -266,66 +268,58 @@ func TestListener_httpHandler_InvalidJSON(t *testing.T) {
266268
}
267269

268270
func TestListener_Wait_Success(t *testing.T) {
269-
WithEnv(t, "AWS_SAM_LOCAL", "true")
270-
logger := zaptest.NewLogger(t)
271-
listener := NewListener(logger)
272-
273-
address, err := listener.Start()
274-
require.NoError(t, err, "Failed to start listener: %v", err)
275-
defer listener.Shutdown()
276271
eventBuilder := NewTestEventBuilder("target-request")
277272

278-
events := []Event{
279-
eventBuilder.PlatformStart(),
280-
eventBuilder.FunctionLog("INFO", "Received request"),
281-
eventBuilder.FunctionLog("INFO", "Processing request"),
282-
eventBuilder.FunctionLog("INFO", "Finished processing request"),
283-
eventBuilder.PlatformRuntimeDone(),
273+
testCases := []struct {
274+
name string
275+
events []Event
276+
}{
277+
{
278+
name: "simple request",
279+
events: []Event{
280+
eventBuilder.PlatformStart(),
281+
eventBuilder.FunctionLog("INFO", "Received request"),
282+
eventBuilder.FunctionLog("INFO", "Processing request"),
283+
eventBuilder.FunctionLog("INFO", "Finished processing request"),
284+
eventBuilder.PlatformRuntimeDone(),
285+
},
286+
},
287+
{
288+
name: "skips wrong request id",
289+
events: []Event{
290+
NewTestEventBuilder("other-request-1").PlatformRuntimeDone(),
291+
eventBuilder.PlatformStart(),
292+
eventBuilder.FunctionLog("INFO", "Received request"),
293+
NewTestEventBuilder("other-request-2").PlatformRuntimeDone(),
294+
eventBuilder.FunctionLog("INFO", "Processing request"),
295+
eventBuilder.FunctionLog("INFO", "Finished processing request"),
296+
NewTestEventBuilder("other-request-3").PlatformRuntimeDone(),
297+
eventBuilder.PlatformRuntimeDone(),
298+
},
299+
},
284300
}
285301

286-
body, err := json.Marshal(events)
287-
require.NoError(t, err, "Failed to marshal events: %v", err)
288-
289-
resp, err := http.Post(address, "application/json", bytes.NewReader(body))
290-
require.NoError(t, err, "Failed to post events: %v", err)
291-
require.NoError(t, resp.Body.Close())
292-
293-
ctx := context.Background()
294-
waitErr := listener.Wait(ctx, "target-request")
295-
require.NoError(t, waitErr, "Failed to wait for target-req")
296-
}
297-
298-
func TestListener_Wait_SkipsWrongRequestId(t *testing.T) {
299-
WithEnv(t, "AWS_SAM_LOCAL", "true")
300-
logger := zaptest.NewLogger(t)
301-
listener := NewListener(logger)
302-
303-
address, err := listener.Start()
304-
require.NoError(t, err, "Failed to start listener: %v", err)
305-
defer listener.Shutdown()
306-
eventBuilder := NewTestEventBuilder("target-request")
307-
308-
events := []Event{
309-
NewTestEventBuilder("other-request-1").PlatformRuntimeDone(),
310-
eventBuilder.PlatformStart(),
311-
eventBuilder.FunctionLog("INFO", "Received request"),
312-
NewTestEventBuilder("other-request-2").PlatformRuntimeDone(),
313-
eventBuilder.FunctionLog("INFO", "Processing request"),
314-
eventBuilder.FunctionLog("INFO", "Finished processing request"),
315-
NewTestEventBuilder("other-request-3").PlatformRuntimeDone(),
316-
eventBuilder.PlatformRuntimeDone(),
302+
for _, test := range testCases {
303+
t.Run(test.name, func(t *testing.T) {
304+
listener, address := setupListener(t)
305+
306+
waitDone := make(chan error, 1)
307+
go func() {
308+
ctx := context.Background()
309+
waitDone <- listener.Wait(ctx, "target-request")
310+
}()
311+
312+
assertWaitBlocks(t, waitDone, 50*time.Millisecond)
313+
for i, event := range test.events {
314+
submitEvents(t, address, []Event{event})
315+
if i < len(test.events)-1 {
316+
assertWaitBlocks(t, waitDone, 50*time.Millisecond)
317+
} else {
318+
assertWaitCompletes(t, waitDone, 1*time.Second)
319+
}
320+
}
321+
})
317322
}
318-
319-
body, err := json.Marshal(events)
320-
require.NoError(t, err, "Failed to marshal events: %v", err)
321-
322-
resp, err := http.Post(address, "application/json", bytes.NewReader(body))
323-
require.NoError(t, err, "Failed to post events: %v", err)
324-
require.NoError(t, resp.Body.Close())
325-
326-
ctx := context.Background()
327-
waitErr := listener.Wait(ctx, "target-request")
328-
require.NoError(t, waitErr, "Failed to wait for target-request")
329323
}
330324

331325
func TestListener_Wait_ContextCanceled(t *testing.T) {

0 commit comments

Comments
 (0)