Skip to content

Commit 7e5c99b

Browse files
committed
Add per-row operation type support for transactional batch
1 parent 4f22449 commit 7e5c99b

File tree

7 files changed

+374
-28
lines changed

7 files changed

+374
-28
lines changed

sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/AsyncItemWriter.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,14 @@ private trait AsyncItemWriter {
1414
*/
1515
def scheduleWrite(partitionKeyValue: PartitionKey, objectNode: ObjectNode): Unit
1616

17+
/**
18+
* Schedule a write to happen in async and return immediately with per-row operation type
19+
* @param partitionKeyValue the partition key value
20+
* @param objectNode the json object node
21+
* @param operationType optional operation type (create, upsert, replace, delete) for this specific row
22+
*/
23+
def scheduleWrite(partitionKeyValue: PartitionKey, objectNode: ObjectNode, operationType: Option[String]): Unit
24+
1725
/**
1826
* Wait for all remaining work
1927
* Throws if any of the work resulted in failure

sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/BulkWriter.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,12 @@ private class BulkWriter
646646
}
647647

648648
override def scheduleWrite(partitionKeyValue: PartitionKey, objectNode: ObjectNode): Unit = {
649+
scheduleWrite(partitionKeyValue, objectNode, None)
650+
}
651+
652+
override def scheduleWrite(partitionKeyValue: PartitionKey, objectNode: ObjectNode, operationType: Option[String]): Unit = {
653+
// BulkWriter doesn't support per-row operation types - it uses global ItemWriteStrategy
654+
// The operationType parameter is ignored here for interface compatibility
649655
Preconditions.checkState(!closed.get())
650656
throwIfCapturedExceptionExists()
651657

sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1439,7 +1439,7 @@ private[spark] object DiagnosticsConfig {
14391439

14401440
private object ItemWriteStrategy extends Enumeration {
14411441
type ItemWriteStrategy = Value
1442-
val ItemOverwrite, ItemAppend, ItemDelete, ItemDeleteIfNotModified, ItemOverwriteIfNotModified, ItemPatch, ItemPatchIfExists, ItemBulkUpdate = Value
1442+
val ItemOverwrite, ItemAppend, ItemDelete, ItemDeleteIfNotModified, ItemOverwriteIfNotModified, ItemPatch, ItemPatchIfExists, ItemBulkUpdate, ItemTransactionalBatch = Value
14431443
}
14441444

14451445
private object CosmosPatchOperationTypes extends Enumeration {

sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosWriterBase.scala

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,13 @@ private abstract class CosmosWriterBase(
6262

6363
private val writer: AtomicReference[AsyncItemWriter] = new AtomicReference(
6464
if (cosmosWriteConfig.bulkEnabled) {
65-
if (cosmosWriteConfig.bulkEnableTransactions) {
65+
// Use TransactionalBulkWriter if either:
66+
// 1. bulkEnableTransactions is explicitly set, OR
67+
// 2. ItemWriteStrategy is ItemTransactionalBatch
68+
val useTransactional = cosmosWriteConfig.bulkEnableTransactions ||
69+
cosmosWriteConfig.itemWriteStrategy == ItemWriteStrategy.ItemTransactionalBatch
70+
71+
if (useTransactional) {
6672
new TransactionalBulkWriter(
6773
container,
6874
cosmosTargetContainerConfig,
@@ -94,6 +100,23 @@ private abstract class CosmosWriterBase(
94100
override def write(internalRow: InternalRow): Unit = {
95101
val objectNode = cosmosRowConverter.fromInternalRowToObjectNode(internalRow, inputSchema)
96102

103+
// Extract operationType if column exists (for per-row operation support)
104+
val operationType: Option[String] = if (inputSchema.fieldNames.contains("operationType")) {
105+
val opTypeIndex = inputSchema.fieldIndex("operationType")
106+
if (!internalRow.isNullAt(opTypeIndex)) {
107+
Some(internalRow.getString(opTypeIndex))
108+
} else {
109+
None
110+
}
111+
} else {
112+
None
113+
}
114+
115+
// Remove operationType from objectNode if present (don't persist to Cosmos)
116+
if (objectNode.has("operationType")) {
117+
objectNode.remove("operationType")
118+
}
119+
97120
require(objectNode.has(CosmosConstants.Properties.Id) &&
98121
objectNode.get(CosmosConstants.Properties.Id).isTextual,
99122
s"${CosmosConstants.Properties.Id} is a mandatory field. " +
@@ -107,7 +130,7 @@ private abstract class CosmosWriterBase(
107130
}
108131

109132
val partitionKeyValue = PartitionKeyHelper.getPartitionKeyPath(objectNode, partitionKeyDefinition)
110-
writer.get.scheduleWrite(partitionKeyValue, objectNode)
133+
writer.get.scheduleWrite(partitionKeyValue, objectNode, operationType)
111134
}
112135

113136
override def commit(): WriterCommitMessage = {

sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/PointWriter.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,10 @@ private class PointWriter(container: CosmosAsyncContainer,
7676
}
7777

7878
override def scheduleWrite(partitionKeyValue: PartitionKey, objectNode: ObjectNode): Unit = {
79+
scheduleWrite(partitionKeyValue, objectNode, None)
80+
}
81+
82+
override def scheduleWrite(partitionKeyValue: PartitionKey, objectNode: ObjectNode, operationType: Option[String]): Unit = {
7983
checkState(!closed.get())
8084

8185
val etag = getETag(objectNode)

sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/TransactionalBulkWriter.scala

Lines changed: 65 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,10 @@ private class TransactionalBulkWriter
655655
}
656656

657657
override def scheduleWrite(partitionKeyValue: PartitionKey, objectNode: ObjectNode): Unit = {
658+
scheduleWrite(partitionKeyValue, objectNode, None)
659+
}
660+
661+
override def scheduleWrite(partitionKeyValue: PartitionKey, objectNode: ObjectNode, operationType: Option[String]): Unit = {
658662
Preconditions.checkState(!closed.get())
659663
throwIfCapturedExceptionExists()
660664

@@ -664,7 +668,9 @@ private class TransactionalBulkWriter
664668
partitionKeyValue,
665669
getETag(objectNode),
666670
1,
667-
monotonicOperationCounter.incrementAndGet())
671+
monotonicOperationCounter.incrementAndGet(),
672+
None,
673+
operationType)
668674
val numberOfIntervalsWithIdenticalActiveOperationSnapshots = new AtomicLong(0)
669675
// Don't clone the activeOperations for the first iteration
670676
// to reduce perf impact before the Semaphore has been acquired
@@ -739,10 +745,28 @@ private class TransactionalBulkWriter
739745
objectNode: ObjectNode,
740746
operationContext: OperationContext): Unit = {
741747

742-
val bulkItemOperation = writeConfig.itemWriteStrategy match {
743-
case ItemWriteStrategy.ItemOverwrite =>
748+
// Determine the effective operation type from per-row operationType or global strategy
749+
val effectiveOperationType = operationContext.operationType match {
750+
case Some(opType) => opType.toLowerCase
751+
case None => writeConfig.itemWriteStrategy match {
752+
case ItemWriteStrategy.ItemOverwrite => "upsert"
753+
case ItemWriteStrategy.ItemAppend => "create"
754+
case ItemWriteStrategy.ItemDelete | ItemWriteStrategy.ItemDeleteIfNotModified => "delete"
755+
case ItemWriteStrategy.ItemOverwriteIfNotModified => "replace"
756+
case ItemWriteStrategy.ItemTransactionalBatch => "upsert" // Default to upsert for transactional batch
757+
case ItemWriteStrategy.ItemPatch | ItemWriteStrategy.ItemPatchIfExists => "patch"
758+
case _ => throw new RuntimeException(s"${writeConfig.itemWriteStrategy} not supported for transactional batch")
759+
}
760+
}
761+
762+
val bulkItemOperation = effectiveOperationType match {
763+
case "create" =>
764+
CosmosBulkOperations.getCreateItemOperation(objectNode, partitionKeyValue, operationContext)
765+
766+
case "upsert" =>
744767
CosmosBulkOperations.getUpsertItemOperation(objectNode, partitionKeyValue, operationContext)
745-
case ItemWriteStrategy.ItemOverwriteIfNotModified =>
768+
769+
case "replace" =>
746770
operationContext.eTag match {
747771
case Some(eTag) =>
748772
CosmosBulkOperations.getReplaceItemOperation(
@@ -751,29 +775,41 @@ private class TransactionalBulkWriter
751775
partitionKeyValue,
752776
new CosmosBulkItemRequestOptions().setIfMatchETag(eTag),
753777
operationContext)
754-
case _ => CosmosBulkOperations.getCreateItemOperation(objectNode, partitionKeyValue, operationContext)
778+
case _ =>
779+
CosmosBulkOperations.getReplaceItemOperation(
780+
operationContext.itemId,
781+
objectNode,
782+
partitionKeyValue,
783+
operationContext)
755784
}
756-
case ItemWriteStrategy.ItemAppend =>
757-
CosmosBulkOperations.getCreateItemOperation(objectNode, partitionKeyValue, operationContext)
758-
case ItemWriteStrategy.ItemDelete =>
759-
CosmosBulkOperations.getDeleteItemOperation(operationContext.itemId, partitionKeyValue, operationContext)
760-
case ItemWriteStrategy.ItemDeleteIfNotModified =>
761-
CosmosBulkOperations.getDeleteItemOperation(
785+
786+
case "delete" =>
787+
operationContext.eTag match {
788+
case Some(eTag) =>
789+
CosmosBulkOperations.getDeleteItemOperation(
790+
operationContext.itemId,
791+
partitionKeyValue,
792+
new CosmosBulkItemRequestOptions().setIfMatchETag(eTag),
793+
operationContext)
794+
case _ =>
795+
CosmosBulkOperations.getDeleteItemOperation(
796+
operationContext.itemId,
797+
partitionKeyValue,
798+
operationContext)
799+
}
800+
801+
case "patch" =>
802+
getPatchItemOperation(
762803
operationContext.itemId,
763804
partitionKeyValue,
764-
operationContext.eTag match {
765-
case Some(eTag) => new CosmosBulkItemRequestOptions().setIfMatchETag(eTag)
766-
case _ => new CosmosBulkItemRequestOptions()
767-
},
805+
partitionKeyDefinition,
806+
objectNode,
768807
operationContext)
769-
case ItemWriteStrategy.ItemPatch | ItemWriteStrategy.ItemPatchIfExists => getPatchItemOperation(
770-
operationContext.itemId,
771-
partitionKeyValue,
772-
partitionKeyDefinition,
773-
objectNode,
774-
operationContext)
808+
775809
case _ =>
776-
throw new RuntimeException(s"${writeConfig.itemWriteStrategy} not supported")
810+
throw new IllegalArgumentException(
811+
s"Unsupported operationType '$effectiveOperationType'. " +
812+
s"Supported types for transactional batch: create, upsert, replace, delete")
777813
}
778814

779815
this.emitBulkInput(bulkItemOperation)
@@ -1363,9 +1399,10 @@ private class TransactionalBulkWriter
13631399
val attemptNumber: Int,
13641400
val sequenceNumber: Long,
13651401
/** starts from 1 * */
1366-
sourceItemInput: Option[ObjectNode] = None) // for patchBulkUpdate: source item refers to the original objectNode from which SDK constructs the final bulk item operation
1402+
sourceItemInput: Option[ObjectNode] = None, // for patchBulkUpdate: source item refers to the original objectNode from which SDK constructs the final bulk item operation
1403+
operationTypeInput: Option[String] = None) // per-row operation type (create, upsert, replace, delete)
13671404
{
1368-
private val ctxCore: OperationContextCore = OperationContextCore(itemIdInput, partitionKeyValueInput, eTagInput, sourceItemInput)
1405+
private val ctxCore: OperationContextCore = OperationContextCore(itemIdInput, partitionKeyValueInput, eTagInput, sourceItemInput, operationTypeInput)
13691406

13701407
override def equals(obj: Any): Boolean = ctxCore.equals(obj)
13711408

@@ -1382,14 +1419,17 @@ private class TransactionalBulkWriter
13821419
def eTag: Option[String] = ctxCore.eTag
13831420

13841421
def sourceItem: Option[ObjectNode] = ctxCore.sourceItem
1422+
1423+
def operationType: Option[String] = ctxCore.operationType
13851424
}
13861425

13871426
private case class OperationContextCore
13881427
(
13891428
itemId: String,
13901429
partitionKeyValue: PartitionKey,
13911430
eTag: Option[String],
1392-
sourceItem: Option[ObjectNode] = None) // for patchBulkUpdate: source item refers to the original objectNode from which SDK constructs the final bulk item operation
1431+
sourceItem: Option[ObjectNode] = None, // for patchBulkUpdate: source item refers to the original objectNode from which SDK constructs the final bulk item operation
1432+
operationType: Option[String] = None) // per-row operation type (create, upsert, replace, delete)
13931433
{
13941434
override def productPrefix: String = "OperationContext"
13951435
}

0 commit comments

Comments
 (0)