diff --git a/packages/kafka/lib/AbstractKafkaConsumer.ts b/packages/kafka/lib/AbstractKafkaConsumer.ts index a64d5e0d..cc69a009 100644 --- a/packages/kafka/lib/AbstractKafkaConsumer.ts +++ b/packages/kafka/lib/AbstractKafkaConsumer.ts @@ -149,11 +149,12 @@ export abstract class AbstractKafkaConsumer< if (!this.consumerStream && !this.messageBatchStream) return false try { return this.consumer.isConnected() + /* v8 ignore start */ } catch (_) { // this should not happen, but if so it means the consumer is not healthy - /* v8 ignore next */ return false } + /* v8 ignore stop */ } /** @@ -165,11 +166,12 @@ export abstract class AbstractKafkaConsumer< if (!this.consumerStream && !this.messageBatchStream) return false try { return this.consumer.isActive() + /* v8 ignore start */ } catch (_) { // this should not happen, but if so it means the consumer is not healthy - /* v8 ignore next */ return false } + /* v8 ignore stop */ } async init(): Promise { @@ -188,14 +190,19 @@ export abstract class AbstractKafkaConsumer< }) this.consumerStream = await this.consumer.consume({ ...consumeOptions, topics }) + this.consumerStream.on('error', (error) => this.handlerError(error)) + if (this.options.batchProcessingEnabled && this.options.batchProcessingOptions) { this.messageBatchStream = new KafkaMessageBatchStream< DeserializedMessage> - >({ - batchSize: this.options.batchProcessingOptions.batchSize, - timeoutMilliseconds: this.options.batchProcessingOptions.timeoutMilliseconds, - }) + >( + (batch) => + this.consume(batch.topic, batch.messages).catch((error) => this.handlerError(error)), + this.options.batchProcessingOptions, + ) this.consumerStream.pipe(this.messageBatchStream) + } else { + this.handleSyncStream(this.consumerStream).catch((error) => this.handlerError(error)) } } catch (error) { throw new InternalError({ @@ -204,14 +211,6 @@ export abstract class AbstractKafkaConsumer< cause: error, }) } - - if (this.options.batchProcessingEnabled && this.messageBatchStream) { - this.handleSyncStreamBatch(this.messageBatchStream).catch((error) => this.handlerError(error)) - } else { - this.handleSyncStream(this.consumerStream).catch((error) => this.handlerError(error)) - } - - this.consumerStream.on('error', (error) => this.handlerError(error)) } private async handleSyncStream( @@ -224,16 +223,6 @@ export abstract class AbstractKafkaConsumer< ) } } - private async handleSyncStreamBatch( - stream: KafkaMessageBatchStream>>, - ): Promise { - for await (const messageBatch of stream) { - await this.consume( - messageBatch.topic, - messageBatch.messages as DeserializedMessage>, - ) - } - } async close(): Promise { if (!this.consumerStream && !this.messageBatchStream) { @@ -291,7 +280,6 @@ export abstract class AbstractKafkaConsumer< const firstMessage = validMessages[0]! const requestContext = this.getRequestContext(firstMessage) - /* v8 ignore next */ const transactionId = randomUUID() this.transactionObservabilityManager?.start(this.buildTransactionName(topic), transactionId) diff --git a/packages/kafka/lib/utils/KafkaMessageBatchStream.spec.ts b/packages/kafka/lib/utils/KafkaMessageBatchStream.spec.ts index 8e604e5a..ccc88e03 100644 --- a/packages/kafka/lib/utils/KafkaMessageBatchStream.spec.ts +++ b/packages/kafka/lib/utils/KafkaMessageBatchStream.spec.ts @@ -1,3 +1,5 @@ +import { setTimeout } from 'node:timers/promises' +import { waitAndRetry } from '@lokalise/universal-ts-utils/node' import { KafkaMessageBatchStream, type MessageBatch } from './KafkaMessageBatchStream.ts' describe('KafkaMessageBatchStream', () => { @@ -12,22 +14,27 @@ describe('KafkaMessageBatchStream', () => { })) // When - const batchStream = new KafkaMessageBatchStream({ - batchSize: 3, - timeoutMilliseconds: 10000, - }) // Setting big timeout to check batch size only - const receivedBatches: MessageBatch[] = [] - const dataFetchingPromise = new Promise((resolve) => { - batchStream.on('data', (batch) => { + let resolvePromise: () => void + const dataFetchingPromise = new Promise((resolve) => { + resolvePromise = resolve + }) + + const batchStream = new KafkaMessageBatchStream( + (batch) => { receivedBatches.push(batch) - // We expect 3 batches and last message waiting in the stream + // We expect 3 batches and the last message waiting in the stream if (receivedBatches.length >= 3) { - resolve(null) + resolvePromise() } - }) - }) + return Promise.resolve() + }, + { + batchSize: 3, + timeoutMilliseconds: 10000, + }, + ) // Setting big timeout to check batch size only for (const message of messages) { batchStream.write(message) @@ -54,24 +61,25 @@ describe('KafkaMessageBatchStream', () => { })) // When - const batchStream = new KafkaMessageBatchStream({ - batchSize: 1000, - timeoutMilliseconds: 500, - }) // Setting big batch size to check timeout only - const receivedBatches: MessageBatch[] = [] - batchStream.on('data', (batch) => { - receivedBatches.push(batch) - }) + + const batchStream = new KafkaMessageBatchStream( + (batch) => { + receivedBatches.push(batch) + return Promise.resolve() + }, + { + batchSize: 1000, + timeoutMilliseconds: 100, + }, + ) // Setting big batch size to check timeout only for (const message of messages) { batchStream.write(message) } - // Sleep 1 seconds to let the timeout trigger - await new Promise((resolve) => { - setTimeout(resolve, 1000) - }) + // Sleep to let the timeout trigger + await setTimeout(150) // Then expect(receivedBatches).toEqual([{ topic, partition: 0, messages }]) @@ -104,16 +112,16 @@ describe('KafkaMessageBatchStream', () => { ] // When - const batchStream = new KafkaMessageBatchStream<{ topic: string; partition: number }>({ - batchSize: 2, - timeoutMilliseconds: 10000, - }) // Setting big timeout to check batch size only - const receivedBatchesByTopicPartition: Record = {} - let receivedMessagesCounter = 0 - const dataFetchingPromise = new Promise((resolve) => { - batchStream.on('data', (batch) => { + + let resolvePromise: () => void + const dataFetchingPromise = new Promise((resolve) => { + resolvePromise = resolve + }) + + const batchStream = new KafkaMessageBatchStream<{ topic: string; partition: number }>( + (batch) => { const key = `${batch.topic}:${batch.partition}` if (!receivedBatchesByTopicPartition[key]) { receivedBatchesByTopicPartition[key] = [] @@ -123,10 +131,16 @@ describe('KafkaMessageBatchStream', () => { // We expect 5 batches and last message waiting in the stream receivedMessagesCounter++ if (receivedMessagesCounter >= 5) { - resolve(null) + resolvePromise() } - }) - }) + + return Promise.resolve() + }, + { + batchSize: 2, + timeoutMilliseconds: 10000, + }, + ) // Setting big timeout to check batch size only for (const message of messages) { batchStream.write(message) @@ -177,25 +191,31 @@ describe('KafkaMessageBatchStream', () => { ] // When - const batchStream = new KafkaMessageBatchStream<{ topic: string; partition: number }>({ - batchSize: 2, - timeoutMilliseconds: 10000, - }) // Setting big timeout to check batch size only - const receivedBatches: any[] = [] - let receivedBatchesCounter = 0 - const dataFetchingPromise = new Promise((resolve) => { - batchStream.on('data', (batch) => { + + let resolvePromise: () => void + const dataFetchingPromise = new Promise((resolve) => { + resolvePromise = resolve + }) + + const batchStream = new KafkaMessageBatchStream<{ topic: string; partition: number }>( + (batch) => { receivedBatches.push(batch) // We expect 4 batches (2 per partition) receivedBatchesCounter++ if (receivedBatchesCounter >= 4) { - resolve(null) + resolvePromise() } - }) - }) + + return Promise.resolve() + }, + { + batchSize: 2, + timeoutMilliseconds: 10000, + }, + ) // Setting big timeout to check batch size only for (const message of messages) { batchStream.write(message) @@ -211,4 +231,69 @@ describe('KafkaMessageBatchStream', () => { { topic, partition: 1, messages: [messages[5], messages[7]] }, ]) }) + + it('should handle backpressure correctly when timeout flush is slow', async () => { + // Given + const topic = 'test-topic' + const messages = Array.from({ length: 6 }, (_, i) => ({ + id: i + 1, + content: `Message ${i + 1}`, + topic, + partition: 0, + })) + + const batchStartTimes: number[] = [] // Track start times of batch processing + const batchEndTimes: number[] = [] // Track end times of batch processing + const batchMessageCounts: number[] = [] // Track number of messages per batch + let maxConcurrentBatches = 0 // Track max concurrent batches + + let batchesProcessing = 0 + const batchStream = new KafkaMessageBatchStream( + async (batch) => { + batchStartTimes.push(Date.now()) + batchMessageCounts.push(batch.messages.length) + + batchesProcessing++ + maxConcurrentBatches = Math.max(maxConcurrentBatches, batchesProcessing) + + // Simulate batch processing (50ms per batch) + await setTimeout(50) + + batchEndTimes.push(Date.now()) + batchesProcessing-- + }, + { + batchSize: 1000, // Large batch size to never trigger size-based flushing + timeoutMilliseconds: 10, // Short timeout to trigger flush after each message + }, + ) + + // When: Write messages with 20ms delay between them + // Since processing (50ms) is slower than message arrival + timeout, backpressure causes accumulation + for (const message of messages) { + batchStream.write(message) + await setTimeout(20) + } + + // Then + // Wait until all 3 batches have been processed + await waitAndRetry(() => batchMessageCounts.length >= 3, 500, 20) + + // Backpressure causes messages to accumulate while previous batch processes: + // - Batch 1: Message 1 (flushed at 10ms timeout) + // - Batch 2: Messages 2-4 (accumulated during Batch 1 processing, including Message 4 arriving at ~60ms) + // - Batch 3: Messages 5-6 (accumulated during Batch 2 processing) + expect(batchMessageCounts).toEqual([1, 3, 2]) + + // Verify that batches never processed in parallel (backpressure working) + expect(maxConcurrentBatches).toBe(1) // Should never process more than 1 batch at a time + + // Verify that batches were processed sequentially (each starts after previous ends) + for (let i = 1; i < batchStartTimes.length; i++) { + const previousEndTime = batchEndTimes[i - 1] + const currentStartTime = batchStartTimes[i] + // The current batch must start after the previous batch finished + expect(currentStartTime).toBeGreaterThanOrEqual(previousEndTime ?? 0) + } + }) }) diff --git a/packages/kafka/lib/utils/KafkaMessageBatchStream.ts b/packages/kafka/lib/utils/KafkaMessageBatchStream.ts index 5ae29df8..a6c43d0a 100644 --- a/packages/kafka/lib/utils/KafkaMessageBatchStream.ts +++ b/packages/kafka/lib/utils/KafkaMessageBatchStream.ts @@ -1,6 +1,4 @@ -import { Duplex } from 'node:stream' - -type CallbackFunction = (error?: Error | null) => void +import { Transform } from 'node:stream' // Topic and partition are required for the stream to work properly type MessageWithTopicAndPartition = { topic: string; partition: number } @@ -11,103 +9,114 @@ export type KafkaMessageBatchOptions = { } export type MessageBatch = { topic: string; partition: number; messages: TMessage[] } - -export interface KafkaMessageBatchStream - extends Duplex { - // biome-ignore lint/suspicious/noExplicitAny: compatible with Duplex definition - on(event: string | symbol, listener: (...args: any[]) => void): this - on(event: 'data', listener: (chunk: MessageBatch) => void): this - - push(chunk: MessageBatch | null): boolean -} +export type OnMessageBatchCallback = (batch: MessageBatch) => Promise /** * Collects messages in batches based on provided batchSize and flushes them when messages amount or timeout is reached. + * + * This implementation uses Transform stream which properly handles backpressure by design. + * When the downstream consumer is slow, the stream will automatically pause accepting new messages + * until the consumer catches up, preventing memory leaks and OOM errors. */ -// biome-ignore lint/suspicious/noUnsafeDeclarationMerging: merging interface with class to add strong typing for 'data' event -export class KafkaMessageBatchStream extends Duplex { +export class KafkaMessageBatchStream< + TMessage extends MessageWithTopicAndPartition, +> extends Transform { + private readonly onBatch: OnMessageBatchCallback private readonly batchSize: number private readonly timeout: number private readonly currentBatchPerTopicPartition: Record private readonly batchTimeoutPerTopicPartition: Record - constructor(options: { batchSize: number; timeoutMilliseconds: number }) { + private readonly timeoutProcessingPromises: Map> = new Map() + + constructor( + onBatch: OnMessageBatchCallback, + options: { batchSize: number; timeoutMilliseconds: number }, + ) { super({ objectMode: true }) + this.onBatch = onBatch this.batchSize = options.batchSize this.timeout = options.timeoutMilliseconds this.currentBatchPerTopicPartition = {} this.batchTimeoutPerTopicPartition = {} } - override _read() { - // No-op, as we push data when we have a full batch or timeout - } - - override _write(message: TMessage, _encoding: BufferEncoding, callback: CallbackFunction) { - const key = this.getTopicPartitionKey(message.topic, message.partition) - - if (!this.currentBatchPerTopicPartition[key]) { - this.currentBatchPerTopicPartition[key] = [message] - } else { - // biome-ignore lint/style/noNonNullAssertion: non-existing entry is handled above - this.currentBatchPerTopicPartition[key]!.push(message) + override async _transform(message: TMessage, _encoding: BufferEncoding, callback: () => void) { + const key = getTopicPartitionKey(message.topic, message.partition) + + // Wait for all pending timeout flushes to complete to maintain backpressure + if (this.timeoutProcessingPromises.size > 0) { + // Capture a snapshot of current promises to avoid race conditions with new timeouts + const promiseEntries = Array.from(this.timeoutProcessingPromises.entries()) + // Wait for all to complete and then clean up from the map + await Promise.all( + promiseEntries.map(([k, p]) => p.finally(() => this.timeoutProcessingPromises.delete(k))), + ) } - // biome-ignore lint/style/noNonNullAssertion: we ensure above that the array is defined - if (this.currentBatchPerTopicPartition[key]!.length >= this.batchSize) { - this.flushCurrentBatchMessages(message.topic, message.partition) - return callback(null) + // Accumulate the message + if (!this.currentBatchPerTopicPartition[key]) this.currentBatchPerTopicPartition[key] = [] + this.currentBatchPerTopicPartition[key].push(message) + + // Check if the batch is complete by size + if (this.currentBatchPerTopicPartition[key].length >= this.batchSize) { + await this.flushCurrentBatchMessages(message.topic, message.partition) + callback() + return } + // Start timeout for this partition if not already started if (!this.batchTimeoutPerTopicPartition[key]) { - this.batchTimeoutPerTopicPartition[key] = setTimeout(() => { - this.flushCurrentBatchMessages(message.topic, message.partition) - }, this.timeout) + this.batchTimeoutPerTopicPartition[key] = setTimeout( + () => + this.timeoutProcessingPromises.set( + key, + this.flushCurrentBatchMessages(message.topic, message.partition), + ), + this.timeout, + ) } - callback(null) + callback() } - // Write side is closed, flush the remaining messages - override _final(callback: CallbackFunction) { - this.flushAllBatches() - this.push(null) // End readable side + // Flush all remaining batches when stream is closing + override async _flush(callback: () => void) { + await this.flushAllBatches() callback() } - private flushAllBatches() { + private async flushAllBatches() { for (const key of Object.keys(this.currentBatchPerTopicPartition)) { - const { topic, partition } = this.splitTopicPartitionKey(key) - this.flushCurrentBatchMessages(topic, partition) + const { topic, partition } = splitTopicPartitionKey(key) + await this.flushCurrentBatchMessages(topic, partition) } } - private flushCurrentBatchMessages(topic: string, partition: number) { - const key = this.getTopicPartitionKey(topic, partition) + private async flushCurrentBatchMessages(topic: string, partition: number) { + const key = getTopicPartitionKey(topic, partition) + // Clear timeout if (this.batchTimeoutPerTopicPartition[key]) { clearTimeout(this.batchTimeoutPerTopicPartition[key]) this.batchTimeoutPerTopicPartition[key] = undefined } - if (!this.currentBatchPerTopicPartition[key]?.length) { - return - } + const messages = this.currentBatchPerTopicPartition[key] ?? [] - this.push({ topic, partition, messages: this.currentBatchPerTopicPartition[key] }) + // Push the batch downstream + await this.onBatch({ topic, partition, messages }) this.currentBatchPerTopicPartition[key] = [] } +} - private getTopicPartitionKey(topic: string, partition: number): string { - return `${topic}:${partition}` - } +const getTopicPartitionKey = (topic: string, partition: number): string => `${topic}:${partition}` +const splitTopicPartitionKey = (key: string): { topic: string; partition: number } => { + const [topic, partition] = key.split(':') + /* v8 ignore start */ + if (!topic || !partition) throw new Error('Invalid topic-partition key format') + /* v8 ignore stop */ - private splitTopicPartitionKey(key: string): { topic: string; partition: number } { - const [topic, partition] = key.split(':') - if (!topic || !partition) { - throw new Error('Invalid topic-partition key format') - } - return { topic, partition: Number.parseInt(partition, 10) } - } + return { topic, partition: Number.parseInt(partition, 10) } }