22// Licensed under the MIT License.
33package com .azure .cosmos .spark
44
5+ import com .azure .cosmos .{CosmosAsyncClient , ReadConsistencyStrategy , SparkBridgeInternal }
56import com .azure .cosmos .spark .diagnostics .LoggerHelper
67import org .apache .spark .broadcast .Broadcast
8+ import org .apache .spark .sql .connector .distributions .{Distribution , Distributions }
9+ import org .apache .spark .sql .connector .expressions .{Expression , Expressions , NullOrdering , SortDirection , SortOrder }
710import org .apache .spark .sql .connector .metric .CustomMetric
811import org .apache .spark .sql .connector .write .streaming .StreamingWrite
9- import org .apache .spark .sql .connector .write .{BatchWrite , Write , WriteBuilder }
12+ import org .apache .spark .sql .connector .write .{BatchWrite , RequiresDistributionAndOrdering , Write , WriteBuilder }
1013import org .apache .spark .sql .types .StructType
1114import org .apache .spark .sql .util .CaseInsensitiveStringMap
1215
@@ -46,7 +49,7 @@ private class ItemsWriterBuilder
4649 diagnosticsConfig,
4750 sparkEnvironmentInfo)
4851
49- private class CosmosWrite extends Write {
52+ private class CosmosWrite extends Write with RequiresDistributionAndOrdering {
5053
5154 private [this ] val supportedCosmosMetrics : Array [CustomMetric ] = {
5255 Array (
@@ -56,6 +59,15 @@ private class ItemsWriterBuilder
5659 )
5760 }
5861
62+ private [this ] val writeConfig = CosmosWriteConfig .parseWriteConfig(
63+ userConfig.asCaseSensitiveMap().asScala.toMap,
64+ inputSchema
65+ )
66+
67+ private [this ] val containerConfig = CosmosContainerConfig .parseCosmosContainerConfig(
68+ userConfig.asCaseSensitiveMap().asScala.toMap
69+ )
70+
5971 override def toBatch (): BatchWrite =
6072 new ItemsBatchWriter (
6173 userConfig.asCaseSensitiveMap().asScala.toMap,
@@ -73,5 +85,84 @@ private class ItemsWriterBuilder
7385 sparkEnvironmentInfo)
7486
7587 override def supportedCustomMetrics (): Array [CustomMetric ] = supportedCosmosMetrics
88+
89+ override def requiredDistribution (): Distribution = {
90+ if (writeConfig.bulkEnabled && writeConfig.bulkEnableTransactions) {
91+ // For transactional writes, partition by all partition key columns
92+ val partitionKeyPaths = getPartitionKeyColumnNames()
93+ if (partitionKeyPaths.nonEmpty) {
94+ // Use public Expressions.column() factory - returns NamedReference
95+ val clustering = partitionKeyPaths.map(path => Expressions .column(path): Expression ).toArray
96+ Distributions .clustered(clustering)
97+ } else {
98+ Distributions .unspecified()
99+ }
100+ } else {
101+ Distributions .unspecified()
102+ }
103+ }
104+
105+ override def requiredOrdering (): Array [SortOrder ] = {
106+ if (writeConfig.bulkEnabled && writeConfig.bulkEnableTransactions) {
107+ // For transactional writes, order by all partition key columns (ascending)
108+ val partitionKeyPaths = getPartitionKeyColumnNames()
109+ if (partitionKeyPaths.nonEmpty) {
110+ partitionKeyPaths.map { path =>
111+ // Use public Expressions.sort() factory for creating SortOrder
112+ Expressions .sort(
113+ Expressions .column(path),
114+ SortDirection .ASCENDING ,
115+ NullOrdering .NULLS_FIRST
116+ )
117+ }.toArray
118+ } else {
119+ Array .empty[SortOrder ]
120+ }
121+ } else {
122+ Array .empty[SortOrder ]
123+ }
124+ }
125+
126+ private def getPartitionKeyColumnNames (): Seq [String ] = {
127+ try {
128+ // Need to create a temporary container client to get partition key definition
129+ val clientCacheItem = CosmosClientCache (
130+ CosmosClientConfiguration (
131+ userConfig.asCaseSensitiveMap().asScala.toMap,
132+ ReadConsistencyStrategy .EVENTUAL ,
133+ sparkEnvironmentInfo
134+ ),
135+ Some (cosmosClientStateHandles.value.cosmosClientMetadataCaches),
136+ " ItemsWriterBuilder-PKLookup"
137+ )
138+
139+ val container = ThroughputControlHelper .getContainer(
140+ userConfig.asCaseSensitiveMap().asScala.toMap,
141+ containerConfig,
142+ clientCacheItem,
143+ None
144+ )
145+
146+ val containerProperties = SparkBridgeInternal .getContainerPropertiesFromCollectionCache(container)
147+ val partitionKeyDefinition = containerProperties.getPartitionKeyDefinition
148+
149+ // Release the client
150+ clientCacheItem.close()
151+
152+ if (partitionKeyDefinition != null && partitionKeyDefinition.getPaths != null ) {
153+ val paths = partitionKeyDefinition.getPaths.asScala
154+ paths.map(path => {
155+ // Remove leading '/' from partition key path (e.g., "/pk" -> "pk")
156+ if (path.startsWith(" /" )) path.substring(1 ) else path
157+ }).toSeq
158+ } else {
159+ Seq .empty[String ]
160+ }
161+ } catch {
162+ case ex : Exception =>
163+ log.logWarning(s " Failed to get partition key definition for transactional writes: ${ex.getMessage}" )
164+ Seq .empty[String ]
165+ }
166+ }
76167 }
77168}
0 commit comments