diff --git a/agent/agent.go b/agent/agent.go index c770544d..3100aecc 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -104,6 +104,9 @@ type Agent struct { cacheRefreshInterval time.Duration clusterCache *appstatecache.Cache + + inflightMu sync.Mutex + inflightLogs map[string]context.CancelFunc } const defaultQueueName = "default" diff --git a/agent/inbound.go b/agent/inbound.go index 2aaf72b2..19d86de9 100644 --- a/agent/inbound.go +++ b/agent/inbound.go @@ -72,6 +72,8 @@ func (a *Agent) processIncomingEvent(ev *event.Event) error { log().WithError(err).Errorf("Unable to process incoming redis event") } }() + case event.TargetContainerLog: + err = a.processIncomingContainerLogRequest(ev) default: err = fmt.Errorf("unknown event target - processIncomingEvent: %s", ev.Target()) } diff --git a/agent/log.go b/agent/log.go new file mode 100644 index 00000000..887a433b --- /dev/null +++ b/agent/log.go @@ -0,0 +1,579 @@ +// Copyright 2024 The argocd-agent Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package agent + +import ( + "bufio" + "bytes" + "context" + "errors" + "fmt" + "io" + "strings" + "time" + + "github.com/argoproj-labs/argocd-agent/internal/event" + "github.com/argoproj-labs/argocd-agent/pkg/api/grpc/logstreamapi" + "github.com/cenkalti/backoff/v4" + "github.com/sirupsen/logrus" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + corev1 "k8s.io/api/core/v1" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +// processIncomingContainerLogRequest handles container log requests from Principal +func (a *Agent) processIncomingContainerLogRequest(ev *event.Event) error { + logReq, err := ev.ContainerLogRequest() + if err != nil { + return err + } + + logCtx := log().WithFields(logrus.Fields{ + "uuid": logReq.UUID, + "namespace": logReq.Namespace, + "pod": logReq.PodName, + "container": logReq.Container, + "follow": logReq.Follow, + }) + + err = a.startLogStreamIfNew(logReq, logCtx) + if err != nil { + logCtx.WithError(err).Error("Log processing failed") + return err + } + + return nil +} + +// startLogStreamIfNew manages log streaming with duplicate detection +func (a *Agent) startLogStreamIfNew(logReq *event.ContainerLogRequest, logCtx *logrus.Entry) error { + a.inflightMu.Lock() + if a.inflightLogs == nil { + a.inflightLogs = make(map[string]context.CancelFunc) + } + if _, dup := a.inflightLogs[logReq.UUID]; dup { + a.inflightMu.Unlock() + logCtx.WithField("request_uuid", logReq.UUID).Warn("duplicate log request; already streaming") + return nil + } + ctx, cancel := context.WithCancel(a.context) + a.inflightLogs[logReq.UUID] = cancel + a.inflightMu.Unlock() + + defer func() { + a.inflightMu.Lock() + delete(a.inflightLogs, logReq.UUID) + a.inflightMu.Unlock() + }() + + logCtx.WithFields(logrus.Fields{ + "uuid": logReq.UUID, + "namespace": logReq.Namespace, + "pod": logReq.PodName, + "container": logReq.Container, + "follow": logReq.Follow, + }).Info("Processing log request") + + if logReq.Follow { + // Handle live logs with early ACK + return a.handleLiveStreaming(ctx, logReq, logCtx) + } + + // Handle static logs with completion ACK + return a.handleStaticLogs(ctx, logReq, logCtx) +} + +// handleStaticLogs handles static log requests (follow=false) with completion ACK +func (a *Agent) handleStaticLogs(ctx context.Context, logReq *event.ContainerLogRequest, logCtx *logrus.Entry) error { + // Create gRPC stream + stream, err := a.createLogStream(ctx) + if err != nil { + return err + } + + err = stream.Send(&logstreamapi.LogStreamData{ + RequestUuid: logReq.UUID, + Data: []byte{}, + Eof: false, + }) + if err != nil { + return err + } + + // Create Kubernetes log stream + rc, err := a.createKubernetesLogStream(ctx, logReq) + if err != nil { + _ = stream.Send(&logstreamapi.LogStreamData{RequestUuid: logReq.UUID, Eof: true, Error: err.Error()}) + _, _ = stream.CloseAndRecv() + return err + } + defer rc.Close() + + err = a.streamLogsToCompletion(ctx, stream, rc, logReq, logCtx) + + if _, cerr := stream.CloseAndRecv(); cerr != nil { + err = cerr + } + + if err != nil { + // Stop immediately on intentional server stops or auth issues + switch status.Code(err) { + case codes.Canceled, codes.NotFound: + return err + case codes.Unauthenticated, codes.PermissionDenied: + logCtx.WithError(err).Warn("Auth/permission failure") + a.SetConnected(false) + return err + default: + logCtx.WithError(err).Warn("Stream error") + } + } + + return err +} + +// handleLiveStreaming handles live log requests (follow=true) with early ACK and resume capability +func (a *Agent) handleLiveStreaming(ctx context.Context, logReq *event.ContainerLogRequest, logCtx *logrus.Entry) error { + // Start streaming with resume capability in background goroutine + go func() { + defer func() { + if r := recover(); r != nil { + logCtx.WithField("panic", r).Error("Panic in live log streaming") + } + }() + + streamCtx := logCtx.WithField("mode", "live_streaming") + a.streamLogsWithResume(ctx, logReq, streamCtx) + }() + + // Return success immediately - this sends the ACK to Principal + return nil +} + +// createLogStream creates a gRPC LogStream to the principal +func (a *Agent) createLogStream(ctx context.Context) (logstreamapi.LogStreamService_StreamLogsClient, error) { + conn := a.remote.Conn() + if conn == nil { + return nil, fmt.Errorf("gRPC connection is nil") + } + + client := logstreamapi.NewLogStreamServiceClient(conn) + return client.StreamLogs(ctx) +} + +// createKubernetesLogStream creates a Kubernetes log stream +func (a *Agent) createKubernetesLogStream(ctx context.Context, logReq *event.ContainerLogRequest) (io.ReadCloser, error) { + logOptions := &corev1.PodLogOptions{ + Container: logReq.Container, + Follow: logReq.Follow, + Timestamps: true, + Previous: logReq.Previous, + InsecureSkipTLSVerifyBackend: logReq.InsecureSkipTLSVerifyBackend, + TailLines: logReq.TailLines, + SinceSeconds: logReq.SinceSeconds, + LimitBytes: logReq.LimitBytes, + } + + // Handle SinceTime if provided + if logReq.SinceTime != "" { + if sinceTime, err := time.Parse(time.RFC3339, logReq.SinceTime); err == nil { + mt := v1.NewTime(sinceTime) + logOptions.SinceTime = &mt + } + } + + request := a.kubeClient.Clientset.CoreV1().Pods(logReq.Namespace).GetLogs(logReq.PodName, logOptions) + return request.Stream(ctx) +} + +// streamLogsToCompletion streams ALL available (static) logs from k8s to the principal. +// It flushes raw data without processing, using chunk size (64KB) or time-based flushing. +func (a *Agent) streamLogsToCompletion( + ctx context.Context, + stream logstreamapi.LogStreamService_StreamLogsClient, + rc io.ReadCloser, + logReq *event.ContainerLogRequest, + logCtx *logrus.Entry, +) error { + const ( + chunkMax = 64 * 1024 + flushEvery = 50 * time.Millisecond + ) + + defer rc.Close() + + br := bufio.NewReader(rc) + ticker := time.NewTicker(flushEvery) + defer ticker.Stop() + buf := make([]byte, 0, chunkMax) + readBuf := make([]byte, chunkMax) + + flush := func(reason string) error { + if len(buf) == 0 { + return nil + } + if err := stream.Send(&logstreamapi.LogStreamData{ + RequestUuid: logReq.UUID, + Data: buf, // bytes field recommended in proto + }); err != nil { + logCtx.WithFields(logrus.Fields{ + "reason": reason, "bytes": len(buf), "error": err.Error(), + }).Warn("Send failed") + return err + } + buf = buf[:0] + return nil + } + + for { + // Respect cancellations (flush pending bytes first) + select { + case <-ctx.Done(): + _ = flush("ctx_done") + return ctx.Err() + case <-stream.Context().Done(): + return stream.Context().Err() + default: + } + + n, err := br.Read(readBuf) + + // First, handle any bytes we got (even if err == io.EOF) + if n > 0 { + remain := readBuf[:n] + for len(remain) > 0 { + space := chunkMax - len(buf) + if space == 0 { + if err := flush("full"); err != nil { + return err + } + space = chunkMax + } + toCopy := len(remain) + if toCopy > space { + toCopy = space + } + buf = append(buf, remain[:toCopy]...) + remain = remain[toCopy:] + if len(buf) == chunkMax { + if err := flush("full"); err != nil { + return err + } + } + } + } + + if err != nil { + if errors.Is(err, io.EOF) { + _ = flush("eof") + // Clean end of this rc; final flush below. + break + } + // Real error: flush what we have, optionally notify, then return + _ = flush("error_before_return") + logCtx.WithError(err).Error("Error reading log stream") + _ = stream.Send(&logstreamapi.LogStreamData{ + RequestUuid: logReq.UUID, + Error: "log stream read failed", + }) + return err + } + + // Timer-based flush (fires only after a read iteration) + select { + case <-ticker.C: + if err := flush("timer"); err != nil { + return err + } + default: + } + } + + // Final flush on normal EOF + if err := flush("final"); err != nil { + return err + } + + _ = stream.Send(&logstreamapi.LogStreamData{ + RequestUuid: logReq.UUID, + Eof: true, + }) + + return nil +} + +func (a *Agent) streamLogsWithResume(ctx context.Context, logReq *event.ContainerLogRequest, logCtx *logrus.Entry) { + const ( + waitForReconnect = 10 * time.Second // how long we poll IsConnected() after Unauthenticated + pollEvery = 1 * time.Second + ) + + var lastTimestamp *time.Time + + // Configure exponential backoff with jitter + b := backoff.NewExponentialBackOff() + b.InitialInterval = 200 * time.Millisecond + b.Multiplier = 2.0 + b.MaxInterval = 5 * time.Second + b.MaxElapsedTime = 30 * time.Second + bo := backoff.WithContext(b, ctx) + + for { + // Build resume request + resumeReq := *logReq + if lastTimestamp != nil { + t := lastTimestamp.Add(-100 * time.Millisecond) + resumeReq.SinceTime = t.Format(time.RFC3339) + } + + // One attempt to create + stream + attempt := func() (err error) { + stream, err := a.createLogStream(ctx) + if err != nil { + return err + } + // Send empty data to indicate the start of the stream + // Used for health checks + err = stream.Send(&logstreamapi.LogStreamData{ + RequestUuid: logReq.UUID, + Data: []byte{}, + Eof: false, + }) + if err != nil { + _, err = stream.CloseAndRecv() + return err + } + + rc, err := a.createKubernetesLogStream(ctx, &resumeReq) + if err != nil { + _ = stream.Send(&logstreamapi.LogStreamData{RequestUuid: logReq.UUID, Eof: true, Error: err.Error()}) + _, _ = stream.CloseAndRecv() + return err + } + defer rc.Close() + + newLast, runErr := a.streamLogs(ctx, stream, rc, &resumeReq, logCtx) + if newLast != nil { + lastTimestamp = newLast + } + if _, cerr := stream.CloseAndRecv(); cerr != nil { + runErr = cerr + } + + return runErr + } + + err := attempt() + if err == nil { + return + } + + switch status.Code(err) { + case codes.Canceled, codes.NotFound: + // Intentional stop (UI gone / request not found) -> do not retry + logCtx.WithError(err).Info("Log stream ended") + return + + case codes.Unauthenticated, codes.PermissionDenied: + // Do NOT backoff-retry; instead block waiting for connector to become connected. + logCtx.WithError(err).Warn("Auth/permission failure") + a.SetConnected(false) + + waitCtx, cancel := context.WithTimeout(ctx, waitForReconnect) + t := time.NewTicker(pollEvery) + + reconnected := false + for !reconnected { + select { + case <-waitCtx.Done(): + cancel() + t.Stop() + return + case <-t.C: + if a.IsConnected() { + reconnected = true + } + } + } + cancel() + t.Stop() + b.Reset() + continue + default: + // Transient or unknown -> exponential backoff + d := bo.NextBackOff() + if d == backoff.Stop { + logCtx.WithError(err).Error("Backoff stopped") + return + } + select { + case <-ctx.Done(): + return + case <-time.After(d): + } + } + } +} + +// streamLogsv2 streams logs until the context is done, returning the last seen timestamp. +// It flushes raw data, using chunk size (64KB) or time-based flushing. +// Timestamps are extracted from raw lines for retry capability. +func (a *Agent) streamLogs(ctx context.Context, stream logstreamapi.LogStreamService_StreamLogsClient, rc io.ReadCloser, logReq *event.ContainerLogRequest, logCtx *logrus.Entry) (*time.Time, error) { + const ( + chunkMax = 64 * 1024 // 64KB chunks + flushEvery = 50 * time.Millisecond + ) + + br := bufio.NewReader(rc) + ticker := time.NewTicker(flushEvery) + defer ticker.Stop() + var lastTimestamp *time.Time + // Aggregate buffer that we cap at chunkMax + buf := make([]byte, 0, chunkMax) + // Reusable read buffer + readBuf := make([]byte, chunkMax) + + // Simple flush function for raw data + flush := func(reason string) error { + if len(buf) == 0 { + return nil + } + if err := stream.Send(&logstreamapi.LogStreamData{ + RequestUuid: logReq.UUID, + Data: buf, + }); err != nil { + logCtx.WithFields(logrus.Fields{ + "reason": reason, + "bytes": len(buf), + "error": err.Error(), + }).Warn("Send failed") + return err + } + buf = buf[:0] // Reset buffer + return nil + } + + for { + select { + case <-ctx.Done(): + _ = flush("ctx_done") + return lastTimestamp, ctx.Err() + case <-stream.Context().Done(): + return lastTimestamp, stream.Context().Err() + default: + } + + // Set read timeout to avoid blocking forever + if tcpConn, ok := rc.(interface{ SetReadDeadline(time.Time) error }); ok { + _ = tcpConn.SetReadDeadline(time.Now().Add(1 * time.Second)) + } + + n, err := br.Read(readBuf) + if n > 0 { + if logReq.Timestamps { + b := readBuf[:n] + if end := bytes.LastIndexByte(b, '\n'); end >= 0 { + start := bytes.LastIndexByte(b[:end], '\n') + 1 + line := b[start:end] + if len(line) > 0 && line[len(line)-1] == '\r' { + line = line[:len(line)-1] + } + if ts := extractTimestamp(string(line)); ts != nil { + lastTimestamp = ts + } + } + } + data := readBuf[:n] + // Append to capped buffer, splitting if needed + remain := data + for len(remain) > 0 { + space := chunkMax - len(buf) + if space == 0 { + if err := flush("full"); err != nil { + return lastTimestamp, err + } + space = chunkMax + } + toCopy := len(remain) + if toCopy > space { + toCopy = space + } + buf = append(buf, remain[:toCopy]...) + remain = remain[toCopy:] + if len(buf) == chunkMax { + if err := flush("full"); err != nil { + return lastTimestamp, err + } + } + } + } + // Now handle the read error/result + if err != nil { + logCtx.WithError(err).Error("Error reading log stream") + if errors.Is(err, io.EOF) { + // Final flush on EOF + if fErr := flush("eof"); fErr != nil { + return lastTimestamp, fErr + } + // Don't propagate EOF to principal for follow=true - it may be temporary due to: + // - Pod restart/termination + return lastTimestamp, err + } + _ = stream.Send(&logstreamapi.LogStreamData{ + RequestUuid: logReq.UUID, + Error: err.Error(), + }) + return lastTimestamp, err + } + + // Time-based flush to keep UI moving + select { + case <-ticker.C: + if err := flush("timer"); err != nil { + return lastTimestamp, err + } + default: + } + } +} + +// extractTimestamp extracts timestamp from a log line for resume capability +func extractTimestamp(line string) *time.Time { + if len(line) < 20 { // "2006-01-02T15:04:05Z" is 20 chars + return nil + } + // Grab the first token (up to whitespace). k8s puts a space after the timestamp. + space := strings.IndexAny(line, " \t") + if space == -1 { + // Fall back: try whole line (cheap fast-fail) + space = len(line) + } + // Guard against absurdly long "tokens" + const maxTSLen = 40 // a tad higher than needed; RFC3339Nano+offset is 35 + if space > maxTSLen { + return nil + } + token := line[:space] + + // Try the common RFC3339 flavors (covers with/without fractional seconds and offsets) + if ts, err := time.Parse(time.RFC3339Nano, token); err == nil { + return &ts + } + if ts, err := time.Parse(time.RFC3339, token); err == nil { + return &ts + } + return nil +} diff --git a/agent/log_test.go b/agent/log_test.go new file mode 100644 index 00000000..a6164a23 --- /dev/null +++ b/agent/log_test.go @@ -0,0 +1,374 @@ +// Copyright 2024 The argocd-agent Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package agent + +import ( + "context" + "io" + "strings" + "sync" + "testing" + "time" + + "github.com/argoproj-labs/argocd-agent/internal/event" + "github.com/argoproj-labs/argocd-agent/pkg/api/grpc/logstreamapi" + "github.com/argoproj-labs/argocd-agent/principal/apis/logstream/mock" + "github.com/argoproj-labs/argocd-agent/test/fake/kube" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/metadata" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +// MockLogStreamClient wraps the existing MockLogStreamServer for client-side testing +type MockLogStreamClient struct { + *mock.MockLogStreamServer + sentData []*logstreamapi.LogStreamData + sendFunc func(data *logstreamapi.LogStreamData) error + mu sync.RWMutex +} + +func NewMockLogStreamClient(ctx context.Context) *MockLogStreamClient { + return &MockLogStreamClient{ + MockLogStreamServer: mock.NewMockLogStreamServer(ctx), + sentData: make([]*logstreamapi.LogStreamData, 0), + sendFunc: func(data *logstreamapi.LogStreamData) error { + return nil + }, + } +} + +func (m *MockLogStreamClient) Send(data *logstreamapi.LogStreamData) error { + // For client-side testing, we track sent data + m.mu.Lock() + m.sentData = append(m.sentData, data) + m.mu.Unlock() + return m.sendFunc(data) +} + +func (m *MockLogStreamClient) CloseSend() error { + return nil +} + +func (m *MockLogStreamClient) Header() (metadata.MD, error) { + return metadata.New(nil), nil +} + +func (m *MockLogStreamClient) Trailer() metadata.MD { + return metadata.New(nil) +} + +func (m *MockLogStreamClient) SendMsg(msg interface{}) error { + return nil +} + +func (m *MockLogStreamClient) RecvMsg(msg interface{}) error { + return nil +} + +func (m *MockLogStreamClient) CloseAndRecv() (*logstreamapi.LogStreamResponse, error) { + // Simulate closing and receiving a response + m.mu.RLock() + linesReceived := len(m.sentData) + m.mu.RUnlock() + return &logstreamapi.LogStreamResponse{ + RequestUuid: "test-uuid", + Status: 200, + LinesReceived: int32(linesReceived), + }, nil +} + +func (m *MockLogStreamClient) GetSentData() []*logstreamapi.LogStreamData { + m.mu.RLock() + defer m.mu.RUnlock() + // Return a copy to avoid race conditions + result := make([]*logstreamapi.LogStreamData, len(m.sentData)) + copy(result, m.sentData) + return result +} + +func (m *MockLogStreamClient) Reset() { + m.mu.Lock() + m.sentData = make([]*logstreamapi.LogStreamData, 0) + m.mu.Unlock() +} + +func (m *MockLogStreamClient) SetSendFunc(fn func(data *logstreamapi.LogStreamData) error) { + m.sendFunc = fn +} + +// MockReadCloser implements io.ReadCloser for testing +type MockReadCloser struct { + io.Reader + closed bool +} + +func (m *MockReadCloser) Close() error { + m.closed = true + return nil +} + +func (m *MockReadCloser) IsClosed() bool { + return m.closed +} + +// Test helper functions +func createTestAgent() *Agent { + ctx, cancel := context.WithCancel(context.Background()) + agent := &Agent{ + context: ctx, + cancelFn: cancel, + inflightLogs: make(map[string]context.CancelFunc), + inflightMu: sync.Mutex{}, + } + return agent +} + +func createTestAgentWithKubeClient() *Agent { + ctx, cancel := context.WithCancel(context.Background()) + kubeClient := kube.NewKubernetesFakeClientWithResources() + agent := &Agent{ + context: ctx, + cancelFn: cancel, + kubeClient: kubeClient, + inflightLogs: make(map[string]context.CancelFunc), + inflightMu: sync.Mutex{}, + } + return agent +} + +func createTestLogRequest(follow bool) *event.ContainerLogRequest { + return &event.ContainerLogRequest{ + UUID: "test-uuid-123", + Namespace: "test-namespace", + PodName: "test-pod", + Container: "test-container", + Follow: follow, + Timestamps: true, + } +} + +// Test extractTimestamp function +func TestExtractTimestamp(t *testing.T) { + tests := []struct { + name string + input string + expected *time.Time + }{ + { + name: "RFC3339 with nanoseconds", + input: "2023-12-07T10:30:45.123456789Z some log message", + expected: timePtr(time.Date(2023, 12, 7, 10, 30, 45, 123456789, time.UTC)), + }, + { + name: "RFC3339 without nanoseconds", + input: "2023-12-07T10:30:45Z some log message", + expected: timePtr(time.Date(2023, 12, 7, 10, 30, 45, 0, time.UTC)), + }, + { + name: "RFC3339 with milliseconds", + input: "2023-12-07T10:30:45.123Z some log message", + expected: timePtr(time.Date(2023, 12, 7, 10, 30, 45, 123000000, time.UTC)), + }, + { + name: "no timestamp", + input: "some log message without timestamp", + expected: nil, + }, + { + name: "timestamp with tab separator", + input: "2023-12-07T10:30:45Z\tsome log message", + expected: timePtr(time.Date(2023, 12, 7, 10, 30, 45, 0, time.UTC)), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := extractTimestamp(tt.input) + if tt.expected == nil { + assert.Nil(t, result) + } else { + require.NotNil(t, result) + assert.Equal(t, *tt.expected, *result) + } + }) + } +} + +// Test createKubernetesLogStream +func TestCreateKubernetesLogStream(t *testing.T) { + ctx := context.Background() + logReq := createTestLogRequest(false) + + t.Run("Test createKubernetesLogStream with fake kube client and pod", func(t *testing.T) { + agent := createTestAgentWithKubeClient() + // Create a test pod + pod := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: logReq.PodName, + Namespace: logReq.Namespace, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: logReq.Container, + }, + }, + }, + } + // Create the pod in the fake client + _, err := agent.kubeClient.Clientset.CoreV1().Pods(logReq.Namespace).Create(ctx, pod, metav1.CreateOptions{}) + require.NoError(t, err) + + // Now test the log stream creation + rc, err := agent.createKubernetesLogStream(ctx, logReq) + // The fake client actually supports streaming and returns a valid ReadCloser + assert.NoError(t, err) + assert.NotNil(t, rc) + + // Clean up + if rc != nil { + rc.Close() + } + }) + + t.Run("Test createKubernetesLogStream with non-existent pod", func(t *testing.T) { + agent := createTestAgentWithKubeClient() + // Test with a non-existent pod + logReqNotFound := createTestLogRequest(false) + logReqNotFound.PodName = "non-existent-pod" + _, err := agent.createKubernetesLogStream(ctx, logReqNotFound) + assert.NoError(t, err) + }) +} + +// Test startLogStreamIfNew +func TestStartLogStreamIfNew(t *testing.T) { + logCtx := logrus.NewEntry(logrus.New()) + + t.Run("duplicate request", func(t *testing.T) { + logReq := createTestLogRequest(false) + agent := createTestAgent() + // Add a duplicate request + agent.inflightMu.Lock() + agent.inflightLogs[logReq.UUID] = func() {} + agent.inflightMu.Unlock() + + err := agent.startLogStreamIfNew(logReq, logCtx) + assert.NoError(t, err) // Should return early for duplicate + }) + + t.Run("new request", func(t *testing.T) { + logReq := createTestLogRequest(false) + agent := createTestAgent() + // This will panic due to missing dependencies, but we can check if new request is processed + assert.Panics(t, func() { + agent.startLogStreamIfNew(logReq, logCtx) + }) + + }) +} + +// Test streamLogsToCompletion +func TestStreamLogsToCompletion(t *testing.T) { + agent := createTestAgentWithKubeClient() + ctx := context.Background() + logReq := createTestLogRequest(false) + logCtx := logrus.NewEntry(logrus.New()) + + // Create a test reader with some log data + testData := "2023-12-07T10:30:45Z line 1\n2023-12-07T10:30:46Z line 2\n" + reader := &MockReadCloser{Reader: strings.NewReader(testData)} + + t.Run("successful streaming", func(t *testing.T) { + mockStream := NewMockLogStreamClient(ctx) + + err := agent.streamLogsToCompletion(ctx, mockStream, reader, logReq, logCtx) + assert.NoError(t, err) + + // Verify that data was sent + sentData := mockStream.GetSentData() + assert.GreaterOrEqual(t, len(sentData), 2) // At least 2 data messages + EOF + + // Check that the last message is EOF + lastMessage := sentData[len(sentData)-1] + assert.True(t, lastMessage.Eof) + }) + + t.Run("context cancellation", func(t *testing.T) { + cancelCtx, cancel := context.WithCancel(ctx) + cancel() // Cancel immediately + + mockStream := NewMockLogStreamClient(cancelCtx) + reader := &MockReadCloser{Reader: strings.NewReader(testData)} + err := agent.streamLogsToCompletion(cancelCtx, mockStream, reader, logReq, logCtx) + assert.Error(t, err) + assert.Equal(t, context.Canceled, err) + }) +} + +// Test streamLogs +func TestStreamLogs(t *testing.T) { + agent := createTestAgentWithKubeClient() + ctx := context.Background() + logReq := createTestLogRequest(true) + logCtx := logrus.NewEntry(logrus.New()) + + t.Run("successful streaming with data verification", func(t *testing.T) { + // Create a context that we can cancel + testCtx, cancel := context.WithCancel(ctx) + defer cancel() + + mockStream := NewMockLogStreamClient(testCtx) + testData := "2025-12-07T10:30:45Z line 1\n2025-12-07T10:30:46Z line 2\n" + reader := &MockReadCloser{Reader: strings.NewReader(testData)} + logReq.Timestamps = true + + // Start streaming in a goroutine + var lastTimestamp *time.Time + var streamErr error + done := make(chan struct{}) + + go func() { + defer close(done) + lastTimestamp, streamErr = agent.streamLogs(testCtx, mockStream, reader, logReq, logCtx) + }() + + // Wait for timer to fire (50ms flush interval + buffer) + time.Sleep(100 * time.Millisecond) + // Check if data was sent before cancelling + sentData := mockStream.GetSentData() + // Cancel context to stop the streaming (streamLogs will wait forever on EOF) + cancel() + <-done // Wait for goroutine to finish + + // Verify data was sent due to timer flush + assert.Greater(t, len(sentData), 0, "Data should be sent due to timer flush") + + // Verify the function ended with EOF (expected for streamLogs with finite data) + assert.Error(t, streamErr, "streamLogs should end with error") + assert.Equal(t, io.EOF, streamErr, "Should end with EOF, not timeout") + + // Verify timestamp extraction worked + assert.NotNil(t, lastTimestamp, "Timestamp should be extracted") + assert.Equal(t, 2025, lastTimestamp.Year()) + }) +} + +// Helper function to create time pointer +func timePtr(t time.Time) *time.Time { + return &t +} diff --git a/go.mod b/go.mod index 39aed2cf..81ca69c9 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/alicebob/miniredis/v2 v2.35.0 github.com/argoproj/argo-cd/v3 v3.1.8 github.com/argoproj/gitops-engine v0.7.1-0.20250905160054-e48120133eec + github.com/cenkalti/backoff/v4 v4.3.0 github.com/cloudevents/sdk-go/binding/format/protobuf/v2 v2.16.2 github.com/cloudevents/sdk-go/v2 v2.16.2 github.com/go-redis/cache/v9 v9.0.0 diff --git a/go.sum b/go.sum index 03cc951d..8c89eacb 100644 --- a/go.sum +++ b/go.sum @@ -65,6 +65,8 @@ github.com/casbin/casbin/v2 v2.107.0/go.mod h1:Ee33aqGrmES+GNL17L0h9X28wXuo829wn github.com/casbin/govaluate v1.3.0/go.mod h1:G/UnbIjZk/0uMNaLwZZmFQrR72tYRZWQkO70si/iR7A= github.com/casbin/govaluate v1.7.0 h1:Es2j2K2jv7br+QHJhxKcdoOa4vND0g0TqsO6rJeqJbA= github.com/casbin/govaluate v1.7.0/go.mod h1:G/UnbIjZk/0uMNaLwZZmFQrR72tYRZWQkO70si/iR7A= +github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= +github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= diff --git a/internal/event/event.go b/internal/event/event.go index e0c9d3bd..c8d5767d 100644 --- a/internal/event/event.go +++ b/internal/event/event.go @@ -19,6 +19,7 @@ import ( "errors" "fmt" "net/http" + "strconv" "sync" "time" @@ -80,6 +81,7 @@ const ( TargetResourceResync EventTarget = "resourceResync" TargetClusterCacheInfoUpdate EventTarget = "clusterCacheInfoUpdate" TargetRepository EventTarget = "repository" + TargetContainerLog EventTarget = "containerlog" ) const ( @@ -406,9 +408,9 @@ func (evs EventSource) NewResourceRequestEvent(gvr v1.GroupVersionResource, name cev.SetSource(evs.source) cev.SetSpecVersion(cloudEventSpecVersion) cev.SetType(method) - cev.SetDataSchema(TargetResource.String()) cev.SetExtension(resourceID, reqUUID) cev.SetExtension(eventID, reqUUID) + cev.SetDataSchema(TargetResource.String()) err := cev.SetData(cloudevents.ApplicationJSON, rr) return &cev, err } @@ -596,6 +598,8 @@ func Target(raw *cloudevents.Event) EventTarget { return TargetRedis case TargetClusterCacheInfoUpdate.String(): return TargetClusterCacheInfoUpdate + case TargetContainerLog.String(): + return TargetContainerLog } return "" } @@ -950,3 +954,91 @@ func (ewm *EventWritersMap) Remove(agentName string) { delete(ewm.eventWriters, agentName) } + +type ContainerLogRequest struct { + // UUID for request/response correlation + UUID string `json:"uuid"` + Namespace string `json:"namespace"` + PodName string `json:"podName"` + Container string `json:"container,omitempty"` + Follow bool `json:"follow,omitempty"` + TailLines *int64 `json:"tailLines,omitempty"` + SinceSeconds *int64 `json:"sinceSeconds,omitempty"` + SinceTime string `json:"sinceTime,omitempty"` + Timestamps bool `json:"timestamps,omitempty"` + Previous bool `json:"previous,omitempty"` + InsecureSkipTLSVerifyBackend bool `json:"insecureSkipTLSVerifyBackend,omitempty"` + LimitBytes *int64 `json:"limitBytes,omitempty"` +} + +// NewLogRequestEvent creates a cloud event for requesting logs +func (evs EventSource) NewLogRequestEvent(namespace, podName, method string, params map[string]string) (*cloudevents.Event, error) { + reqUUID := uuid.NewString() + + // Parse log-specific parameters + logReq := &ContainerLogRequest{ + UUID: reqUUID, + Namespace: namespace, + PodName: podName, + } + + if container, ok := params["container"]; ok { + logReq.Container = container + } + + // Parse query parameters + if follow := params["follow"]; follow == "true" { + logReq.Follow = true + } + + if tailLines := params["tailLines"]; tailLines != "" { + if lines, err := strconv.ParseInt(tailLines, 10, 64); err == nil { + logReq.TailLines = &lines + } + } + + if sinceSeconds := params["sinceSeconds"]; sinceSeconds != "" { + if seconds, err := strconv.ParseInt(sinceSeconds, 10, 64); err == nil { + logReq.SinceSeconds = &seconds + } + } + + if sinceTime := params["sinceTime"]; sinceTime != "" { + logReq.SinceTime = sinceTime + } + + if timestamps := params["timestamps"]; timestamps == "true" { + logReq.Timestamps = true + } + + if previous := params["previous"]; previous == "true" { + logReq.Previous = true + } + + // Parse additional K8s logs API parameters + if insecureSkipTLS := params["insecureSkipTLSVerifyBackend"]; insecureSkipTLS == "true" { + logReq.InsecureSkipTLSVerifyBackend = true + } + + if limitBytes := params["limitBytes"]; limitBytes != "" { + if bytes, err := strconv.ParseInt(limitBytes, 10, 64); err == nil { + logReq.LimitBytes = &bytes + } + } + cev := cloudevents.NewEvent() + cev.SetSource(evs.source) + cev.SetSpecVersion(cloudEventSpecVersion) + cev.SetType(method) // HTTP method + cev.SetDataSchema(TargetContainerLog.String()) + cev.SetExtension(resourceID, reqUUID) + cev.SetExtension(eventID, reqUUID) + err := cev.SetData(cloudevents.ApplicationJSON, logReq) + return &cev, err +} + +// ContainerLogRequest extracts ContainerLogRequest data from event +func (ev *Event) ContainerLogRequest() (*ContainerLogRequest, error) { + logReq := &ContainerLogRequest{} + err := ev.event.DataAs(logReq) + return logReq, err +} diff --git a/pkg/api/grpc/logstreamapi/logstream.pb.go b/pkg/api/grpc/logstreamapi/logstream.pb.go new file mode 100644 index 00000000..328a183a --- /dev/null +++ b/pkg/api/grpc/logstreamapi/logstream.pb.go @@ -0,0 +1,281 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.28.1 +// protoc v4.25.3 +// source: logstream.proto + +package logstreamapi + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// LogStreamData represents a line (or chunk) of log data sent from the agent +type LogStreamData struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // Unique identifier matching the original request/event UUID + RequestUuid string `protobuf:"bytes,1,opt,name=request_uuid,json=requestUuid,proto3" json:"request_uuid,omitempty"` + // Log data content. + Data []byte `protobuf:"bytes,2,opt,name=data,proto3" json:"data,omitempty"` + // End of stream indicator + Eof bool `protobuf:"varint,3,opt,name=eof,proto3" json:"eof,omitempty"` + // Optional error message + Error string `protobuf:"bytes,4,opt,name=error,proto3" json:"error,omitempty"` +} + +func (x *LogStreamData) Reset() { + *x = LogStreamData{} + if protoimpl.UnsafeEnabled { + mi := &file_logstream_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *LogStreamData) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*LogStreamData) ProtoMessage() {} + +func (x *LogStreamData) ProtoReflect() protoreflect.Message { + mi := &file_logstream_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use LogStreamData.ProtoReflect.Descriptor instead. +func (*LogStreamData) Descriptor() ([]byte, []int) { + return file_logstream_proto_rawDescGZIP(), []int{0} +} + +func (x *LogStreamData) GetRequestUuid() string { + if x != nil { + return x.RequestUuid + } + return "" +} + +func (x *LogStreamData) GetData() []byte { + if x != nil { + return x.Data + } + return nil +} + +func (x *LogStreamData) GetEof() bool { + if x != nil { + return x.Eof + } + return false +} + +func (x *LogStreamData) GetError() string { + if x != nil { + return x.Error + } + return "" +} + +// LogStreamResponse is returned by principal when the agent closes the stream +type LogStreamResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + RequestUuid string `protobuf:"bytes,1,opt,name=request_uuid,json=requestUuid,proto3" json:"request_uuid,omitempty"` + Status int32 `protobuf:"varint,2,opt,name=status,proto3" json:"status,omitempty"` // 200 on success + Error string `protobuf:"bytes,3,opt,name=error,proto3" json:"error,omitempty"` + LinesReceived int32 `protobuf:"varint,4,opt,name=lines_received,json=linesReceived,proto3" json:"lines_received,omitempty"` +} + +func (x *LogStreamResponse) Reset() { + *x = LogStreamResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_logstream_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *LogStreamResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*LogStreamResponse) ProtoMessage() {} + +func (x *LogStreamResponse) ProtoReflect() protoreflect.Message { + mi := &file_logstream_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use LogStreamResponse.ProtoReflect.Descriptor instead. +func (*LogStreamResponse) Descriptor() ([]byte, []int) { + return file_logstream_proto_rawDescGZIP(), []int{1} +} + +func (x *LogStreamResponse) GetRequestUuid() string { + if x != nil { + return x.RequestUuid + } + return "" +} + +func (x *LogStreamResponse) GetStatus() int32 { + if x != nil { + return x.Status + } + return 0 +} + +func (x *LogStreamResponse) GetError() string { + if x != nil { + return x.Error + } + return "" +} + +func (x *LogStreamResponse) GetLinesReceived() int32 { + if x != nil { + return x.LinesReceived + } + return 0 +} + +var File_logstream_proto protoreflect.FileDescriptor + +var file_logstream_proto_rawDesc = []byte{ + 0x0a, 0x0f, 0x6c, 0x6f, 0x67, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x2e, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x12, 0x0c, 0x6c, 0x6f, 0x67, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x61, 0x70, 0x69, 0x22, + 0x6e, 0x0a, 0x0d, 0x4c, 0x6f, 0x67, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x44, 0x61, 0x74, 0x61, + 0x12, 0x21, 0x0a, 0x0c, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, 0x75, 0x75, 0x69, 0x64, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x55, + 0x75, 0x69, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x64, 0x61, 0x74, 0x61, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x0c, 0x52, 0x04, 0x64, 0x61, 0x74, 0x61, 0x12, 0x10, 0x0a, 0x03, 0x65, 0x6f, 0x66, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x65, 0x6f, 0x66, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, + 0x6f, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, + 0x8b, 0x01, 0x0a, 0x11, 0x4c, 0x6f, 0x67, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x21, 0x0a, 0x0c, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x5f, 0x75, 0x75, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x72, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x55, 0x75, 0x69, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, + 0x75, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, + 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x12, 0x25, 0x0a, 0x0e, 0x6c, 0x69, 0x6e, 0x65, 0x73, 0x5f, + 0x72, 0x65, 0x63, 0x65, 0x69, 0x76, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0d, + 0x6c, 0x69, 0x6e, 0x65, 0x73, 0x52, 0x65, 0x63, 0x65, 0x69, 0x76, 0x65, 0x64, 0x32, 0x60, 0x0a, + 0x10, 0x4c, 0x6f, 0x67, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, + 0x65, 0x12, 0x4c, 0x0a, 0x0a, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x4c, 0x6f, 0x67, 0x73, 0x12, + 0x1b, 0x2e, 0x6c, 0x6f, 0x67, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x61, 0x70, 0x69, 0x2e, 0x4c, + 0x6f, 0x67, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x44, 0x61, 0x74, 0x61, 0x1a, 0x1f, 0x2e, 0x6c, + 0x6f, 0x67, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x61, 0x70, 0x69, 0x2e, 0x4c, 0x6f, 0x67, 0x53, + 0x74, 0x72, 0x65, 0x61, 0x6d, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x28, 0x01, 0x42, + 0x41, 0x5a, 0x3f, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x61, 0x72, + 0x67, 0x6f, 0x70, 0x72, 0x6f, 0x6a, 0x2d, 0x6c, 0x61, 0x62, 0x73, 0x2f, 0x61, 0x72, 0x67, 0x6f, + 0x63, 0x64, 0x2d, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x61, 0x70, 0x69, + 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x6c, 0x6f, 0x67, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x61, + 0x70, 0x69, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_logstream_proto_rawDescOnce sync.Once + file_logstream_proto_rawDescData = file_logstream_proto_rawDesc +) + +func file_logstream_proto_rawDescGZIP() []byte { + file_logstream_proto_rawDescOnce.Do(func() { + file_logstream_proto_rawDescData = protoimpl.X.CompressGZIP(file_logstream_proto_rawDescData) + }) + return file_logstream_proto_rawDescData +} + +var file_logstream_proto_msgTypes = make([]protoimpl.MessageInfo, 2) +var file_logstream_proto_goTypes = []interface{}{ + (*LogStreamData)(nil), // 0: logstreamapi.LogStreamData + (*LogStreamResponse)(nil), // 1: logstreamapi.LogStreamResponse +} +var file_logstream_proto_depIdxs = []int32{ + 0, // 0: logstreamapi.LogStreamService.StreamLogs:input_type -> logstreamapi.LogStreamData + 1, // 1: logstreamapi.LogStreamService.StreamLogs:output_type -> logstreamapi.LogStreamResponse + 1, // [1:2] is the sub-list for method output_type + 0, // [0:1] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_logstream_proto_init() } +func file_logstream_proto_init() { + if File_logstream_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_logstream_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*LogStreamData); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_logstream_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*LogStreamResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_logstream_proto_rawDesc, + NumEnums: 0, + NumMessages: 2, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_logstream_proto_goTypes, + DependencyIndexes: file_logstream_proto_depIdxs, + MessageInfos: file_logstream_proto_msgTypes, + }.Build() + File_logstream_proto = out.File + file_logstream_proto_rawDesc = nil + file_logstream_proto_goTypes = nil + file_logstream_proto_depIdxs = nil +} diff --git a/pkg/api/grpc/logstreamapi/logstream_grpc.pb.go b/pkg/api/grpc/logstreamapi/logstream_grpc.pb.go new file mode 100644 index 00000000..430ee1f5 --- /dev/null +++ b/pkg/api/grpc/logstreamapi/logstream_grpc.pb.go @@ -0,0 +1,141 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.2.0 +// - protoc v4.25.3 +// source: logstream.proto + +package logstreamapi + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.32.0 or later. +const _ = grpc.SupportPackageIsVersion7 + +// LogStreamServiceClient is the client API for LogStreamService service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type LogStreamServiceClient interface { + // Agent establishes a client-streaming RPC and sends log data to principal + StreamLogs(ctx context.Context, opts ...grpc.CallOption) (LogStreamService_StreamLogsClient, error) +} + +type logStreamServiceClient struct { + cc grpc.ClientConnInterface +} + +func NewLogStreamServiceClient(cc grpc.ClientConnInterface) LogStreamServiceClient { + return &logStreamServiceClient{cc} +} + +func (c *logStreamServiceClient) StreamLogs(ctx context.Context, opts ...grpc.CallOption) (LogStreamService_StreamLogsClient, error) { + stream, err := c.cc.NewStream(ctx, &LogStreamService_ServiceDesc.Streams[0], "/logstreamapi.LogStreamService/StreamLogs", opts...) + if err != nil { + return nil, err + } + x := &logStreamServiceStreamLogsClient{stream} + return x, nil +} + +type LogStreamService_StreamLogsClient interface { + Send(*LogStreamData) error + CloseAndRecv() (*LogStreamResponse, error) + grpc.ClientStream +} + +type logStreamServiceStreamLogsClient struct { + grpc.ClientStream +} + +func (x *logStreamServiceStreamLogsClient) Send(m *LogStreamData) error { + return x.ClientStream.SendMsg(m) +} + +func (x *logStreamServiceStreamLogsClient) CloseAndRecv() (*LogStreamResponse, error) { + if err := x.ClientStream.CloseSend(); err != nil { + return nil, err + } + m := new(LogStreamResponse) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +// LogStreamServiceServer is the server API for LogStreamService service. +// All implementations must embed UnimplementedLogStreamServiceServer +// for forward compatibility +type LogStreamServiceServer interface { + // Agent establishes a client-streaming RPC and sends log data to principal + StreamLogs(LogStreamService_StreamLogsServer) error + mustEmbedUnimplementedLogStreamServiceServer() +} + +// UnimplementedLogStreamServiceServer must be embedded to have forward compatible implementations. +type UnimplementedLogStreamServiceServer struct { +} + +func (UnimplementedLogStreamServiceServer) StreamLogs(LogStreamService_StreamLogsServer) error { + return status.Errorf(codes.Unimplemented, "method StreamLogs not implemented") +} +func (UnimplementedLogStreamServiceServer) mustEmbedUnimplementedLogStreamServiceServer() {} + +// UnsafeLogStreamServiceServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to LogStreamServiceServer will +// result in compilation errors. +type UnsafeLogStreamServiceServer interface { + mustEmbedUnimplementedLogStreamServiceServer() +} + +func RegisterLogStreamServiceServer(s grpc.ServiceRegistrar, srv LogStreamServiceServer) { + s.RegisterService(&LogStreamService_ServiceDesc, srv) +} + +func _LogStreamService_StreamLogs_Handler(srv interface{}, stream grpc.ServerStream) error { + return srv.(LogStreamServiceServer).StreamLogs(&logStreamServiceStreamLogsServer{stream}) +} + +type LogStreamService_StreamLogsServer interface { + SendAndClose(*LogStreamResponse) error + Recv() (*LogStreamData, error) + grpc.ServerStream +} + +type logStreamServiceStreamLogsServer struct { + grpc.ServerStream +} + +func (x *logStreamServiceStreamLogsServer) SendAndClose(m *LogStreamResponse) error { + return x.ServerStream.SendMsg(m) +} + +func (x *logStreamServiceStreamLogsServer) Recv() (*LogStreamData, error) { + m := new(LogStreamData) + if err := x.ServerStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +// LogStreamService_ServiceDesc is the grpc.ServiceDesc for LogStreamService service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var LogStreamService_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "logstreamapi.LogStreamService", + HandlerType: (*LogStreamServiceServer)(nil), + Methods: []grpc.MethodDesc{}, + Streams: []grpc.StreamDesc{ + { + StreamName: "StreamLogs", + Handler: _LogStreamService_StreamLogs_Handler, + ClientStreams: true, + }, + }, + Metadata: "logstream.proto", +} diff --git a/principal/apis/logstream/logstream.go b/principal/apis/logstream/logstream.go new file mode 100644 index 00000000..116fd34b --- /dev/null +++ b/principal/apis/logstream/logstream.go @@ -0,0 +1,317 @@ +// Copyright 2024 The argocd-agent Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logstream + +import ( + "context" + "fmt" + "io" + "net/http" + "sync" + "time" + + "github.com/argoproj-labs/argocd-agent/pkg/api/grpc/logstreamapi" + "github.com/sirupsen/logrus" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// Server implements logstreamapi.LogStreamServiceServer +type Server struct { + logstreamapi.UnimplementedLogStreamServiceServer + mu sync.RWMutex + sessions map[string]*session +} + +type session struct { + hw *httpWriter + completeCh chan bool // signaled on EOF (static logs) + cancelFn context.CancelFunc +} + +type httpWriter struct { + w http.ResponseWriter + flusher http.Flusher +} + +type logClient struct { + ctx context.Context + cancelFn context.CancelFunc + logCtx *logrus.Entry + requestID string + terminateErr error // returned as stream status when set +} + +func (s *Server) newLogClient(ctx context.Context) *logClient { + cctx, cancel := context.WithCancel(ctx) + return &logClient{ + ctx: cctx, + cancelFn: cancel, + logCtx: logrus.WithField("module", "LogStream"), + } +} + +func NewServer() *Server { + logrus.Info("Starting LogStream gRPC service") + return &Server{ + sessions: make(map[string]*session), + } +} + +// RegisterHTTP registers an HTTP writer for a given request UUID +func (s *Server) RegisterHTTP(requestUUID string, w http.ResponseWriter, r *http.Request) error { + s.mu.Lock() + defer s.mu.Unlock() + + flusher, ok := w.(http.Flusher) + if !ok { + return status.Error(codes.FailedPrecondition, "writer does not support flushing") + } + // streaming headers + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.Header().Set("Cache-Control", "no-cache, no-transform") + w.Header().Set("Connection", "keep-alive") + w.WriteHeader(http.StatusOK) + + // upsert session + sess := s.sessions[requestUUID] + if sess == nil { + sess = &session{ + hw: &httpWriter{w: w, flusher: flusher}, + completeCh: make(chan bool, 1), + } + s.sessions[requestUUID] = sess + } else { + sess.hw = &httpWriter{w: w, flusher: flusher} + } + + //watchdog for client disconnection. When client disconnects, immediately cancel the stream + go func(reqID, ua, ra string, done <-chan struct{}) { + //wait for client disconnection + <-done + logrus.WithFields(logrus.Fields{ + "request_id": reqID, + "reason": "client_disconnected", + }).Info("Stream terminated due to client disconnection") + + s.mu.Lock() + sess := s.sessions[reqID] + if sess != nil { + sess.hw = nil + if sess.cancelFn != nil { + // Tag stream as canceled due to client detach. + sess.cancelFn() + } + } + s.mu.Unlock() + }(requestUUID, r.Header.Get("User-Agent"), r.RemoteAddr, r.Context().Done()) + + return nil +} + +// tryRecvWithCancel wraps stream.Recv() so we can abort via c.ctx.Done(). +func tryRecvWithCancel[T any](ctx context.Context, fn func() (T, error)) (T, error) { + type res struct { + m T + err error + } + ch := make(chan res, 1) + go func() { + m, err := fn() + select { + case ch <- res{m, err}: // send result if someone is still listening + case <-ctx.Done(): // otherwise just exit quietly + } + }() + select { + case <-ctx.Done(): + var zero T + return zero, status.Error(codes.Canceled, "client detached timeout") + case r := <-ch: + return r.m, r.err + } +} + +// StreamLogs receives log data from agent and forwards to HTTP writer +func (s *Server) StreamLogs(stream logstreamapi.LogStreamService_StreamLogsServer) error { + c := s.newLogClient(stream.Context()) + for { + msg, err := tryRecvWithCancel(c.ctx, stream.Recv) // ← cancelable + if err != nil { + // prefer a consistent reason for the agent + if status.Code(err) == codes.Canceled && c.terminateErr == nil { + c.terminateErr = status.Error(codes.Canceled, "client detached timeout") + } + break + } + if msg == nil { + c.terminateErr = status.Error(codes.InvalidArgument, "invalid log message") + break + } + + // First message, capture request UUID and expose cancelFn for detach handler + if c.requestID == "" { + c.requestID = msg.GetRequestUuid() + c.logCtx = c.logCtx.WithField("request_id", c.requestID) + c.logCtx.Info("LogStream started") + + s.mu.Lock() + if sess, ok := s.sessions[c.requestID]; ok { + sess.cancelFn = func() { + // tag this stream as terminated due to client detach TTL + c.terminateErr = status.Error(codes.Canceled, "client detached timeout") + c.cancelFn() + } + } + s.mu.Unlock() + } + + if err := s.processLogMessage(c, msg); err != nil { + // processLogMessage can return io.EOF or a terminal status + if err == io.EOF { + break + } + c.terminateErr = err + break + } + } + // Cleanup session + if c.requestID != "" { + s.finalizeSession(c.requestID) + } + + if c.terminateErr != nil { + return c.terminateErr + } + resp := &logstreamapi.LogStreamResponse{RequestUuid: c.requestID, Status: 200} + return stream.SendAndClose(resp) +} + +// safeFlush prevents process crash if ResponseWriter is gone. +func safeFlush(f http.Flusher) (err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("flush panic: %v", r) + } + }() + f.Flush() + return nil +} + +func (s *Server) processLogMessage(c *logClient, msg *logstreamapi.LogStreamData) error { + reqID := msg.GetRequestUuid() + logCtx := c.logCtx + + // Lookup session + s.mu.RLock() + sess := s.sessions[reqID] + s.mu.RUnlock() + if sess == nil { + logCtx.Warn("received data for unknown request; terminating") + return status.Error(codes.NotFound, "unknown request id") + } + + // Agent forwarded error + if msg.GetError() != "" { + logCtx.WithField("error", msg.GetError()).Warn("log stream error from agent") + return status.Error(codes.Internal, msg.GetError()) + } + // EOF + if msg.GetEof() { + logCtx.Info("LogStream EOF") + s.mu.Lock() + if sess, ok := s.sessions[reqID]; ok && sess.completeCh != nil { + select { + case sess.completeCh <- true: + default: // don't block if already signaled + } + } + s.mu.Unlock() + return io.EOF + } + + data := msg.GetData() + // Agent sends empty frame as prob, no flush needed + if len(data) == 0 { + return nil + } + logCtx.WithField("data_length", len(data)).Trace("data received") + + // Get current writer + s.mu.RLock() + hw := sess.hw + cancel := sess.cancelFn + s.mu.RUnlock() + + // If writer is gone, end the stream (vanilla semantics: new request will be created) + if hw == nil { + logCtx.Info("HTTP writer missing; terminating stream") + return status.Error(codes.Canceled, "client disconnected") + } + + // write + flush; on failure, null writer and cancel + if _, err := hw.w.Write(data); err != nil { + logCtx.WithError(err).Warn("HTTP write failed; canceling stream") + s.mu.Lock() + if sess := s.sessions[reqID]; sess != nil { + sess.hw = nil + } + s.mu.Unlock() + if cancel != nil { + cancel() + } + return status.Error(codes.Canceled, "HTTP write failed") + } + if err := safeFlush(hw.flusher); err != nil { + logCtx.WithError(err).Warn("HTTP flush failed; canceling stream") + s.mu.Lock() + if sess := s.sessions[reqID]; sess != nil { + sess.hw = nil + } + s.mu.Unlock() + if cancel != nil { + cancel() + } + return status.Error(codes.Canceled, "HTTP flush failed") + } + logCtx.WithFields(logrus.Fields{ + "data_length": len(data), + "request_id": reqID, + }).Info("HTTP write and flush successful") + return nil +} + +// WaitForCompletion waits for a LogStream to complete (static logs) or times out +func (s *Server) WaitForCompletion(requestUUID string, timeout time.Duration) bool { + s.mu.RLock() + sess := s.sessions[requestUUID] + s.mu.RUnlock() + + if sess == nil || sess.completeCh == nil { + return false + } + select { + case <-sess.completeCh: + return true + case <-time.After(timeout): + return false + } +} + +func (s *Server) finalizeSession(requestUUID string) { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.sessions, requestUUID) +} diff --git a/principal/apis/logstream/logstream.proto b/principal/apis/logstream/logstream.proto new file mode 100644 index 00000000..03daeb83 --- /dev/null +++ b/principal/apis/logstream/logstream.proto @@ -0,0 +1,32 @@ +syntax = "proto3"; + +package logstreamapi; + +option go_package = "github.com/argoproj-labs/argocd-agent/pkg/api/grpc/logstreamapi"; + +// LogStreamData represents a line (or chunk) of log data sent from the agent +message LogStreamData { + // Unique identifier matching the original request/event UUID + string request_uuid = 1; + // Log data content. + bytes data = 2; + // End of stream indicator + bool eof = 3; + // Optional error message + string error = 4; +} + +// LogStreamResponse is returned by principal when the agent closes the stream +message LogStreamResponse { + string request_uuid = 1; + int32 status = 2; // 200 on success + string error = 3; + int32 lines_received = 4; +} + +service LogStreamService { + // Agent establishes a client-streaming RPC and sends log data to principal + rpc StreamLogs(stream LogStreamData) returns (LogStreamResponse); +} + + diff --git a/principal/apis/logstream/logstream_test.go b/principal/apis/logstream/logstream_test.go new file mode 100644 index 00000000..ed8ce417 --- /dev/null +++ b/principal/apis/logstream/logstream_test.go @@ -0,0 +1,526 @@ +// Copyright 2024 The argocd-agent Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logstream + +import ( + "context" + "fmt" + "io" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/argoproj-labs/argocd-agent/pkg/api/grpc/logstreamapi" + "github.com/argoproj-labs/argocd-agent/principal/apis/logstream/mock" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewServer(t *testing.T) { + server := NewServer() + assert.NotNil(t, server) + assert.NotNil(t, server.sessions) + assert.Equal(t, 0, len(server.sessions)) +} + +func TestRegisterHTTP(t *testing.T) { + server := NewServer() + requestUUID := "test-request-123" + + t.Run("successful registration", func(t *testing.T) { + w := mock.NewMockHTTPResponseWriter() + r := httptest.NewRequest("GET", "/logs", nil) + + err := server.RegisterHTTP(requestUUID, w, r) + assert.NoError(t, err) + + // Check headers were set + assert.Equal(t, "text/plain; charset=utf-8", w.Header().Get("Content-Type")) + assert.Equal(t, "no-cache, no-transform", w.Header().Get("Cache-Control")) + assert.Equal(t, "keep-alive", w.Header().Get("Connection")) + assert.Equal(t, http.StatusOK, w.GetStatusCode()) + + // Check session was created + server.mu.RLock() + sess, exists := server.sessions[requestUUID] + server.mu.RUnlock() + assert.True(t, exists) + assert.NotNil(t, sess) + assert.NotNil(t, sess.hw) + assert.NotNil(t, sess.completeCh) + }) + + t.Run("writer without flusher", func(t *testing.T) { + // Create a writer that doesn't implement http.Flusher + w := &mock.MockWriterWithoutFlusher{} + r := httptest.NewRequest("GET", "/logs", nil) + + err := server.RegisterHTTP(requestUUID, w, r) + assert.Error(t, err) + assert.Contains(t, err.Error(), "writer does not support flushing") + }) + + t.Run("upsert existing session", func(t *testing.T) { + // First registration + w1 := mock.NewMockHTTPResponseWriter() + r1 := httptest.NewRequest("GET", "/logs", nil) + err := server.RegisterHTTP(requestUUID, w1, r1) + assert.NoError(t, err) + + // Second registration (upsert) + w2 := mock.NewMockHTTPResponseWriter() + r2 := httptest.NewRequest("GET", "/logs", nil) + err = server.RegisterHTTP(requestUUID, w2, r2) + assert.NoError(t, err) + + // Check session was updated + server.mu.RLock() + sess, exists := server.sessions[requestUUID] + server.mu.RUnlock() + assert.True(t, exists) + assert.Equal(t, w2, sess.hw.w) // Should be the new writer + }) +} + +func TestStreamLogs(t *testing.T) { + server := NewServer() + requestUUID := "test-request-123" + + t.Run("successful log streaming", func(t *testing.T) { + // Register HTTP session first + w := mock.NewMockHTTPResponseWriter() + r := httptest.NewRequest("GET", "/logs", nil) + err := server.RegisterHTTP(requestUUID, w, r) + require.NoError(t, err) + + // Create mock stream + ctx := context.Background() + mockStream := mock.NewMockLogStreamServer(ctx) + + // Add test data + mockStream.AddRecvData(&logstreamapi.LogStreamData{ + RequestUuid: requestUUID, + Data: []byte("test log line 1\n"), + }) + mockStream.AddRecvData(&logstreamapi.LogStreamData{ + RequestUuid: requestUUID, + Data: []byte("test log line 2\n"), + }) + mockStream.AddRecvData(&logstreamapi.LogStreamData{ + RequestUuid: requestUUID, + Eof: true, + }) + + // Run StreamLogs + err = server.StreamLogs(mockStream) + assert.NoError(t, err) + + // Check that data was written to HTTP response + body := w.GetBody() + assert.Contains(t, body, "test log line 1") + assert.Contains(t, body, "test log line 2") + }) + + t.Run("stream with error from agent", func(t *testing.T) { + // Register HTTP session first + w := mock.NewMockHTTPResponseWriter() + r := httptest.NewRequest("GET", "/logs", nil) + err := server.RegisterHTTP(requestUUID, w, r) + require.NoError(t, err) + + // Create mock stream + ctx := context.Background() + mockStream := mock.NewMockLogStreamServer(ctx) + + // Add error data + mockStream.AddRecvData(&logstreamapi.LogStreamData{ + RequestUuid: requestUUID, + Error: "agent error occurred", + }) + + // Run StreamLogs + err = server.StreamLogs(mockStream) + assert.Error(t, err) + assert.Contains(t, err.Error(), "agent error occurred") + }) + + t.Run("stream with unknown request ID", func(t *testing.T) { + // Don't register HTTP session + ctx := context.Background() + mockStream := mock.NewMockLogStreamServer(ctx) + + // Add data with unknown request ID + mockStream.AddRecvData(&logstreamapi.LogStreamData{ + RequestUuid: "unknown-request", + Data: []byte("test log line\n"), + }) + + // Run StreamLogs + err := server.StreamLogs(mockStream) + assert.Error(t, err) + assert.Contains(t, err.Error(), "unknown request id") + }) + + t.Run("stream with nil message", func(t *testing.T) { + // Register HTTP session first + w := mock.NewMockHTTPResponseWriter() + r := httptest.NewRequest("GET", "/logs", nil) + err := server.RegisterHTTP(requestUUID, w, r) + require.NoError(t, err) + + // Create mock stream + ctx := context.Background() + mockStream := mock.NewMockLogStreamServer(ctx) + + // Add a nil message to the recv data + mockStream.AddRecvData(nil) + + // Run StreamLogs + err = server.StreamLogs(mockStream) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid log message") + }) + + t.Run("stream with context cancellation", func(t *testing.T) { + // Register HTTP session first + w := mock.NewMockHTTPResponseWriter() + r := httptest.NewRequest("GET", "/logs", nil) + err := server.RegisterHTTP(requestUUID, w, r) + require.NoError(t, err) + + // Create mock stream with already canceled context + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel immediately + mockStream := mock.NewMockLogStreamServer(ctx) + + // Run StreamLogs + err = server.StreamLogs(mockStream) + assert.Error(t, err) + assert.Contains(t, err.Error(), "client detached timeout") + }) +} + +func TestProcessLogMessage(t *testing.T) { + server := NewServer() + requestUUID := "test-request-123" + + // Register HTTP session first + w := mock.NewMockHTTPResponseWriter() + r := httptest.NewRequest("GET", "/logs", nil) + err := server.RegisterHTTP(requestUUID, w, r) + require.NoError(t, err) + + // Create log client + ctx := context.Background() + client := server.newLogClient(ctx) + client.requestID = requestUUID + + t.Run("successful log processing", func(t *testing.T) { + msg := &logstreamapi.LogStreamData{ + RequestUuid: requestUUID, + Data: []byte("test log line\n"), + } + + err := server.processLogMessage(client, msg) + assert.NoError(t, err) + + // Check that data was written + body := w.GetBody() + assert.Contains(t, body, "test log line") + }) + + t.Run("empty data (no-op)", func(t *testing.T) { + // Clear the body first + w.Reset() + + msg := &logstreamapi.LogStreamData{ + RequestUuid: requestUUID, + Data: []byte{}, // Empty data + } + + err := server.processLogMessage(client, msg) + assert.NoError(t, err) + + // Body should remain empty since empty data is ignored + body := w.GetBody() + assert.Empty(t, body) + }) + + t.Run("EOF handling", func(t *testing.T) { + msg := &logstreamapi.LogStreamData{ + RequestUuid: requestUUID, + Eof: true, + } + + err := server.processLogMessage(client, msg) + assert.Equal(t, io.EOF, err) + + // Check that completion channel was signaled + server.mu.RLock() + sess := server.sessions[requestUUID] + server.mu.RUnlock() + assert.NotNil(t, sess) + + // Wait for completion signal + select { + case <-sess.completeCh: + // Success + case <-time.After(100 * time.Millisecond): + t.Fatal("completion channel was not signaled") + } + }) + + t.Run("agent error", func(t *testing.T) { + msg := &logstreamapi.LogStreamData{ + RequestUuid: requestUUID, + Error: "agent error", + } + + err := server.processLogMessage(client, msg) + assert.Error(t, err) + assert.Contains(t, err.Error(), "agent error") + }) + + t.Run("unknown request ID", func(t *testing.T) { + msg := &logstreamapi.LogStreamData{ + RequestUuid: "unknown-request", + Data: []byte("test log line\n"), + } + + err := server.processLogMessage(client, msg) + assert.Error(t, err) + assert.Contains(t, err.Error(), "unknown request id") + }) + + t.Run("missing HTTP writer", func(t *testing.T) { + // Remove the HTTP writer + server.mu.Lock() + if sess, ok := server.sessions[requestUUID]; ok { + sess.hw = nil + } + server.mu.Unlock() + + msg := &logstreamapi.LogStreamData{ + RequestUuid: requestUUID, + Data: []byte("test log line\n"), + } + + err := server.processLogMessage(client, msg) + assert.Error(t, err) + assert.Contains(t, err.Error(), "client disconnected") + }) +} + +func TestWaitForCompletion(t *testing.T) { + server := NewServer() + requestUUID := "test-request-123" + + t.Run("completion within timeout", func(t *testing.T) { + // Register session + w := mock.NewMockHTTPResponseWriter() + r := httptest.NewRequest("GET", "/logs", nil) + err := server.RegisterHTTP(requestUUID, w, r) + require.NoError(t, err) + + // Signal completion + server.mu.RLock() + sess := server.sessions[requestUUID] + server.mu.RUnlock() + require.NotNil(t, sess) + + go func() { + time.Sleep(10 * time.Millisecond) + select { + case sess.completeCh <- true: + default: + } + }() + + // Wait for completion + completed := server.WaitForCompletion(requestUUID, 100*time.Millisecond) + assert.True(t, completed) + }) + + t.Run("timeout before completion", func(t *testing.T) { + // Register session + w := mock.NewMockHTTPResponseWriter() + r := httptest.NewRequest("GET", "/logs", nil) + err := server.RegisterHTTP(requestUUID, w, r) + require.NoError(t, err) + + // Wait for completion (should timeout) + completed := server.WaitForCompletion(requestUUID, 10*time.Millisecond) + assert.False(t, completed) + }) + + t.Run("unknown request ID", func(t *testing.T) { + completed := server.WaitForCompletion("unknown-request", 10*time.Millisecond) + assert.False(t, completed) + }) +} + +func TestFinalizeSession(t *testing.T) { + server := NewServer() + requestUUID := "test-request-123" + + // Register session + w := mock.NewMockHTTPResponseWriter() + r := httptest.NewRequest("GET", "/logs", nil) + err := server.RegisterHTTP(requestUUID, w, r) + require.NoError(t, err) + + // Verify session exists + server.mu.RLock() + _, exists := server.sessions[requestUUID] + server.mu.RUnlock() + assert.True(t, exists) + + // Finalize session + server.finalizeSession(requestUUID) + + // Verify session was removed + server.mu.RLock() + _, exists = server.sessions[requestUUID] + server.mu.RUnlock() + assert.False(t, exists) +} + +func TestSafeFlush(t *testing.T) { + t.Run("successful flush", func(t *testing.T) { + w := mock.NewMockHTTPResponseWriter() + err := safeFlush(w) + assert.NoError(t, err) + assert.True(t, w.IsFlushCalled()) + }) + + t.Run("flush with panic", func(t *testing.T) { + // Create a flusher that panics + panicFlusher := &mock.PanicFlusher{} + + // This should not panic due to the defer recover, but should return an error + err := safeFlush(panicFlusher) + assert.Error(t, err) // safeFlush catches panics and returns an error + assert.Contains(t, err.Error(), "flush panic") + }) +} + +func TestTryRecvWithCancel(t *testing.T) { + t.Run("successful receive", func(t *testing.T) { + ctx := context.Background() + called := false + fn := func() (string, error) { + called = true + return "test", nil + } + + result, err := tryRecvWithCancel(ctx, fn) + assert.NoError(t, err) + assert.Equal(t, "test", result) + assert.True(t, called) + }) + + t.Run("context cancellation", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel immediately + + fn := func() (string, error) { + time.Sleep(100 * time.Millisecond) // Simulate slow operation + return "test", nil + } + + result, err := tryRecvWithCancel(ctx, fn) + assert.Error(t, err) + assert.Equal(t, "", result) + assert.Contains(t, err.Error(), "client detached timeout") + }) + + t.Run("function error", func(t *testing.T) { + ctx := context.Background() + expectedErr := fmt.Errorf("function error") + fn := func() (string, error) { + return "", expectedErr + } + + result, err := tryRecvWithCancel(ctx, fn) + assert.Error(t, err) + assert.Equal(t, expectedErr, err) + assert.Equal(t, "", result) + }) +} + +func TestNewLogClient(t *testing.T) { + ctx := context.Background() + client := NewServer().newLogClient(ctx) + + assert.NotNil(t, client) + assert.NotNil(t, client.ctx) + assert.NotNil(t, client.cancelFn) + assert.NotNil(t, client.logCtx) + assert.Equal(t, "", client.requestID) + assert.Nil(t, client.terminateErr) + + // Test that context can be canceled + cancel := client.cancelFn + cancel() + select { + case <-client.ctx.Done(): + // Success + case <-time.After(100 * time.Millisecond): + t.Fatal("context was not canceled") + } +} + +func TestConcurrentAccess(t *testing.T) { + server := NewServer() + requestUUID := "test-request-123" + + // Register session + w := mock.NewMockHTTPResponseWriter() + r := httptest.NewRequest("GET", "/logs", nil) + err := server.RegisterHTTP(requestUUID, w, r) + require.NoError(t, err) + + // Test concurrent access to sessions map + var wg sync.WaitGroup + numGoroutines := 10 + + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + // Read from sessions + server.mu.RLock() + sess, exists := server.sessions[requestUUID] + server.mu.RUnlock() + + assert.True(t, exists) + assert.NotNil(t, sess) + + // Simulate some work + time.Sleep(1 * time.Millisecond) + }(i) + } + + wg.Wait() +} + +func init() { + // Set log level to reduce noise during testing + logrus.SetLevel(logrus.ErrorLevel) +} diff --git a/principal/apis/logstream/mock/mock.go b/principal/apis/logstream/mock/mock.go new file mode 100644 index 00000000..93b73935 --- /dev/null +++ b/principal/apis/logstream/mock/mock.go @@ -0,0 +1,187 @@ +// Copyright 2024 The argocd-agent Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + "fmt" + "io" + "net/http" + "strings" + "sync" + + "github.com/argoproj-labs/argocd-agent/pkg/api/grpc/logstreamapi" + "google.golang.org/grpc" +) + +// MockLogStreamServer implements grpc.ClientStreamingServer for testing +type MockLogStreamServer struct { + grpc.ServerStream + ctx context.Context + recvData []*logstreamapi.LogStreamData + recvIndex int + recvError error + sendError error + mu sync.Mutex + closed bool +} + +func NewMockLogStreamServer(ctx context.Context) *MockLogStreamServer { + return &MockLogStreamServer{ + ctx: ctx, + recvData: make([]*logstreamapi.LogStreamData, 0), + } +} + +func (m *MockLogStreamServer) Context() context.Context { + return m.ctx +} + +func (m *MockLogStreamServer) Recv() (*logstreamapi.LogStreamData, error) { + m.mu.Lock() + defer m.mu.Unlock() + + if m.closed { + return nil, io.EOF + } + + if m.recvError != nil { + return nil, m.recvError + } + + if m.recvIndex >= len(m.recvData) { + return nil, io.EOF + } + + data := m.recvData[m.recvIndex] + m.recvIndex++ + + // Return the data as-is, even if it's nil + return data, nil +} + +func (m *MockLogStreamServer) SendAndClose(resp *logstreamapi.LogStreamResponse) error { + m.mu.Lock() + defer m.mu.Unlock() + + if m.closed { + return fmt.Errorf("stream already closed") + } + + m.closed = true + return m.sendError +} + +func (m *MockLogStreamServer) AddRecvData(data *logstreamapi.LogStreamData) { + m.mu.Lock() + defer m.mu.Unlock() + m.recvData = append(m.recvData, data) +} + +func (m *MockLogStreamServer) SetRecvError(err error) { + m.mu.Lock() + defer m.mu.Unlock() + m.recvError = err +} + +func (m *MockLogStreamServer) SetSendError(err error) { + m.mu.Lock() + defer m.mu.Unlock() + m.sendError = err +} + +func (m *MockLogStreamServer) Close() { + m.mu.Lock() + defer m.mu.Unlock() + m.closed = true +} + +// MockHTTPResponseWriter implements http.ResponseWriter and http.Flusher for testing +type MockHTTPResponseWriter struct { + headers http.Header + body strings.Builder + statusCode int + flushCalled bool +} + +func NewMockHTTPResponseWriter() *MockHTTPResponseWriter { + return &MockHTTPResponseWriter{ + headers: make(http.Header), + } +} + +func (m *MockHTTPResponseWriter) Header() http.Header { + return m.headers +} + +func (m *MockHTTPResponseWriter) Write(data []byte) (int, error) { + return m.body.Write(data) +} + +func (m *MockHTTPResponseWriter) WriteHeader(statusCode int) { + m.statusCode = statusCode +} + +func (m *MockHTTPResponseWriter) Flush() { + m.flushCalled = true +} + +func (m *MockHTTPResponseWriter) GetBody() string { + return m.body.String() +} + +func (m *MockHTTPResponseWriter) GetStatusCode() int { + return m.statusCode +} + +func (m *MockHTTPResponseWriter) IsFlushCalled() bool { + return m.flushCalled +} + +func (m *MockHTTPResponseWriter) Reset() { + m.body.Reset() + m.flushCalled = false + m.statusCode = 0 + m.headers = make(http.Header) +} + +// PanicFlusher implements http.Flusher but panics on Flush() +type PanicFlusher struct{} + +func (p *PanicFlusher) Flush() { + panic("simulated panic") +} + +// MockWriterWithoutFlusher implements http.ResponseWriter but NOT http.Flusher +type MockWriterWithoutFlusher struct { + headers http.Header + body strings.Builder + statusCode int +} + +func (m *MockWriterWithoutFlusher) Header() http.Header { + if m.headers == nil { + m.headers = make(http.Header) + } + return m.headers +} + +func (m *MockWriterWithoutFlusher) Write(data []byte) (int, error) { + return m.body.Write(data) +} + +func (m *MockWriterWithoutFlusher) WriteHeader(statusCode int) { + m.statusCode = statusCode +} diff --git a/principal/listen.go b/principal/listen.go index bb133cc5..10c8bd6d 100644 --- a/principal/listen.go +++ b/principal/listen.go @@ -33,6 +33,7 @@ import ( "github.com/argoproj-labs/argocd-agent/internal/metrics" "github.com/argoproj-labs/argocd-agent/pkg/api/grpc/authapi" "github.com/argoproj-labs/argocd-agent/pkg/api/grpc/eventstreamapi" + "github.com/argoproj-labs/argocd-agent/pkg/api/grpc/logstreamapi" "github.com/argoproj-labs/argocd-agent/pkg/api/grpc/versionapi" "github.com/argoproj-labs/argocd-agent/principal/apis/auth" "github.com/argoproj-labs/argocd-agent/principal/apis/eventstream" @@ -218,5 +219,7 @@ func (s *Server) registerGrpcServices(metrics *metrics.PrincipalMetrics) error { authapi.RegisterAuthenticationServer(s.grpcServer, authSrv) versionapi.RegisterVersionServer(s.grpcServer, version.NewServer(s.authenticate)) eventstreamapi.RegisterEventStreamServer(s.grpcServer, eventstream.NewServer(s.queues, s.eventWriters, metrics, s.clusterMgr, eventstream.WithNotifyOnConnect(s.notifyOnConnect))) + // Proposal: register LogStream gRPC service for data-plane (use singleton instance) + logstreamapi.RegisterLogStreamServiceServer(s.grpcServer, s.logStream) return nil } diff --git a/principal/resource.go b/principal/resource.go index 1ac282c2..3441d982 100644 --- a/principal/resource.go +++ b/principal/resource.go @@ -19,10 +19,13 @@ import ( "fmt" "io" "net/http" + "strings" "time" "github.com/argoproj-labs/argocd-agent/internal/event" "github.com/argoproj-labs/argocd-agent/principal/resourceproxy" + cloudevents "github.com/cloudevents/sdk-go/v2" + "github.com/sirupsen/logrus" "k8s.io/apimachinery/pkg/api/validation" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) @@ -97,24 +100,108 @@ func (s *Server) processResourceRequest(w http.ResponseWriter, r *http.Request, w.WriteHeader(http.StatusInternalServerError) return } + defer func() { + if err := r.Body.Close(); err != nil { + logCtx.WithError(err).Error("Uh oh") + } + }() reqParams := map[string]string{} for k, v := range r.URL.Query() { reqParams[k] = v[0] } + requestedName := params.Get("name") + requestedNamespace := params.Get("namespace") + requestedSubresource := params.Get("subresource") + // Create the event - sentEv, err := s.events.NewResourceRequestEvent(gvr, params.Get("namespace"), params.Get("name"), params.Get("subresource"), r.Method, reqBody, reqParams) - if err != nil { - logCtx.Errorf("Could not create event: %v", err) - w.WriteHeader(http.StatusInternalServerError) - return + var sentEv *cloudevents.Event + if requestedSubresource == "log" { + if requestedNamespace == "" || requestedName == "" { + logCtx.WithFields(logrus.Fields{ + "namespace": requestedNamespace, + "pod": requestedName, + "params": reqParams, + "agent": agentName, + }).Error("Missing required parameters: namespace and pod are required") + http.Error(w, "Missing required parameters: namespace and pod", http.StatusBadRequest) + return + } + sentEv, err = s.events.NewLogRequestEvent(requestedNamespace, requestedName, r.Method, reqParams) + if err != nil { + logCtx.WithFields(logrus.Fields{ + "namespace": requestedNamespace, + "pod": requestedName, + "params": reqParams, + "agent": agentName, + }).Errorf("Could not create container log event: %v", err) + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + } else { + sentEv, err = s.events.NewResourceRequestEvent(gvr, requestedNamespace, requestedName, requestedSubresource, r.Method, reqBody, reqParams) + if err != nil { + logCtx.Errorf("Could not create event: %v", err) + w.WriteHeader(http.StatusInternalServerError) + return + } } // Remember the resource ID of the sent event sentUUID := event.EventID(sentEv) - // Start tracking the event, so we can later get the response + if requestedSubresource != "" { + logCtx.Infof("Proxying request for subresource %s of resource %s named %s/%s", requestedSubresource, gvr.String(), requestedNamespace, requestedName) + } else if requestedName != "" { + logCtx.Infof("Proxying request for resource of type %s named %s/%s", gvr.String(), requestedNamespace, requestedName) + } else { + logCtx.Infof("Proxying request for resources of type %s in namespace %s", gvr.String(), requestedNamespace) + } + + if requestedSubresource == "log" { + logCtx.WithFields(logrus.Fields{ + "namespace": requestedNamespace, + "pod": requestedName, + "params": reqParams, + "agent": agentName, + "uuid": string(sentUUID), + }).Info("Proxying pod log request") + if err := s.logStream.RegisterHTTP(sentUUID, w, r); err != nil { + logCtx.Errorf("Could not register HTTP writer for log streaming: %v", err) + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + // Submit the event to the queue + logCtx.Tracef("Submitting event: %v", sentEv) + q.Add(sentEv) + + // Decide static vs streaming based on follow=true + isStreaming := strings.EqualFold(reqParams["follow"], "true") + + if isStreaming { + // Keep handler alive until client disconnects + logCtx.WithField("uuid", string(sentUUID)).Info("Streaming logs: waiting for client disconnect") + <-r.Context().Done() + logCtx.WithField("uuid", string(sentUUID)).Info("Client disconnected; end streaming handler") + } else { + // Static logs: wait for completion signal from logStream + logCtx.WithField("uuid", string(sentUUID)).Info("Static logs: waiting for completion") + if ok := s.logStream.WaitForCompletion(sentUUID, requestTimeout); !ok { + logCtx.WithField("uuid", string(sentUUID)).Warn("Static logs timeout") + // Best-effort: the writer may have already sent partial data. + // Return 504 only if nothing has been sent yet. If RegisterHTTP + // streams early chunks, this will be ignored by client. + http.Error(w, "Timeout fetching logs from agent", http.StatusGatewayTimeout) + } + } + // IMPORTANT: do not enter the standard eventCh loop for log requests. + return + + } + + // Start tracking the event, so we can later for non log requests and get the response eventCh, err := s.resourceProxy.Track(sentUUID, agentName) if err != nil { logCtx.Errorf("Could not track event %s: %v", sentUUID, err) @@ -130,26 +217,9 @@ func (s *Server) processResourceRequest(w http.ResponseWriter, r *http.Request, logCtx.Tracef("Submitting event: %v", sentEv) q.Add(sentEv) - requestedName := params.Get("name") - requestedNamespace := params.Get("namespace") - requestedSubresource := params.Get("subresource") - - if requestedSubresource != "" { - logCtx.Infof("Proxying request for subresource %s of resource %s named %s/%s", requestedSubresource, gvr.String(), requestedNamespace, requestedName) - } else if requestedName != "" { - logCtx.Infof("Proxying request for resource of type %s named %s/%s", gvr.String(), requestedNamespace, requestedName) - } else { - logCtx.Infof("Proxying request for resources of type %s in namespace %s", gvr.String(), requestedNamespace) - } - // Wait for the event from the agent ctx, cancel := context.WithTimeout(s.ctx, requestTimeout) defer cancel() - defer func() { - if err := r.Body.Close(); err != nil { - logCtx.WithError(err).Error("Uh oh") - } - }() // The response is being read through a channel that is kept open and // written to by the resource proxy. diff --git a/principal/server.go b/principal/server.go index 5cb533e6..8bee1bd2 100644 --- a/principal/server.go +++ b/principal/server.go @@ -51,6 +51,7 @@ import ( "github.com/argoproj-labs/argocd-agent/internal/tlsutil" "github.com/argoproj-labs/argocd-agent/internal/version" "github.com/argoproj-labs/argocd-agent/pkg/types" + "github.com/argoproj-labs/argocd-agent/principal/apis/logstream" "github.com/argoproj-labs/argocd-agent/principal/redisproxy" "github.com/argoproj-labs/argocd-agent/principal/resourceproxy" "github.com/argoproj/argo-cd/v3/common" @@ -157,6 +158,7 @@ type Server struct { handlersOnConnect []handlersOnConnect eventWriters *event.EventWritersMap + logStream *logstream.Server } type handlersOnConnect func(agent types.Agent) error @@ -371,6 +373,7 @@ func NewServer(ctx context.Context, kubeClient *kube.KubernetesClient, namespace } s.resources = resources.NewAgentResources() + s.logStream = logstream.NewServer() return s, nil }