diff --git a/lambda-extensions/lambdaapi/extensionapiclient.go b/lambda-extensions/lambdaapi/extensionapiclient.go index ee63938..47828b3 100644 --- a/lambda-extensions/lambdaapi/extensionapiclient.go +++ b/lambda-extensions/lambdaapi/extensionapiclient.go @@ -46,15 +46,20 @@ const ( ) var ( - lambdaEvents = []EventType{"INVOKE", "SHUTDOWN"} + lambdaEvents = []EventType{"INVOKE", "SHUTDOWN"} + elevatorLambdaEvents = []EventType{"SHUTDOWN"} ) // RegisterExtension is to register extension to Run Time API client. Call the following method on initialization as early as possible, // otherwise you may get a timeout error. Runtime initialization will start after all extensions are registered. -func (client *Client) RegisterExtension(ctx context.Context) (*RegisterResponse, error) { +func (client *Client) RegisterExtension(ctx context.Context, isElevator bool) (*RegisterResponse, error) { URL := client.baseURL + extensionURL + "register" + events := lambdaEvents + if isElevator { + events = elevatorLambdaEvents + } reqBody, err := json.Marshal(map[string]interface{}{ - "events": lambdaEvents, + "events": events, }) if err != nil { return nil, err diff --git a/lambda-extensions/lambdaapi/extensionapiclient_test.go b/lambda-extensions/lambdaapi/extensionapiclient_test.go index a004dc8..421d3ee 100644 --- a/lambda-extensions/lambdaapi/extensionapiclient_test.go +++ b/lambda-extensions/lambdaapi/extensionapiclient_test.go @@ -34,11 +34,11 @@ func TestRegisterExtension(t *testing.T) { client := NewClient(srv.URL[7:], extensionName) // Without Context - response, err := client.RegisterExtension(context.TODO()) + response, err := client.RegisterExtension(context.TODO(), false) commonAsserts(t, client, response, err) // With Context - response, err = client.RegisterExtension(context.Background()) + response, err = client.RegisterExtension(context.Background(), false) commonAsserts(t, client, response, err) } diff --git a/lambda-extensions/lambdaapi/telemetryapiclient.go b/lambda-extensions/lambdaapi/telemetryapiclient.go index 65c4336..e35f0f6 100644 --- a/lambda-extensions/lambdaapi/telemetryapiclient.go +++ b/lambda-extensions/lambdaapi/telemetryapiclient.go @@ -15,14 +15,17 @@ const ( ) // SubscribeToLogsAPI is - Subscribe to Logs API to receive the Lambda Logs. -func (client *Client) SubscribeToTelemetryAPI(ctx context.Context, logEvents []string, telemetryTimeoutMs int, telemetryMaxBytes int64, telemetryMaxItems int) ([]byte, error) { +func (client *Client) SubscribeToTelemetryAPI(ctx context.Context, logEvents []string, telemetryTimeoutMs int, telemetryMaxBytes int64, telemetryMaxItems int, isElevator bool) ([]byte, error) { URL := client.baseURL + telemetryURL - + schemaVersion := "2022-07-01" + if isElevator { + schemaVersion = "2025-01-29" + } reqBody, err := json.Marshal(map[string]interface{}{ "destination": map[string]interface{}{"protocol": "HTTP", "URI": fmt.Sprintf("http://sandbox:%v", receiverPort)}, "types": logEvents, "buffering": map[string]interface{}{"timeoutMs": telemetryTimeoutMs, "maxBytes": telemetryMaxBytes, "maxItems": telemetryMaxItems}, - "schemaVersion": "2022-07-01", + "schemaVersion": schemaVersion, }) if err != nil { return nil, err diff --git a/lambda-extensions/lambdaapi/telemetryapiclient_test.go b/lambda-extensions/lambdaapi/telemetryapiclient_test.go index 58a7ce4..62c4e1b 100644 --- a/lambda-extensions/lambdaapi/telemetryapiclient_test.go +++ b/lambda-extensions/lambdaapi/telemetryapiclient_test.go @@ -31,10 +31,10 @@ func TestSubscribeToTelemetryAPI(t *testing.T) { client := NewClient(srv.URL[7:], extensionName) // Without Context - response, err := client.SubscribeToTelemetryAPI(context.TODO(), []string{"platform", "function", "extension"}, 1000, 262144, 10000) + response, err := client.SubscribeToTelemetryAPI(context.TODO(), []string{"platform", "function", "extension"}, 1000, 262144, 10000, false) commonAsserts(t, client, response, err) // With Context - response, err = client.SubscribeToTelemetryAPI(context.Background(), []string{"platform", "function", "extension"}, 1000, 262144, 10000) + response, err = client.SubscribeToTelemetryAPI(context.Background(), []string{"platform", "function", "extension"}, 1000, 262144, 10000, false) commonAsserts(t, client, response, err) } diff --git a/lambda-extensions/sumologic-extension.go b/lambda-extensions/sumologic-extension.go index 84fa399..9442a3b 100644 --- a/lambda-extensions/sumologic-extension.go +++ b/lambda-extensions/sumologic-extension.go @@ -26,8 +26,12 @@ var ( var producer workers.TaskProducer var consumer workers.TaskConsumer +var elevatorProducer workers.ElevatorTaskProducer +var elevatorConsumer workers.ElevatorTaskConsumer var config *cfg.LambdaExtensionConfig var dataQueue chan []byte +var flushSignal chan string +var isElevator bool func init() { Formatter := new(logrus.TextFormatter) @@ -47,22 +51,52 @@ func init() { logger.Logger.SetLevel(config.LogLevel) dataQueue = make(chan []byte, config.MaxDataQueueLength) - // Start HTTP Server before subscription in a goRoutine - producer = workers.NewTaskProducer(dataQueue, logger) - go func() { - if err := producer.Start(); err != nil { - logger.Errorf("producer Start failed: %v", err) - } - }() + // Check initialization type to determine if elevator mode should be used + initializationType := os.Getenv("AWS_LAMBDA_INITIALIZATION_TYPE") + if initializationType == "lambda-managed-instances" { + isElevator = true + logger.Debug("Initializing in Elevator mode") + + // Initialize flushSignal channel for elevator mode communication + flushSignal = make(chan string, 10) // Buffered channel to prevent blocking + + // Initialize Elevator Producer and start it in a goroutine + elevatorProducer = workers.NewElevatorTaskProducer(dataQueue, flushSignal, logger) + go func() { + if err := elevatorProducer.Start(); err != nil { + logger.Errorf("elevatorProducer Start failed: %v", err) + } + }() + + // Initialize Elevator Consumer and start it + elevatorConsumer = workers.NewElevatorTaskConsumer(dataQueue, flushSignal, config, logger) + // Start the consumer's independent processing loop + ctx := context.Background() + elevatorConsumer.Start(ctx) + + logger.Debug("Elevator mode initialization complete") + } else { + logger.Debug("Initializing in standard mode") + // Start HTTP Server before subscription in a goRoutine + producer = workers.NewTaskProducer(dataQueue, logger) + go func() { + if err := producer.Start(); err != nil { + logger.Errorf("producer Start failed: %v", err) + } + }() - // Creating SumoTaskConsumer - consumer = workers.NewTaskConsumer(dataQueue, config, logger) + // Creating SumoTaskConsumer + consumer = workers.NewTaskConsumer(dataQueue, config, logger) + logger.Debug("Standard mode initialization complete") + } + + logger.Debug("Is Elevator value: ", isElevator) } func runTimeAPIInit() (int64, error) { // Register early so Runtime could start in parallel logger.Debug("Registering Extension to Run Time API Client..........") - registerResponse, err := extensionClient.RegisterExtension(context.TODO()) + registerResponse, err := extensionClient.RegisterExtension(context.TODO(), isElevator) if err != nil { return 0, err } @@ -70,7 +104,7 @@ func runTimeAPIInit() (int64, error) { // Subscribe to Telemetry API logger.Debug("Subscribing Extension to Telemetry API........") - subscribeResponse, err := extensionClient.SubscribeToTelemetryAPI(context.TODO(), config.LogTypes, config.TelemetryTimeoutMs, config.TelemetryMaxBytes, config.TelemetryMaxItems) + subscribeResponse, err := extensionClient.SubscribeToTelemetryAPI(context.TODO(), config.LogTypes, config.TelemetryTimeoutMs, config.TelemetryMaxBytes, config.TelemetryMaxItems, isElevator) if err != nil { return 0, err } @@ -78,11 +112,14 @@ func runTimeAPIInit() (int64, error) { logger.Debug("Successfully subscribed to Telemetry API: ", utils.PrettyPrint(string(subscribeResponse))) // Call next to say registration is successful and get the deadtimems - nextResponse, err := nextEvent(context.TODO()) - if err != nil { - return 0, err + if !isElevator { + nextResponse, err := nextEvent(context.TODO()) + if err != nil { + return 0, err + } + return nextResponse.DeadlineMs, nil } - return nextResponse.DeadlineMs, nil + return 0, nil } func nextEvent(ctx context.Context) (*lambdaapi.NextEventResponse, error) { @@ -109,15 +146,17 @@ func processEvents(ctx context.Context) { consumer.FlushDataQueue(ctx) return default: - logger.Debugf("switching to other go routine") - runtime.Gosched() - logger.Infof("Calling DrainQueue from processEvents") - // for { - runtime_done := consumer.DrainQueue(ctx) - - if runtime_done == 1 { - logger.Infof("Exiting DrainQueueLoop: Runtime is Done") + if !isElevator { + logger.Debugf("switching to other go routine") + runtime.Gosched() + logger.Infof("Calling DrainQueue from processEvents") + // for { + runtime_done := consumer.DrainQueue(ctx) + if runtime_done == 1 { + logger.Infof("Exiting DrainQueueLoop: Runtime is Done") + } } + // } // This statement will freeze lambda diff --git a/lambda-extensions/workers/elevatorConsumer.go b/lambda-extensions/workers/elevatorConsumer.go new file mode 100644 index 0000000..13abb53 --- /dev/null +++ b/lambda-extensions/workers/elevatorConsumer.go @@ -0,0 +1,164 @@ +package workers + +import ( + "context" + + cfg "github.com/SumoLogic/sumologic-lambda-extensions/lambda-extensions/config" + sumocli "github.com/SumoLogic/sumologic-lambda-extensions/lambda-extensions/sumoclient" + + "github.com/sirupsen/logrus" +) + +// ElevatorTaskConsumer exposes methods for consuming tasks in elevator mode +type ElevatorTaskConsumer interface { + Start(context.Context) + FlushDataQueue(context.Context) + DrainQueue(context.Context) int +} + +// elevatorSumoConsumer drains log from dataQueue in elevator mode +type elevatorSumoConsumer struct { + dataQueue chan []byte + flushSignal chan string + logger *logrus.Entry + config *cfg.LambdaExtensionConfig + sumoclient sumocli.LogSender +} + +// NewElevatorTaskConsumer returns a new elevator consumer +// flushSignal channel is used to receive signals from producer to trigger flushing +func NewElevatorTaskConsumer(consumerQueue chan []byte, flushSignal chan string, config *cfg.LambdaExtensionConfig, logger *logrus.Entry) ElevatorTaskConsumer { + return &elevatorSumoConsumer{ + dataQueue: consumerQueue, + flushSignal: flushSignal, + logger: logger, + sumoclient: sumocli.NewLogSenderClient(logger, config), + config: config, + } +} + +// Start starts the elevator consumer in a goroutine to listen for flush signals independently +func (esc *elevatorSumoConsumer) Start(ctx context.Context) { + esc.logger.Info("Starting Elevator Consumer") + go esc.processFlushSignals(ctx) +} + +// processFlushSignals continuously listens for flush signals and triggers queue draining +// This runs independently without needing callbacks from main thread +func (esc *elevatorSumoConsumer) processFlushSignals(ctx context.Context) { + esc.logger.Info("Elevator Consumer: Started listening for flush signals") + + for { + select { + case <-ctx.Done(): + esc.logger.Info("Elevator Consumer: Context cancelled, flushing remaining data") + esc.FlushDataQueue(ctx) + return + + case signal := <-esc.flushSignal: + esc.logger.Infof("Elevator Consumer: Received flush signal: %s", signal) + + switch signal { + case "queue_threshold": + esc.logger.Info("Elevator Consumer: Draining queue due to 80% threshold") + esc.DrainQueue(ctx) + + case "platform.report": + esc.logger.Info("Elevator Consumer: Draining queue due to platform.report event") + esc.DrainQueue(ctx) + + default: + esc.logger.Warnf("Elevator Consumer: Unknown flush signal received: %s", signal) + } + } + } +} + +// FlushDataQueue drains the dataqueue completely (called during shutdown) +func (esc *elevatorSumoConsumer) FlushDataQueue(ctx context.Context) { + esc.logger.Info("Elevator Consumer: Flushing DataQueue") + + if esc.config.EnableFailover { + var rawMsgArr [][]byte + Loop: + for { + select { + case rawmsg := <-esc.dataQueue: + rawMsgArr = append(rawMsgArr, rawmsg) + default: + if len(rawMsgArr) > 0 { + err := esc.sumoclient.FlushAll(rawMsgArr) + if err != nil { + esc.logger.Errorln("Elevator Consumer: Unable to flush DataQueue", err.Error()) + // putting back all the msg to the queue in case of failure + for _, msg := range rawMsgArr { + select { + case esc.dataQueue <- msg: + default: + esc.logger.Warnf("Elevator Consumer: Failed to requeue message, queue full") + } + } + } else { + esc.logger.Infof("Elevator Consumer: Successfully flushed %d messages", len(rawMsgArr)) + } + } + close(esc.dataQueue) + esc.logger.Debugf("Elevator Consumer: DataQueue completely drained and closed") + break Loop + } + } + } else { + // calling drainqueue (during shutdown) if failover is not enabled + maxCallsNeededForCompleteDraining := (len(esc.dataQueue) / esc.config.MaxConcurrentRequests) + 1 + for i := 0; i < maxCallsNeededForCompleteDraining; i++ { + esc.DrainQueue(ctx) + } + esc.logger.Info("Elevator Consumer: DataQueue drained without failover") + } +} + +// DrainQueue drains the current contents of the queue +func (esc *elevatorSumoConsumer) DrainQueue(ctx context.Context) int { + esc.logger.Debug("Elevator Consumer: Draining data from dataQueue") + + var rawMsgArr [][]byte + var logsStr string + var runtime_done = 0 + + // Collect all available messages from the queue +Loop: + for { + select { + case rawmsg := <-esc.dataQueue: + rawMsgArr = append(rawMsgArr, rawmsg) + logsStr = string(rawmsg) + esc.logger.Debugf("Elevator Consumer: DrainQueue: logsStr length: %d", len(logsStr)) + + default: + // No more messages in queue, send what we have + if len(rawMsgArr) > 0 { + esc.logger.Infof("Elevator Consumer: Sending %d messages to Sumo Logic", len(rawMsgArr)) + err := esc.sumoclient.SendAllLogs(ctx, rawMsgArr) + if err != nil { + esc.logger.Errorln("Elevator Consumer: Unable to send logs to Sumo Logic", err.Error()) + // putting back all the msg to the queue in case of failure + for _, msg := range rawMsgArr { + select { + case esc.dataQueue <- msg: + default: + esc.logger.Warn("Elevator Consumer: Failed to requeue message, queue full") + } + } + } else { + esc.logger.Infof("Elevator Consumer: Successfully sent %d messages", len(rawMsgArr)) + } + } else { + esc.logger.Debug("Elevator Consumer: No messages to drain") + } + break Loop + } + } + + esc.logger.Debugf("Elevator Consumer: DrainQueue complete. Runtime done: %d", runtime_done) + return runtime_done +} diff --git a/lambda-extensions/workers/elevatorProducer.go b/lambda-extensions/workers/elevatorProducer.go new file mode 100644 index 0000000..8d97058 --- /dev/null +++ b/lambda-extensions/workers/elevatorProducer.go @@ -0,0 +1,144 @@ +package workers + +import ( + "encoding/json" + "fmt" + ioutil "io" + "log" + "net/http" + + "github.com/sirupsen/logrus" +) + +const ( + // elevatorReceiverIP is Web Server Constants for elevator mode + elevatorReceiverIP = "0.0.0.0" + // elevatorReceiverPort is Web Server Constants for elevator mode + elevatorReceiverPort = 4243 + // queueThresholdPercent is the threshold percentage for triggering flush + queueThresholdPercent = 0.8 +) + +// ElevatorTaskProducer exposes methods for producing tasks in elevator mode +type ElevatorTaskProducer interface { + Start() error +} + +type elevatorHttpServer struct { + dataQueue chan []byte + logger *logrus.Entry + flushSignal chan string // Signal channel to notify consumer to flush +} + +type Event struct { + Time string `json:"time"` + Type string `json:"type"` + Record json.RawMessage `json:"record"` +} + +// NewElevatorTaskProducer returns a new elevator producer object +// flushSignal channel is used to signal consumer when queue is 80% full or platform.report is received +func NewElevatorTaskProducer(consumerQueue chan []byte, flushSignal chan string, logger *logrus.Entry) ElevatorTaskProducer { + return &elevatorHttpServer{ + dataQueue: consumerQueue, + logger: logger, + flushSignal: flushSignal, + } +} + +// Start starts the HTTP Server for elevator mode +func (ehs *elevatorHttpServer) Start() error { + http.HandleFunc("/", ehs.logsHandler) + ehs.logger.Info("Starting Elevator HTTP Server on port ", elevatorReceiverPort) + err := http.ListenAndServe(fmt.Sprintf("%s:%d", elevatorReceiverIP, elevatorReceiverPort), nil) + if err != nil { + ehs.logger.Errorf("Elevator HTTP server failed to start: %v", err) + panic(err) + } + return err +} + +// checkQueueThreshold checks if dataQueue has reached 80% capacity and signals consumer +func (ehs *elevatorHttpServer) checkQueueThreshold() { + queueLen := len(ehs.dataQueue) + queueCap := cap(ehs.dataQueue) + threshold := int(float64(queueCap) * queueThresholdPercent) + + ehs.logger.Debugf("Elevator Producer: Queue status - Length: %d, Capacity: %d, Threshold: %d", queueLen, queueCap, threshold) + + if queueLen >= threshold { + ehs.logger.Infof("Elevator Producer: Queue reached %d%% capacity (%d/%d), signaling consumer to flush", + int(queueThresholdPercent*100), queueLen, queueCap) + // Send flush signal to consumer (non-blocking) + select { + case ehs.flushSignal <- "queue_threshold": + ehs.logger.Debugf("Elevator Producer: Sent queue_threshold signal to consumer") + default: + ehs.logger.Warnf("Elevator Producer: Flush signal channel full, signal dropped") + } + } +} + +// logsHandler is Server Implementation to get Logs from logs API for elevator mode +func (ehs *elevatorHttpServer) logsHandler(writer http.ResponseWriter, request *http.Request) { + if request.URL.Path != "/" { + http.NotFound(writer, request) + return + } + switch request.Method { + case "POST": + defer func() { + if err := request.Body.Close(); err != nil { + ehs.logger.Errorf("failed to close body: %v", err) + } + }() + + reqBody, err := ioutil.ReadAll(request.Body) + if err != nil { + ehs.logger.Error("Read from Logs API failed: ", err.Error()) + writer.WriteHeader(http.StatusInternalServerError) + return + } + + ehs.logger.Debugf("Elevator Producer: Producing data into dataQueue - %d bytes\n", len(reqBody)) + payload := []byte(reqBody) + + // Send payload to dataQueue (non-blocking to prevent deadlock) + select { + case ehs.dataQueue <- payload: + ehs.logger.Debugf("Elevator Producer: Successfully queued data") + default: + ehs.logger.Warnf("Elevator Producer: dataQueue is full, dropping message") + } + + // Check if queue has reached 80% capacity after adding data + ehs.checkQueueThreshold() + + // Parse events and check for platform.report + var events []Event + err = json.Unmarshal(reqBody, &events) + if err != nil { + log.Printf("Elevator Producer: Error parsing JSON: %v", err) + } else { + ehs.logger.Debugf("Elevator Producer: Parsed %d events from telemetry payload\n", len(events)) + + // Check for platform.report type + for _, event := range events { + if event.Type == "platform.report" { + ehs.logger.Infof("Elevator Producer: Found platform.report event at time: %s\n", event.Time) + // Send platform.report signal to consumer (non-blocking) + select { + case ehs.flushSignal <- "platform.report": + ehs.logger.Debugf("Elevator Producer: Sent platform.report signal to consumer") + default: + ehs.logger.Warnf("Elevator Producer: Flush signal channel full, signal dropped") + } + } + } + } + + writer.WriteHeader(http.StatusOK) + default: + http.Error(writer, "Method not allowed", http.StatusMethodNotAllowed) + } +}