Skip to content

Commit 03c2f45

Browse files
committed
[SPARK-52167] Support hint for DataFrame
### What changes were proposed in this pull request? This PR aims to support `hint` API for `DataFrame`. ### Why are the changes needed? For feature parity. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #155 from dongjoon-hyun/SPARK-52167. Authored-by: Dongjoon Hyun <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 2c2a5f1 commit 03c2f45

File tree

3 files changed

+69
-0
lines changed

3 files changed

+69
-0
lines changed

Sources/SparkConnect/DataFrame.swift

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ import Synchronization
115115
/// - ``melt(_:_:_:_:)``
116116
/// - ``transpose()``
117117
/// - ``transpose(_:)``
118+
/// - ``hint(_:_:)``
118119
///
119120
/// ### Join Operations
120121
/// - ``join(_:)``
@@ -1349,6 +1350,17 @@ public actor DataFrame: Sendable {
13491350
return GroupedData(self, GroupType.cube, cols)
13501351
}
13511352

1353+
/// Specifies some hint on the current Dataset.
1354+
/// - Parameters:
1355+
/// - name: The hint name.
1356+
/// - parameters: The parameters of the hint
1357+
/// - Returns: A ``DataFrame``.
1358+
@discardableResult
1359+
public func hint(_ name: String, _ parameters: Sendable...) -> DataFrame {
1360+
let plan = SparkConnectClient.getHint(self.plan.root, name, parameters)
1361+
return DataFrame(spark: self.spark, plan: plan)
1362+
}
1363+
13521364
/// Creates a local temporary view using the given name. The lifetime of this temporary view is
13531365
/// tied to the `SparkSession` that was used to create this ``DataFrame``.
13541366
/// - Parameter viewName: A view name.

Sources/SparkConnect/SparkConnectClient.swift

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -950,6 +950,41 @@ public actor SparkConnectClient {
950950
return plan
951951
}
952952

953+
static func getHint(_ child: Relation, _ name: String, _ parameters: [Sendable]) -> Plan {
954+
var hint = Spark_Connect_Hint()
955+
hint.input = child
956+
hint.name = name
957+
hint.parameters = parameters.map {
958+
var literal = ExpressionLiteral()
959+
switch $0 {
960+
case let value as Bool:
961+
literal.boolean = value
962+
case let value as Int8:
963+
literal.byte = Int32(value)
964+
case let value as Int16:
965+
literal.short = Int32(value)
966+
case let value as Int32:
967+
literal.integer = value
968+
case let value as Int64: // Hint parameter raises exceptions for Int64
969+
literal.integer = Int32(value)
970+
case let value as Int:
971+
literal.integer = Int32(value)
972+
case let value as String:
973+
literal.string = value
974+
default:
975+
literal.string = $0 as! String
976+
}
977+
var expr = Spark_Connect_Expression()
978+
expr.literal = literal
979+
return expr
980+
}
981+
var relation = Relation()
982+
relation.hint = hint
983+
var plan = Plan()
984+
plan.opType = .root(relation)
985+
return plan
986+
}
987+
953988
func createTempView(
954989
_ child: Relation, _ viewName: String, replace: Bool, isGlobal: Bool
955990
) async throws {

Tests/SparkConnectTests/DataFrameTests.swift

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -852,4 +852,26 @@ struct DataFrameTests {
852852

853853
await spark.stop()
854854
}
855+
856+
@Test
857+
func hint() async throws {
858+
let spark = try await SparkSession.builder.getOrCreate()
859+
let df1 = try await spark.range(1)
860+
let df2 = try await spark.range(1)
861+
862+
try await df1.join(df2.hint("broadcast")).count()
863+
try await df1.join(df2.hint("coalesce", 10)).count()
864+
try await df1.join(df2.hint("rebalance", 10)).count()
865+
try await df1.join(df2.hint("rebalance", 10, "id")).count()
866+
try await df1.join(df2.hint("repartition", 10)).count()
867+
try await df1.join(df2.hint("repartition", 10, "id")).count()
868+
try await df1.join(df2.hint("repartition", "id")).count()
869+
try await df1.join(df2.hint("repartition_by_range")).count()
870+
try await df1.join(df2.hint("merge")).count()
871+
try await df1.join(df2.hint("shuffle_hash")).count()
872+
try await df1.join(df2.hint("shuffle_replicate_nl")).count()
873+
try await df1.join(df2.hint("shuffle_merge")).count()
874+
875+
await spark.stop()
876+
}
855877
}

0 commit comments

Comments
 (0)