Skip to content

Commit 5ded181

Browse files
committed
[SPARK-52066] Support unpivot/melt/transpose in DataFrame
### What changes were proposed in this pull request? This PR aims to add `unpivot`, `melt`, `transpose` API of `DataFrame`. ### Why are the changes needed? For feature parity. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes apache#123 from dongjoon-hyun/SPARK-52066. Authored-by: Dongjoon Hyun <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 7de35f7 commit 5ded181

File tree

3 files changed

+196
-0
lines changed

3 files changed

+196
-0
lines changed

Sources/SparkConnect/DataFrame.swift

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,10 @@ import Synchronization
109109
/// - ``dropDuplicatesWithinWatermark(_:)``
110110
/// - ``distinct()``
111111
/// - ``withColumnRenamed(_:_:)``
112+
/// - ``unpivot(_:_:_:)``
113+
/// - ``unpivot(_:_:_:_:)``
114+
/// - ``melt(_:_:_:)``
115+
/// - ``melt(_:_:_:_:)``
112116
///
113117
/// ### Join Operations
114118
/// - ``join(_:)``
@@ -1202,6 +1206,106 @@ public actor DataFrame: Sendable {
12021206
return dropDuplicates()
12031207
}
12041208

1209+
/// Transposes a DataFrame, switching rows to columns. This function transforms the DataFrame
1210+
/// such that the values in the first column become the new columns of the DataFrame.
1211+
/// - Returns: A transposed ``DataFrame``.
1212+
public func transpose() -> DataFrame {
1213+
return buildTranspose([])
1214+
}
1215+
1216+
/// Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns
1217+
/// set. This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation,
1218+
/// which cannot be reversed. This is an alias for `unpivot`.
1219+
/// - Parameters:
1220+
/// - ids: ID column names
1221+
/// - values: Value column names to unpivot
1222+
/// - variableColumnName: Name of the variable column
1223+
/// - valueColumnName: Name of the value column
1224+
/// - Returns: A ``DataFrame``.
1225+
public func melt(
1226+
_ ids: [String],
1227+
_ values: [String],
1228+
_ variableColumnName: String,
1229+
_ valueColumnName: String
1230+
) -> DataFrame {
1231+
return unpivot(ids, values, variableColumnName, valueColumnName)
1232+
}
1233+
1234+
/// Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns
1235+
/// set. This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation,
1236+
/// which cannot be reversed. This is an alias for `unpivot`.
1237+
/// - Parameters:
1238+
/// - ids: ID column names
1239+
/// - variableColumnName: Name of the variable column
1240+
/// - valueColumnName: Name of the value column
1241+
/// - Returns: A ``DataFrame``.
1242+
public func melt(
1243+
_ ids: [String],
1244+
_ variableColumnName: String,
1245+
_ valueColumnName: String
1246+
) -> DataFrame {
1247+
return unpivot(ids, variableColumnName, valueColumnName)
1248+
}
1249+
1250+
/// Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns
1251+
/// set. This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation,
1252+
/// which cannot be reversed.
1253+
/// - Parameters:
1254+
/// - ids: ID column names
1255+
/// - values: Value column names to unpivot
1256+
/// - variableColumnName: Name of the variable column
1257+
/// - valueColumnName: Name of the value column
1258+
/// - Returns: A ``DataFrame``.
1259+
public func unpivot(
1260+
_ ids: [String],
1261+
_ values: [String],
1262+
_ variableColumnName: String,
1263+
_ valueColumnName: String
1264+
) -> DataFrame {
1265+
return buildUnpivot(ids, values, variableColumnName, valueColumnName)
1266+
}
1267+
1268+
/// Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns
1269+
/// set. This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation,
1270+
/// which cannot be reversed.
1271+
/// - Parameters:
1272+
/// - ids: ID column names
1273+
/// - variableColumnName: Name of the variable column
1274+
/// - valueColumnName: Name of the value column
1275+
/// - Returns: A ``DataFrame``.
1276+
public func unpivot(
1277+
_ ids: [String],
1278+
_ variableColumnName: String,
1279+
_ valueColumnName: String
1280+
) -> DataFrame {
1281+
return buildUnpivot(ids, nil, variableColumnName, valueColumnName)
1282+
}
1283+
1284+
func buildUnpivot(
1285+
_ ids: [String],
1286+
_ values: [String]?,
1287+
_ variableColumnName: String,
1288+
_ valueColumnName: String,
1289+
) -> DataFrame {
1290+
let plan = SparkConnectClient.getUnpivot(self.plan.root, ids, values, variableColumnName, valueColumnName)
1291+
return DataFrame(spark: self.spark, plan: plan)
1292+
}
1293+
1294+
/// Transposes a ``DataFrame`` such that the values in the specified index column become the new
1295+
/// columns of the ``DataFrame``.
1296+
/// - Parameter indexColumn: The single column that will be treated as the index for the transpose operation.
1297+
/// This column will be used to pivot the data, transforming the DataFrame such that the values of
1298+
/// the indexColumn become the new columns in the transposed DataFrame.
1299+
/// - Returns: A transposed ``DataFrame``.
1300+
public func transpose(_ indexColumn: String) -> DataFrame {
1301+
return buildTranspose([indexColumn])
1302+
}
1303+
1304+
func buildTranspose(_ indexColumn: [String]) -> DataFrame {
1305+
let plan = SparkConnectClient.getTranspose(self.plan.root, indexColumn)
1306+
return DataFrame(spark: self.spark, plan: plan)
1307+
}
1308+
12051309
/// Groups the DataFrame using the specified columns.
12061310
///
12071311
/// This method is used to perform aggregations on groups of data.

Sources/SparkConnect/SparkConnectClient.swift

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -920,6 +920,53 @@ public actor SparkConnectClient {
920920
return plan
921921
}
922922

923+
static func getUnpivot(
924+
_ child: Relation,
925+
_ ids: [String],
926+
_ values: [String]?,
927+
_ variableColumnName: String,
928+
_ valueColumnName: String,
929+
) -> Plan {
930+
var unpivot = Spark_Connect_Unpivot()
931+
unpivot.input = child
932+
unpivot.ids = ids.map {
933+
var expr = Spark_Connect_Expression()
934+
expr.expressionString = $0.toExpressionString
935+
return expr
936+
}
937+
if let values {
938+
var unpivotValues = Spark_Connect_Unpivot.Values()
939+
unpivotValues.values = values.map {
940+
var expr = Spark_Connect_Expression()
941+
expr.expressionString = $0.toExpressionString
942+
return expr
943+
}
944+
unpivot.values = unpivotValues
945+
}
946+
unpivot.variableColumnName = variableColumnName
947+
unpivot.valueColumnName = valueColumnName
948+
var relation = Relation()
949+
relation.unpivot = unpivot
950+
var plan = Plan()
951+
plan.opType = .root(relation)
952+
return plan
953+
}
954+
955+
static func getTranspose(_ child: Relation, _ indexColumn: [String]) -> Plan {
956+
var transpose = Spark_Connect_Transpose()
957+
transpose.input = child
958+
transpose.indexColumns = indexColumn.map {
959+
var expr = Spark_Connect_Expression()
960+
expr.expressionString = $0.toExpressionString
961+
return expr
962+
}
963+
var relation = Relation()
964+
relation.transpose = transpose
965+
var plan = Plan()
966+
plan.opType = .root(relation)
967+
return plan
968+
}
969+
923970
func createTempView(
924971
_ child: Relation, _ viewName: String, replace: Bool, isGlobal: Bool
925972
) async throws {

Tests/SparkConnectTests/DataFrameTests.swift

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -788,6 +788,51 @@ struct DataFrameTests {
788788
#expect(try await spark.sql("SELECT 1 a, 2 b, 3 c").toJSON().collect() == expected)
789789
await spark.stop()
790790
}
791+
792+
@Test
793+
func unpivot() async throws {
794+
let spark = try await SparkSession.builder.getOrCreate()
795+
let df = try await spark.sql(
796+
"""
797+
SELECT * FROM
798+
VALUES (1, 11, 12L),
799+
(2, 21, 22L)
800+
T(id, int, long)
801+
""")
802+
let expected = [
803+
Row(1, "int", 11),
804+
Row(1, "long", 12),
805+
Row(2, "int", 21),
806+
Row(2, "long", 22),
807+
]
808+
#expect(try await df.unpivot(["id"], ["int", "long"], "variable", "value").collect() == expected)
809+
#expect(try await df.melt(["id"], ["int", "long"], "variable", "value").collect() == expected)
810+
await spark.stop()
811+
}
812+
813+
@Test
814+
func transpose() async throws {
815+
let spark = try await SparkSession.builder.getOrCreate()
816+
if await spark.version.starts(with: "4.") {
817+
#expect(try await spark.range(1).transpose().columns == ["key", "0"])
818+
#expect(try await spark.range(1).transpose().count() == 0)
819+
820+
let df = try await spark.sql(
821+
"""
822+
SELECT * FROM
823+
VALUES ('A', 1, 2),
824+
('B', 3, 4)
825+
T(id, val1, val2)
826+
""")
827+
let expected = [
828+
Row("val1", 1, 3),
829+
Row("val2", 2, 4),
830+
]
831+
#expect(try await df.transpose().collect() == expected)
832+
#expect(try await df.transpose("id").collect() == expected)
833+
}
834+
await spark.stop()
835+
}
791836
#endif
792837

793838
@Test

0 commit comments

Comments
 (0)