Skip to content

Commit 39241c4

Browse files
committed
Refactor semaphore to limit batches not operations
1 parent 0ebfa98 commit 39241c4

File tree

2 files changed

+131
-40
lines changed

2 files changed

+131
-40
lines changed

sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/TransactionalBulkWriter.scala

Lines changed: 71 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,16 @@ private class TransactionalBulkWriter
6363
private val verboseLoggingAfterReEnqueueingRetriesEnabled = new AtomicBoolean(false)
6464

6565
private val cpuCount = SparkUtils.getNumberOfHostCPUCores
66-
// each bulk writer allows up to maxPendingOperations being buffered
67-
// there is one bulk writer per spark task/partition
68-
// and default config will create one executor per core on the executor host
69-
// so multiplying by cpuCount in the default config is too aggressive
66+
67+
// NOTE: The public API config property is "maxPendingOperations" for backward compatibility,
68+
// but internally TransactionalBulkWriter limits concurrent *batches* not individual operations.
69+
// This is semantically correct because transactional batches are atomic units.
70+
// We convert maxPendingOperations to maxPendingBatches by assuming ~50 operations per batch.
71+
private val maxPendingBatches = {
72+
val maxOps = writeConfig.bulkMaxPendingOperations.getOrElse(DefaultMaxPendingOperationPerCore)
73+
// Assume average batch size of 50 operations - limit concurrent batches accordingly
74+
Math.max(1, maxOps / 50)
75+
}
7076
private val maxPendingOperations = writeConfig.bulkMaxPendingOperations
7177
.getOrElse(DefaultMaxPendingOperationPerCore)
7278
private val maxConcurrentPartitions = writeConfig.maxConcurrentCosmosPartitions match {
@@ -85,8 +91,8 @@ private class TransactionalBulkWriter
8591
}
8692

8793
log.logInfo(
88-
s"BulkWriter instantiated (Host CPU count: $cpuCount, maxPendingOperations: $maxPendingOperations, " +
89-
s"maxConcurrentPartitions: $maxConcurrentPartitions ...")
94+
s"TransactionalBulkWriter instantiated (Host CPU count: $cpuCount, maxPendingBatches: $maxPendingBatches, " +
95+
s"maxPendingOperations: $maxPendingOperations, maxConcurrentPartitions: $maxConcurrentPartitions ...")
9096

9197

9298
private val closed = new AtomicBoolean(false)
@@ -106,7 +112,11 @@ private class TransactionalBulkWriter
106112

107113
private val activeBulkWriteOperations =java.util.concurrent.ConcurrentHashMap.newKeySet[CosmosItemOperation]().asScala
108114
private val operationContextMap = new java.util.concurrent.ConcurrentHashMap[CosmosItemOperation, OperationContext]().asScala
109-
private val semaphore = new Semaphore(maxPendingOperations)
115+
// Semaphore limits number of outstanding batches (not individual operations)
116+
private val semaphore = new Semaphore(maxPendingBatches)
117+
private val activeBatches = new AtomicInteger(0)
118+
// Map each batch's first operation to the batch size for semaphore release tracking
119+
private val batchSizeMap = new java.util.concurrent.ConcurrentHashMap[CosmosItemOperation, Integer]()
110120

111121
private val totalScheduledMetrics = new AtomicLong(0)
112122
private val totalSuccessfulIngestionMetrics = new AtomicLong(0)
@@ -288,10 +298,15 @@ private class TransactionalBulkWriter
288298
resp => {
289299
val isGettingRetried = new AtomicBoolean(false)
290300
val shouldSkipTaskCompletion = new AtomicBoolean(false)
301+
val isFirstOperationInBatch = new AtomicBoolean(false)
291302
try {
292303
val itemOperation = resp.getOperation
293304
val itemOperationFound = activeBulkWriteOperations.remove(itemOperation)
294305
val pendingRetriesFound = pendingBulkWriteRetries.remove(itemOperation)
306+
307+
// Check if this is the first operation in a batch (used for semaphore release)
308+
val batchSizeOption = Option(batchSizeMap.remove(itemOperation))
309+
isFirstOperationInBatch.set(batchSizeOption.isDefined)
295310

296311
if (pendingRetriesFound) {
297312
pendingRetries.decrementAndGet()
@@ -337,8 +352,12 @@ private class TransactionalBulkWriter
337352
}
338353
}
339354
finally {
340-
if (!isGettingRetried.get) {
355+
// Release semaphore when we process the first operation of a batch
356+
// This indicates the entire batch has been processed
357+
if (!isGettingRetried.get && isFirstOperationInBatch.get) {
358+
activeBatches.decrementAndGet()
341359
semaphore.release()
360+
log.logTrace(s"Released semaphore for completed batch, activeBatches: ${activeBatches.get} $getThreadInfo")
342361
}
343362
}
344363

@@ -368,42 +387,13 @@ private class TransactionalBulkWriter
368387
Preconditions.checkState(!closed.get())
369388
throwIfCapturedExceptionExists()
370389

371-
val activeTasksSemaphoreTimeout = 10
372390
val operationContext = new OperationContext(
373391
getId(objectNode),
374392
partitionKeyValue,
375393
getETag(objectNode),
376394
1,
377395
monotonicOperationCounter.incrementAndGet(),
378396
None)
379-
val numberOfIntervalsWithIdenticalActiveOperationSnapshots = new AtomicLong(0)
380-
// Don't clone the activeOperations for the first iteration
381-
// to reduce perf impact before the Semaphore has been acquired
382-
// this means if the semaphore can't be acquired within 10 minutes
383-
// the first attempt will always assume it wasn't stale - so effectively we
384-
// allow staleness for ten additional minutes - which is perfectly fine
385-
var activeBulkWriteOperationsSnapshot = mutable.Set.empty[CosmosItemOperation]
386-
var pendingBulkWriteRetriesSnapshot = mutable.Set.empty[CosmosItemOperation]
387-
388-
log.logTrace(
389-
s"Before TryAcquire ${totalScheduledMetrics.get}, Context: ${operationContext.toString} $getThreadInfo")
390-
while (!semaphore.tryAcquire(activeTasksSemaphoreTimeout, TimeUnit.MINUTES)) {
391-
log.logDebug(s"Not able to acquire semaphore, Context: ${operationContext.toString} $getThreadInfo")
392-
if (subscriptionDisposable.isDisposed) {
393-
captureIfFirstFailure(
394-
new IllegalStateException("Can't accept any new work - BulkWriter has been disposed already"))
395-
}
396-
397-
throwIfProgressStaled(
398-
"Semaphore acquisition",
399-
activeBulkWriteOperationsSnapshot,
400-
pendingBulkWriteRetriesSnapshot,
401-
numberOfIntervalsWithIdenticalActiveOperationSnapshots,
402-
allowRetryOnNewBulkWriterInstance = false)
403-
404-
activeBulkWriteOperationsSnapshot = activeBulkWriteOperations.clone()
405-
pendingBulkWriteRetriesSnapshot = pendingBulkWriteRetries.clone()
406-
}
407397

408398
val cnt = totalScheduledMetrics.getAndIncrement()
409399
log.logTrace(s"total scheduled $cnt, Context: ${operationContext.toString} $getThreadInfo")
@@ -464,6 +454,36 @@ private class TransactionalBulkWriter
464454
private[this] def flushCurrentBatch(): Unit = {
465455
// Must be called within batchConstructionLock.synchronized
466456
if (currentBatchOperations.nonEmpty && currentPartitionKey != null) {
457+
// Acquire semaphore before emitting batch - this limits concurrent batches
458+
val activeTasksSemaphoreTimeout = 10
459+
val numberOfIntervalsWithIdenticalActiveOperationSnapshots = new AtomicLong(0)
460+
var activeBulkWriteOperationsSnapshot = mutable.Set.empty[CosmosItemOperation]
461+
var pendingBulkWriteRetriesSnapshot = mutable.Set.empty[CosmosItemOperation]
462+
463+
val dummyContext = new OperationContext("batch-flush", currentPartitionKey, null, 1, 0, None)
464+
465+
log.logTrace(s"Before acquiring semaphore for batch emission, activeBatches: ${activeBatches.get} $getThreadInfo")
466+
while (!semaphore.tryAcquire(activeTasksSemaphoreTimeout, TimeUnit.MINUTES)) {
467+
log.logDebug(s"Not able to acquire semaphore for batch, activeBatches: ${activeBatches.get} $getThreadInfo")
468+
if (subscriptionDisposable.isDisposed) {
469+
captureIfFirstFailure(
470+
new IllegalStateException("Can't accept any new work - BulkWriter has been disposed already"))
471+
}
472+
473+
throwIfProgressStaled(
474+
"Batch semaphore acquisition",
475+
activeBulkWriteOperationsSnapshot,
476+
pendingBulkWriteRetriesSnapshot,
477+
numberOfIntervalsWithIdenticalActiveOperationSnapshots,
478+
allowRetryOnNewBulkWriterInstance = false)
479+
480+
activeBulkWriteOperationsSnapshot = activeBulkWriteOperations.clone()
481+
pendingBulkWriteRetriesSnapshot = pendingBulkWriteRetries.clone()
482+
}
483+
484+
activeBatches.incrementAndGet()
485+
log.logTrace(s"Acquired semaphore for batch emission, activeBatches: ${activeBatches.get} $getThreadInfo")
486+
467487
// Build the batch using the builder API
468488
val batch = CosmosBatch.createCosmosBatch(currentPartitionKey)
469489

@@ -478,14 +498,20 @@ private class TransactionalBulkWriter
478498

479499
// After building the batch, get the operations and track them with their contexts
480500
val batchOperations = batch.getOperations()
501+
val batchSize = batchOperations.size()
481502

482503
// Map each operation to its context and add to tracking
483-
for (i <- 0 until batchOperations.size()) {
504+
for (i <- 0 until batchSize) {
484505
val operation = batchOperations.get(i)
485506
val context = contextList(i)
486507
operationContextMap.put(operation, context)
487508
activeBulkWriteOperations.add(operation)
488509
}
510+
511+
// Track batch size using first operation as key - needed for semaphore release
512+
if (batchSize > 0) {
513+
batchSizeMap.put(batchOperations.get(0), batchSize)
514+
}
489515

490516
// Emit the batch
491517
bulkInputEmitter.emitNext(batch, emitFailureHandler)
@@ -800,15 +826,20 @@ private class TransactionalBulkWriter
800826
}
801827

802828
logInfoOrWarning(s"invoking bulkInputEmitter.onComplete(), Context: ${operationContext.toString} $getThreadInfo")
803-
semaphore.release(Math.max(0, activeTasks.get()))
829+
// Release any remaining batch permits
830+
val remainingBatches = activeBatches.get()
831+
if (remainingBatches > 0) {
832+
semaphore.release(remainingBatches)
833+
log.logDebug(s"Released $remainingBatches batch permits during cleanup")
834+
}
804835
bulkInputEmitter.emitComplete(TransactionalBulkWriter.emitFailureHandlerForComplete)
805836

806837
throwIfCapturedExceptionExists()
807838

808839
assume(activeTasks.get() <= 0)
809840

810841
assume(activeBulkWriteOperations.isEmpty)
811-
assume(semaphore.availablePermits() >= maxPendingOperations)
842+
assume(semaphore.availablePermits() >= maxPendingBatches)
812843

813844
if (totalScheduledMetrics.get() != totalSuccessfulIngestionMetrics.get) {
814845
log.logWarning(s"flushAndClose completed with no error but inconsistent total success and " +

sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/TransactionalBatchITest.scala

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -746,4 +746,64 @@ class TransactionalBatchITest extends IntegrationSpec
746746
}
747747
}
748748

749+
it should "enforce batch-level backpressure with small maxPendingOperations" in {
750+
val cosmosEndpoint = TestConfigurations.HOST
751+
val cosmosMasterKey = TestConfigurations.MASTER_KEY
752+
val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainersWithPkAsPartitionKey)
753+
754+
// Use a small maxPendingOperations value to force batch-level limiting
755+
// With maxPendingOperations=50, maxPendingBatches = 50/50 = 1
756+
// This means only 1 batch should be in-flight at a time
757+
val maxPendingOperations = 50
758+
759+
// Create 200 operations across 4 batches (50 operations per partition key = 1 batch each)
760+
// This will test that the semaphore properly limits concurrent batches
761+
val partitionKeys = (1 to 4).map(_ => UUID.randomUUID().toString)
762+
763+
val schema = StructType(Seq(
764+
StructField("id", StringType, nullable = false),
765+
StructField("pk", StringType, nullable = false),
766+
StructField("value", IntegerType, nullable = false)
767+
))
768+
769+
// Create 50 operations per partition key (forms 1 batch per PK)
770+
val allOperations = partitionKeys.flatMap { pk =>
771+
(1 to 50).map { i =>
772+
Row(s"item-$i-${UUID.randomUUID()}", pk, i)
773+
}
774+
}
775+
776+
val operationsDf = spark.createDataFrame(allOperations.asJava, schema)
777+
778+
// Execute with very small maxPendingOperations to force batch limiting
779+
operationsDf.write
780+
.format("cosmos.oltp")
781+
.option("spark.cosmos.accountEndpoint", cosmosEndpoint)
782+
.option("spark.cosmos.accountKey", cosmosMasterKey)
783+
.option("spark.cosmos.database", cosmosDatabase)
784+
.option("spark.cosmos.container", cosmosContainersWithPkAsPartitionKey)
785+
.option("spark.cosmos.write.bulk.transactional", "true")
786+
.option("spark.cosmos.write.bulk.enabled", "true")
787+
.option("spark.cosmos.write.bulk.maxPendingOperations", maxPendingOperations.toString)
788+
.mode(SaveMode.Append)
789+
.save()
790+
791+
// Verify all items were created successfully
792+
partitionKeys.foreach { pk =>
793+
val queryResult = container
794+
.queryItems(s"SELECT VALUE COUNT(1) FROM c WHERE c.pk = '$pk'", classOf[Long])
795+
.collectList()
796+
.block()
797+
798+
val count = if (queryResult.isEmpty) 0L else queryResult.get(0)
799+
assert(count == 50, s"Expected 50 items for partition key $pk, but found $count")
800+
}
801+
802+
// If we get here without deadlock or timeout, batch-level backpressure is working
803+
// The test verifies:
804+
// 1. Operations complete successfully even with tight batch limit
805+
// 2. No deadlocks occur from semaphore management
806+
// 3. All batches are properly tracked and released
807+
}
808+
749809
}

0 commit comments

Comments
 (0)