diff --git a/collector/internal/telemetryapi/listener.go b/collector/internal/telemetryapi/listener.go index 62a6461cec..742c9915b1 100644 --- a/collector/internal/telemetryapi/listener.go +++ b/collector/internal/telemetryapi/listener.go @@ -17,14 +17,11 @@ package telemetryapi import ( "context" "encoding/json" - "errors" "fmt" "io" - "math/rand" "net" "net/http" "os" - "syscall" "time" "github.com/golang-collections/go-datastructures/queue" @@ -33,17 +30,8 @@ import ( const ( initialQueueSize = 5 - maxRetries = 5 - // Define ephemeral port range (typical range is 49152-65535) - minPort = 49152 - maxPort = 65535 ) -// getRandomPort returns a random port number within the ephemeral range -func getRandomPort() string { - return fmt.Sprintf("%d", rand.Intn(maxPort-minPort)+minPort) -} - // Listener is used to listen to the Telemetry API type Listener struct { httpServer *http.Server @@ -60,46 +48,39 @@ func NewListener(logger *zap.Logger) *Listener { } } -func (s *Listener) tryBindPort() (net.Listener, string, error) { - for i := 0; i < maxRetries; i++ { - port := getRandomPort() - address := listenOnAddress(port) - - l, err := net.Listen("tcp", address) - if err != nil { - if errors.Is(err, syscall.EADDRINUSE) { - s.logger.Debug("Port in use, trying another", - zap.String("address", address)) - continue - } - return nil, "", err - } - return l, address, nil +func (s *Listener) bindListener() (net.Listener, string, error) { + listenerAddr := listenOnAddress() + l, err := net.Listen("tcp", listenerAddr+":0") + if err != nil { + return nil, "", err } - - return nil, "", fmt.Errorf("failed to find available port after %d attempts", maxRetries) + addr := fmt.Sprintf("%s:%d", listenerAddr, l.Addr().(*net.TCPAddr).Port) + return l, addr, nil } -func listenOnAddress(port string) string { +func listenOnAddress() string { envAwsLocal, ok := os.LookupEnv("AWS_SAM_LOCAL") - var addr string if ok && envAwsLocal == "true" { - addr = ":" + port + return "" } else { - addr = "sandbox.localdomain:" + port + return "sandbox.localdomain" } - return addr } // Start the server in a goroutine where the log events will be sent func (s *Listener) Start() (string, error) { - listener, address, err := s.tryBindPort() + listener, address, err := s.bindListener() if err != nil { return "", fmt.Errorf("failed to find available port: %w", err) } s.logger.Info("Listening for requests", zap.String("address", address)) - s.httpServer = &http.Server{Addr: address} - http.HandleFunc("/", s.httpHandler) + mux := http.NewServeMux() + s.httpServer = &http.Server{ + Addr: address, + Handler: mux, + } + mux.HandleFunc("/", s.httpHandler) + go func() { err := s.httpServer.Serve(listener) if err != http.ErrServerClosed { diff --git a/collector/internal/telemetryapi/listener_test.go b/collector/internal/telemetryapi/listener_test.go new file mode 100644 index 0000000000..230bd919cc --- /dev/null +++ b/collector/internal/telemetryapi/listener_test.go @@ -0,0 +1,334 @@ +// Copyright The OpenTelemetry Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package telemetryapi + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "os" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap/zaptest" +) + +func withEnv(t *testing.T, key, value string) { + t.Helper() + require.NoError(t, os.Setenv(key, value)) + t.Cleanup(func() { + require.NoError(t, os.Unsetenv(key)) + }) +} + +func setupListener(t *testing.T) (*Listener, string) { + t.Helper() + withEnv(t, "AWS_SAM_LOCAL", "true") + logger := zaptest.NewLogger(t) + listener := NewListener(logger) + + address, err := listener.Start() + require.NoError(t, err) + + t.Cleanup(func() { + listener.Shutdown() + }) + + return listener, address +} + +func submitEvents(t *testing.T, address string, events []Event) { + t.Helper() + body, err := json.Marshal(events) + require.NoError(t, err) + + resp, err := http.Post(address, "application/json", bytes.NewReader(body)) + require.NoError(t, err) + require.NoError(t, resp.Body.Close()) +} + +func assertWaitBlocks(t *testing.T, waitDone <-chan error, timeout time.Duration) { + t.Helper() + select { + case err := <-waitDone: + t.Fatalf("Wait() unexpectedly completed with error: %v", err) + case <-time.After(timeout): + } +} + +func assertWaitCompletes(t *testing.T, waitDone <-chan error, timeout time.Duration) { + t.Helper() + select { + case err := <-waitDone: + require.NoError(t, err) + case <-time.After(timeout): + t.Fatal("Wait() timed out") + } +} + +type TestEventBuilder struct { + requestID string + timestamp time.Time +} + +func NewTestEventBuilder(requestID string) *TestEventBuilder { + return &TestEventBuilder{ + requestID: requestID, + timestamp: time.Now(), + } +} + +func (b *TestEventBuilder) PlatformStart() Event { + return Event{ + Type: "platform.start", + Time: b.timestamp.Format(time.RFC3339), + Record: map[string]interface{}{ + "requestId": b.requestID, + "version": "$LATEST", + }, + } +} + +func (b *TestEventBuilder) PlatformRuntimeDone() Event { + return Event{ + Type: "platform.runtimeDone", + Time: b.timestamp.Format(time.RFC3339), + Record: map[string]interface{}{ + "requestId": b.requestID, + "status": "success", + }, + } +} + +func (b *TestEventBuilder) FunctionLog(logLevel, message string) Event { + return Event{ + Type: "function", + Time: b.timestamp.Format(time.RFC3339), + Record: map[string]interface{}{ + "requestId": b.requestID, + "type": logLevel, + "message": message, + }, + } +} + +func TestNewListener(t *testing.T) { + logger := zaptest.NewLogger(t) + listener := NewListener(logger) + + require.NotNil(t, listener, "NewListener() returned nil listener") + require.Nil(t, listener.httpServer, "httpServer should be initially nil") + require.NotNil(t, listener.logger, "logger should not be nil") + require.NotNil(t, listener.queue, "queue should not be nil") +} + +func TestListenOnAddress(t *testing.T) { + testCases := []struct { + name string + envValue string + setEnv bool + expectedAddr string + }{ + { + name: "AWS_SAM_LOCAL not set", + setEnv: false, + expectedAddr: "sandbox.localdomain", + }, + { + name: "AWS_SAM_LOCAL set to true", + envValue: "true", + setEnv: true, + expectedAddr: "", + }, + { + name: "AWS_SAM_LOCAL set to false", + envValue: "false", + setEnv: true, + expectedAddr: "sandbox.localdomain", + }, + } + + for _, test := range testCases { + t.Run(test.name, func(t *testing.T) { + require.NoError(t, os.Unsetenv("AWS_SAM_LOCAL")) + + if test.setEnv { + require.NoError(t, os.Setenv("AWS_SAM_LOCAL", test.envValue)) + defer func() { + require.NoError(t, os.Unsetenv("AWS_SAM_LOCAL")) + }() + } + + addr := listenOnAddress() + require.Equal(t, test.expectedAddr, addr) + }) + } +} + +func TestListener_StartAndShutdown(t *testing.T) { + listener, address := setupListener(t) + require.NotEqual(t, address, "", "Start() should not return an empty address") + require.True(t, strings.HasPrefix(address, "http://"), "Address should start with http://") + require.NotNil(t, listener.httpServer, "httpServer should not be nil") + + resp, err := http.Get(address) + if err != nil { + t.Errorf("Failed to connect to listener: %v", err) + } else { + require.NoError(t, resp.Body.Close()) + } + listener.Shutdown() + + require.Nil(t, listener.httpServer, "httpServer should be nil after Shutdown()") +} + +func TestListener_Shutdown_NotStarted(t *testing.T) { + logger := zaptest.NewLogger(t) + listener := NewListener(logger) + listener.Shutdown() + require.Nil(t, listener.httpServer, "httpServer should be nil after Shutdown()") +} + +func TestListener_httpHandler(t *testing.T) { + eventBuilder := NewTestEventBuilder("test-request") + + testCases := []struct { + name string + events []Event + expectedCount int64 + }{ + { + name: "single event", + events: []Event{ + eventBuilder.PlatformStart(), + }, + expectedCount: 1, + }, + { + name: "multiple events", + events: []Event{ + eventBuilder.PlatformStart(), + eventBuilder.FunctionLog("INFO", "Received request"), + eventBuilder.FunctionLog("INFO", "Processing request"), + eventBuilder.FunctionLog("INFO", "Finished processing request"), + eventBuilder.PlatformRuntimeDone(), + }, + expectedCount: 5, + }, + { + name: "empty events array", + events: []Event{}, + expectedCount: 0, + }, + } + + for _, test := range testCases { + t.Run(test.name, func(t *testing.T) { + listener, address := setupListener(t) + submitEvents(t, address, test.events) + require.EventuallyWithT(t, func(c *assert.CollectT) { + require.Equal(c, test.expectedCount, listener.queue.Len()) + }, 1*time.Second, 50*time.Millisecond) + }) + } +} + +func TestListener_httpHandler_InvalidJSON(t *testing.T) { + withEnv(t, "AWS_SAM_LOCAL", "true") + logger := zaptest.NewLogger(t) + listener := NewListener(logger) + + address, err := listener.Start() + require.NoError(t, err, "Failed to start listener: %v", err) + defer listener.Shutdown() + + invalidJSON := []byte(`{"invalid": json}`) + resp, err := http.Post(address, "application/json", bytes.NewReader(invalidJSON)) + require.NoError(t, err, "Failed to post invalid JSON: %v", err) + require.NoError(t, resp.Body.Close(), "Failed to close response body") + + time.Sleep(50 * time.Millisecond) + require.Equal(t, listener.queue.Len(), int64(0), "Queue should be empty after invalid JSON") +} + +func TestListener_Wait_Success(t *testing.T) { + eventBuilder := NewTestEventBuilder("target-request") + + testCases := []struct { + name string + events []Event + }{ + { + name: "simple request", + events: []Event{ + eventBuilder.PlatformStart(), + eventBuilder.FunctionLog("INFO", "Received request"), + eventBuilder.FunctionLog("INFO", "Processing request"), + eventBuilder.FunctionLog("INFO", "Finished processing request"), + eventBuilder.PlatformRuntimeDone(), + }, + }, + { + name: "skips wrong request id", + events: []Event{ + NewTestEventBuilder("other-request-1").PlatformRuntimeDone(), + eventBuilder.PlatformStart(), + eventBuilder.FunctionLog("INFO", "Received request"), + NewTestEventBuilder("other-request-2").PlatformRuntimeDone(), + eventBuilder.FunctionLog("INFO", "Processing request"), + eventBuilder.FunctionLog("INFO", "Finished processing request"), + NewTestEventBuilder("other-request-3").PlatformRuntimeDone(), + eventBuilder.PlatformRuntimeDone(), + }, + }, + } + + for _, test := range testCases { + t.Run(test.name, func(t *testing.T) { + listener, address := setupListener(t) + + waitDone := make(chan error, 1) + go func() { + ctx := context.Background() + waitDone <- listener.Wait(ctx, "target-request") + }() + + assertWaitBlocks(t, waitDone, 50*time.Millisecond) + for i, event := range test.events { + submitEvents(t, address, []Event{event}) + if i < len(test.events)-1 { + assertWaitBlocks(t, waitDone, 50*time.Millisecond) + } else { + assertWaitCompletes(t, waitDone, 1*time.Second) + } + } + }) + } +} + +func TestListener_Wait_ContextCanceled(t *testing.T) { + logger := zaptest.NewLogger(t) + listener := NewListener(logger) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := listener.Wait(ctx, "any-req") + require.Equal(t, context.Canceled, err, "Context should have been canceled") +}