Skip to content

Commit 4f22449

Browse files
committed
Cosmos Spark transactional batch support
1 parent 4eb4645 commit 4f22449

File tree

7 files changed

+3429
-11
lines changed

7 files changed

+3429
-11
lines changed

sdk/cosmos/azure-cosmos-spark_3-5/src/main/scala/com/azure/cosmos/spark/ItemsWriterBuilder.scala

Lines changed: 93 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@
22
// Licensed under the MIT License.
33
package com.azure.cosmos.spark
44

5+
import com.azure.cosmos.{CosmosAsyncClient, ReadConsistencyStrategy, SparkBridgeInternal}
56
import com.azure.cosmos.spark.diagnostics.LoggerHelper
67
import org.apache.spark.broadcast.Broadcast
8+
import org.apache.spark.sql.connector.distributions.{Distribution, Distributions}
9+
import org.apache.spark.sql.connector.expressions.{Expression, Expressions, NullOrdering, SortDirection, SortOrder}
710
import org.apache.spark.sql.connector.metric.CustomMetric
811
import org.apache.spark.sql.connector.write.streaming.StreamingWrite
9-
import org.apache.spark.sql.connector.write.{BatchWrite, Write, WriteBuilder}
12+
import org.apache.spark.sql.connector.write.{BatchWrite, RequiresDistributionAndOrdering, Write, WriteBuilder}
1013
import org.apache.spark.sql.types.StructType
1114
import org.apache.spark.sql.util.CaseInsensitiveStringMap
1215

@@ -46,7 +49,7 @@ private class ItemsWriterBuilder
4649
diagnosticsConfig,
4750
sparkEnvironmentInfo)
4851

49-
private class CosmosWrite extends Write {
52+
private class CosmosWrite extends Write with RequiresDistributionAndOrdering {
5053

5154
private[this] val supportedCosmosMetrics: Array[CustomMetric] = {
5255
Array(
@@ -56,6 +59,15 @@ private class ItemsWriterBuilder
5659
)
5760
}
5861

62+
private[this] val writeConfig = CosmosWriteConfig.parseWriteConfig(
63+
userConfig.asCaseSensitiveMap().asScala.toMap,
64+
inputSchema
65+
)
66+
67+
private[this] val containerConfig = CosmosContainerConfig.parseCosmosContainerConfig(
68+
userConfig.asCaseSensitiveMap().asScala.toMap
69+
)
70+
5971
override def toBatch(): BatchWrite =
6072
new ItemsBatchWriter(
6173
userConfig.asCaseSensitiveMap().asScala.toMap,
@@ -73,5 +85,84 @@ private class ItemsWriterBuilder
7385
sparkEnvironmentInfo)
7486

7587
override def supportedCustomMetrics(): Array[CustomMetric] = supportedCosmosMetrics
88+
89+
override def requiredDistribution(): Distribution = {
90+
if (writeConfig.bulkEnabled && writeConfig.bulkEnableTransactions) {
91+
// For transactional writes, partition by all partition key columns
92+
val partitionKeyPaths = getPartitionKeyColumnNames()
93+
if (partitionKeyPaths.nonEmpty) {
94+
// Use public Expressions.column() factory - returns NamedReference
95+
val clustering = partitionKeyPaths.map(path => Expressions.column(path): Expression).toArray
96+
Distributions.clustered(clustering)
97+
} else {
98+
Distributions.unspecified()
99+
}
100+
} else {
101+
Distributions.unspecified()
102+
}
103+
}
104+
105+
override def requiredOrdering(): Array[SortOrder] = {
106+
if (writeConfig.bulkEnabled && writeConfig.bulkEnableTransactions) {
107+
// For transactional writes, order by all partition key columns (ascending)
108+
val partitionKeyPaths = getPartitionKeyColumnNames()
109+
if (partitionKeyPaths.nonEmpty) {
110+
partitionKeyPaths.map { path =>
111+
// Use public Expressions.sort() factory for creating SortOrder
112+
Expressions.sort(
113+
Expressions.column(path),
114+
SortDirection.ASCENDING,
115+
NullOrdering.NULLS_FIRST
116+
)
117+
}.toArray
118+
} else {
119+
Array.empty[SortOrder]
120+
}
121+
} else {
122+
Array.empty[SortOrder]
123+
}
124+
}
125+
126+
private def getPartitionKeyColumnNames(): Seq[String] = {
127+
try {
128+
// Need to create a temporary container client to get partition key definition
129+
val clientCacheItem = CosmosClientCache(
130+
CosmosClientConfiguration(
131+
userConfig.asCaseSensitiveMap().asScala.toMap,
132+
ReadConsistencyStrategy.EVENTUAL,
133+
sparkEnvironmentInfo
134+
),
135+
Some(cosmosClientStateHandles.value.cosmosClientMetadataCaches),
136+
"ItemsWriterBuilder-PKLookup"
137+
)
138+
139+
val container = ThroughputControlHelper.getContainer(
140+
userConfig.asCaseSensitiveMap().asScala.toMap,
141+
containerConfig,
142+
clientCacheItem,
143+
None
144+
)
145+
146+
val containerProperties = SparkBridgeInternal.getContainerPropertiesFromCollectionCache(container)
147+
val partitionKeyDefinition = containerProperties.getPartitionKeyDefinition
148+
149+
// Release the client
150+
clientCacheItem.close()
151+
152+
if (partitionKeyDefinition != null && partitionKeyDefinition.getPaths != null) {
153+
val paths = partitionKeyDefinition.getPaths.asScala
154+
paths.map(path => {
155+
// Remove leading '/' from partition key path (e.g., "/pk" -> "pk")
156+
if (path.startsWith("/")) path.substring(1) else path
157+
}).toSeq
158+
} else {
159+
Seq.empty[String]
160+
}
161+
} catch {
162+
case ex: Exception =>
163+
log.logWarning(s"Failed to get partition key definition for transactional writes: ${ex.getMessage}")
164+
Seq.empty[String]
165+
}
166+
}
76167
}
77168
}

sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ private[spark] object CosmosConfigNames {
113113
val ClientTelemetryEnabled = "spark.cosmos.clientTelemetry.enabled" // keep this to avoid breaking changes
114114
val ClientTelemetryEndpoint = "spark.cosmos.clientTelemetry.endpoint" // keep this to avoid breaking changes
115115
val WriteBulkEnabled = "spark.cosmos.write.bulk.enabled"
116+
val WriteBulkEnableTransactions = "spark.cosmos.write.bulk.enableTransactions"
116117
val WriteBulkMaxPendingOperations = "spark.cosmos.write.bulk.maxPendingOperations"
117118
val WriteBulkMaxBatchSize = "spark.cosmos.write.bulk.maxBatchSize"
118119
val WriteBulkMinTargetBatchSize = "spark.cosmos.write.bulk.minTargetBatchSize"
@@ -242,6 +243,7 @@ private[spark] object CosmosConfigNames {
242243
ClientTelemetryEnabled,
243244
ClientTelemetryEndpoint,
244245
WriteBulkEnabled,
246+
WriteBulkEnableTransactions,
245247
WriteBulkMaxPendingOperations,
246248
WriteBulkMaxConcurrentPartitions,
247249
WriteBulkPayloadSizeInBytes,
@@ -1462,6 +1464,7 @@ private case class CosmosPatchConfigs(columnConfigsMap: TrieMap[String, CosmosPa
14621464
private case class CosmosWriteConfig(itemWriteStrategy: ItemWriteStrategy,
14631465
maxRetryCount: Int,
14641466
bulkEnabled: Boolean,
1467+
bulkEnableTransactions: Boolean = false,
14651468
bulkMaxPendingOperations: Option[Int] = None,
14661469
pointMaxConcurrency: Option[Int] = None,
14671470
maxConcurrentCosmosPartitions: Option[Int] = None,
@@ -1486,6 +1489,12 @@ private object CosmosWriteConfig {
14861489
parseFromStringFunction = bulkEnabledAsString => bulkEnabledAsString.toBoolean,
14871490
helpMessage = "Cosmos DB Item Write bulk enabled")
14881491

1492+
private val bulkEnableTransactions = CosmosConfigEntry[Boolean](key = CosmosConfigNames.WriteBulkEnableTransactions,
1493+
defaultValue = Option.apply(false),
1494+
mandatory = false,
1495+
parseFromStringFunction = enableTransactionsAsString => enableTransactionsAsString.toBoolean,
1496+
helpMessage = "Cosmos DB Item Write enable transactional batch - requires bulk write to be enabled and Spark 3.5+")
1497+
14891498
private val microBatchPayloadSizeInBytes = CosmosConfigEntry[Int](key = CosmosConfigNames.WriteBulkPayloadSizeInBytes,
14901499
defaultValue = Option.apply(BatchRequestResponseConstants.DEFAULT_MAX_DIRECT_MODE_BATCH_REQUEST_BODY_SIZE_IN_BYTES),
14911500
mandatory = false,
@@ -1758,6 +1767,7 @@ private object CosmosWriteConfig {
17581767
val itemWriteStrategyOpt = CosmosConfigEntry.parse(cfg, itemWriteStrategy)
17591768
val maxRetryCountOpt = CosmosConfigEntry.parse(cfg, maxRetryCount)
17601769
val bulkEnabledOpt = CosmosConfigEntry.parse(cfg, bulkEnabled)
1770+
val bulkEnableTransactionsOpt = CosmosConfigEntry.parse(cfg, bulkEnableTransactions)
17611771
var patchConfigsOpt = Option.empty[CosmosPatchConfigs]
17621772
val throughputControlConfigOpt = CosmosThroughputControlConfig.parseThroughputControlConfig(cfg)
17631773
val microBatchPayloadSizeInBytesOpt = CosmosConfigEntry.parse(cfg, microBatchPayloadSizeInBytes)
@@ -1788,6 +1798,7 @@ private object CosmosWriteConfig {
17881798
itemWriteStrategyOpt.get,
17891799
maxRetryCountOpt.get,
17901800
bulkEnabled = bulkEnabledOpt.get,
1801+
bulkEnableTransactions = bulkEnableTransactionsOpt.getOrElse(false),
17911802
bulkMaxPendingOperations = CosmosConfigEntry.parse(cfg, bulkMaxPendingOperations),
17921803
pointMaxConcurrency = CosmosConfigEntry.parse(cfg, pointWriteConcurrency),
17931804
maxConcurrentCosmosPartitions = CosmosConfigEntry.parse(cfg, bulkMaxConcurrentPartitions),

sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosWriterBase.scala

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,25 @@ private abstract class CosmosWriterBase(
6262

6363
private val writer: AtomicReference[AsyncItemWriter] = new AtomicReference(
6464
if (cosmosWriteConfig.bulkEnabled) {
65-
new BulkWriter(
66-
container,
67-
cosmosTargetContainerConfig,
68-
partitionKeyDefinition,
69-
cosmosWriteConfig,
70-
diagnosticsConfig,
71-
getOutputMetricsPublisher(),
72-
commitAttempt.getAndIncrement())
65+
if (cosmosWriteConfig.bulkEnableTransactions) {
66+
new TransactionalBulkWriter(
67+
container,
68+
cosmosTargetContainerConfig,
69+
partitionKeyDefinition,
70+
cosmosWriteConfig,
71+
diagnosticsConfig,
72+
getOutputMetricsPublisher(),
73+
commitAttempt.getAndIncrement())
74+
} else {
75+
new BulkWriter(
76+
container,
77+
cosmosTargetContainerConfig,
78+
partitionKeyDefinition,
79+
cosmosWriteConfig,
80+
diagnosticsConfig,
81+
getOutputMetricsPublisher(),
82+
commitAttempt.getAndIncrement())
83+
}
7384
} else {
7485
new PointWriter(
7586
container,

0 commit comments

Comments
 (0)