diff --git a/apmproxy/apmserver.go b/apmproxy/apmserver.go index 53db857b..9ea44862 100644 --- a/apmproxy/apmserver.go +++ b/apmproxy/apmserver.go @@ -63,7 +63,7 @@ func (c *Client) ForwardApmData(ctx context.Context) error { c.logger.Debugf("Received something from '%s' without APMData", data.AgentInfo) continue } - if err := c.forwardAgentData(ctx, data); err != nil { + if err := c.ForwardAgentData(ctx, data); err != nil { return err } if lambdaDataChan == nil { @@ -73,7 +73,7 @@ func (c *Client) ForwardApmData(ctx context.Context) error { c.logger.Debug("Assigned Lambda data channel") } case data := <-lambdaDataChan: - if err := c.forwardLambdaData(ctx, data); err != nil { + if err := c.ForwardLambdaData(ctx, data); err != nil { return err } } @@ -91,7 +91,7 @@ func (c *Client) FlushAPMData(ctx context.Context) { // Flush agent data first to make sure metadata is available if possible for i := len(c.AgentDataChannel); i > 0; i-- { data := <-c.AgentDataChannel - if err := c.forwardAgentData(ctx, data); err != nil { + if err := c.ForwardAgentData(ctx, data); err != nil { c.logger.Errorf("Error sending to APM Server, skipping: %v", err) } } @@ -107,7 +107,7 @@ func (c *Client) FlushAPMData(ctx context.Context) { for { select { case apmData := <-c.LambdaDataChannel: - if err := c.forwardLambdaData(ctx, apmData); err != nil { + if err := c.ForwardLambdaData(ctx, apmData); err != nil { c.logger.Errorf("Error sending to APM server, skipping: %v", err) } case <-ctx.Done(): @@ -349,7 +349,7 @@ func (c *Client) WaitForFlush() <-chan struct{} { return c.flushCh } -func (c *Client) forwardAgentData(ctx context.Context, apmData accumulator.APMData) error { +func (c *Client) ForwardAgentData(ctx context.Context, apmData accumulator.APMData) error { if err := c.batch.AddAgentData(apmData); err != nil { c.logger.Warnf("Dropping agent data due to error: %v", err) } @@ -359,7 +359,7 @@ func (c *Client) forwardAgentData(ctx context.Context, apmData accumulator.APMDa return nil } -func (c *Client) forwardLambdaData(ctx context.Context, data []byte) error { +func (c *Client) ForwardLambdaData(ctx context.Context, data []byte) error { if err := c.batch.AddLambdaData(data); err != nil { c.logger.Warnf("Dropping lambda data due to error: %v", err) } diff --git a/app/run.go b/app/run.go index 17936025..c20db2f2 100644 --- a/app/run.go +++ b/app/run.go @@ -98,7 +98,7 @@ func (app *App) Run(ctx context.Context) error { app.logger.Infof("Exiting due to shutdown event with reason %s", event.ShutdownReason) if app.logsClient != nil { // Flush buffered logs if any - app.logsClient.FlushData(ctx, event.RequestID, event.InvokedFunctionArn, app.apmClient.LambdaDataChannel, true) + app.logsClient.FlushData(ctx, event.RequestID, event.InvokedFunctionArn, app.apmClient.ForwardLambdaData, true) } // Since we have waited for the processEvent loop to finish we // already have received all the data we can from the agent. So, we @@ -131,7 +131,7 @@ func (app *App) Run(ctx context.Context) error { flushCtx, cancel := context.WithCancel(ctx) if app.logsClient != nil { // Flush buffered logs if any - app.logsClient.FlushData(ctx, event.RequestID, event.InvokedFunctionArn, app.apmClient.LambdaDataChannel, false) + app.logsClient.FlushData(ctx, event.RequestID, event.InvokedFunctionArn, app.apmClient.ForwardLambdaData, false) } // Flush APM data now that the function invocation has completed app.apmClient.FlushAPMData(flushCtx) @@ -213,7 +213,13 @@ func (app *App) processEvent( invocationCtx, event.RequestID, event.InvokedFunctionArn, - app.apmClient.LambdaDataChannel, + func(ctx context.Context, b []byte) error { + select { + case app.apmClient.LambdaDataChannel <- b: + case <-ctx.Done(): + } + return nil + }, event.EventType == extension.Shutdown, ) }() diff --git a/logsapi/event.go b/logsapi/event.go index 2dc79f60..ca3b9c7b 100644 --- a/logsapi/event.go +++ b/logsapi/event.go @@ -25,6 +25,8 @@ import ( // LogEventType represents the log type that is received in the log messages type LogEventType string +type Forwarder func(context.Context, []byte) error + const ( // PlatformRuntimeDone event is sent when lambda function is finished it's execution PlatformRuntimeDone LogEventType = "platform.runtimeDone" @@ -60,13 +62,13 @@ func (lc *Client) ProcessLogs( ctx context.Context, requestID string, invokedFnArn string, - dataChan chan []byte, + forwardFn Forwarder, isShutdown bool, ) { for { select { case logEvent := <-lc.logsChannel: - if shouldExit := lc.handleEvent(ctx, logEvent, requestID, invokedFnArn, dataChan, isShutdown); shouldExit { + if shouldExit := lc.handleEvent(ctx, logEvent, requestID, invokedFnArn, forwardFn, isShutdown); shouldExit { return } case <-ctx.Done(): @@ -80,14 +82,14 @@ func (lc *Client) FlushData( ctx context.Context, requestID string, invokedFnArn string, - dataChan chan []byte, + forwardFn Forwarder, isShutdown bool, ) { lc.logger.Infof("flushing %d buffered logs", len(lc.logsChannel)) for { select { case logEvent := <-lc.logsChannel: - if shouldExit := lc.handleEvent(ctx, logEvent, requestID, invokedFnArn, dataChan, isShutdown); shouldExit { + if shouldExit := lc.handleEvent(ctx, logEvent, requestID, invokedFnArn, forwardFn, isShutdown); shouldExit { return } case <-ctx.Done(): @@ -106,7 +108,7 @@ func (lc *Client) handleEvent(ctx context.Context, logEvent LogEvent, requestID string, invokedFnArn string, - dataChan chan []byte, + forwardFn Forwarder, isShutdown bool, ) bool { lc.logger.Debugf("Received log event %v for request ID %s", logEvent.Type, logEvent.Record.RequestID) @@ -139,9 +141,8 @@ func (lc *Client) handleEvent(ctx context.Context, if err != nil { lc.logger.Errorf("Error processing Lambda platform metrics: %v", err) } else { - select { - case dataChan <- processedMetrics: - case <-ctx.Done(): + if err := forwardFn(ctx, processedMetrics); err != nil { + lc.logger.Errorf("Error forwarding Lambda platform metrics: %v", err) } } } @@ -166,9 +167,8 @@ func (lc *Client) handleEvent(ctx context.Context, if err != nil { lc.logger.Warnf("Error processing function log : %v", err) } else { - select { - case dataChan <- processedLog: - case <-ctx.Done(): + if err := forwardFn(ctx, processedLog); err != nil { + lc.logger.Warnf("Error forwarding function log : %v", err) } } }