Skip to content

Commit 5750079

Browse files
committed
[SPARK-52066] Support unpivot/melt/transpose in DataFrame
1 parent 7de35f7 commit 5750079

File tree

3 files changed

+190
-0
lines changed

3 files changed

+190
-0
lines changed

Sources/SparkConnect/DataFrame.swift

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1202,6 +1202,106 @@ public actor DataFrame: Sendable {
12021202
return dropDuplicates()
12031203
}
12041204

1205+
/// Transposes a DataFrame, switching rows to columns. This function transforms the DataFrame
1206+
/// such that the values in the first column become the new columns of the DataFrame.
1207+
/// - Returns: A transposed ``DataFrame``.
1208+
public func transpose() -> DataFrame {
1209+
return buildTranspose([])
1210+
}
1211+
1212+
/// Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns
1213+
/// set. This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation,
1214+
/// which cannot be reversed. This is an alias for `unpivot`.
1215+
/// - Parameters:
1216+
/// - ids: ID column names
1217+
/// - values: Value column names to unpivot
1218+
/// - variableColumnName: Name of the variable column
1219+
/// - valueColumnName: Name of the value column
1220+
/// - Returns: A ``DataFrame``.
1221+
public func melt(
1222+
_ ids: [String],
1223+
_ values: [String],
1224+
_ variableColumnName: String,
1225+
_ valueColumnName: String
1226+
) -> DataFrame {
1227+
return unpivot(ids, values, variableColumnName, valueColumnName)
1228+
}
1229+
1230+
/// Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns
1231+
/// set. This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation,
1232+
/// which cannot be reversed. This is an alias for `unpivot`.
1233+
/// - Parameters:
1234+
/// - ids: ID column names
1235+
/// - variableColumnName: Name of the variable column
1236+
/// - valueColumnName: Name of the value column
1237+
/// - Returns: A ``DataFrame``.
1238+
public func melt(
1239+
_ ids: [String],
1240+
_ variableColumnName: String,
1241+
_ valueColumnName: String
1242+
) -> DataFrame {
1243+
return unpivot(ids, variableColumnName, valueColumnName)
1244+
}
1245+
1246+
/// Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns
1247+
/// set. This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation,
1248+
/// which cannot be reversed.
1249+
/// - Parameters:
1250+
/// - ids: ID column names
1251+
/// - values: Value column names to unpivot
1252+
/// - variableColumnName: Name of the variable column
1253+
/// - valueColumnName: Name of the value column
1254+
/// - Returns: A ``DataFrame``.
1255+
public func unpivot(
1256+
_ ids: [String],
1257+
_ values: [String],
1258+
_ variableColumnName: String,
1259+
_ valueColumnName: String
1260+
) -> DataFrame {
1261+
return buildUnpivot(ids, values, variableColumnName, valueColumnName)
1262+
}
1263+
1264+
/// Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns
1265+
/// set. This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation,
1266+
/// which cannot be reversed.
1267+
/// - Parameters:
1268+
/// - ids: ID column names
1269+
/// - variableColumnName: Name of the variable column
1270+
/// - valueColumnName: Name of the value column
1271+
/// - Returns: A ``DataFrame``.
1272+
public func unpivot(
1273+
_ ids: [String],
1274+
_ variableColumnName: String,
1275+
_ valueColumnName: String
1276+
) -> DataFrame {
1277+
return buildUnpivot(ids, nil, variableColumnName, valueColumnName)
1278+
}
1279+
1280+
func buildUnpivot(
1281+
_ ids: [String],
1282+
_ values: [String]?,
1283+
_ variableColumnName: String,
1284+
_ valueColumnName: String,
1285+
) -> DataFrame {
1286+
let plan = SparkConnectClient.getUnpivot(self.plan.root, ids, values, variableColumnName, valueColumnName)
1287+
return DataFrame(spark: self.spark, plan: plan)
1288+
}
1289+
1290+
/// Transposes a ``DataFrame`` such that the values in the specified index column become the new
1291+
/// columns of the ``DataFrame``.
1292+
/// - Parameter indexColumn: The single column that will be treated as the index for the transpose operation.
1293+
/// This column will be used to pivot the data, transforming the DataFrame such that the values of
1294+
/// the indexColumn become the new columns in the transposed DataFrame.
1295+
/// - Returns: A transposed ``DataFrame``.
1296+
public func transpose(_ indexColumn: String) -> DataFrame {
1297+
return buildTranspose([indexColumn])
1298+
}
1299+
1300+
func buildTranspose(_ indexColumn: [String]) -> DataFrame {
1301+
let plan = SparkConnectClient.getTranspose(self.plan.root, indexColumn)
1302+
return DataFrame(spark: self.spark, plan: plan)
1303+
}
1304+
12051305
/// Groups the DataFrame using the specified columns.
12061306
///
12071307
/// This method is used to perform aggregations on groups of data.

Sources/SparkConnect/SparkConnectClient.swift

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -920,6 +920,53 @@ public actor SparkConnectClient {
920920
return plan
921921
}
922922

923+
static func getUnpivot(
924+
_ child: Relation,
925+
_ ids: [String],
926+
_ values: [String]?,
927+
_ variableColumnName: String,
928+
_ valueColumnName: String,
929+
) -> Plan {
930+
var unpivot = Spark_Connect_Unpivot()
931+
unpivot.input = child
932+
unpivot.ids = ids.map {
933+
var expr = Spark_Connect_Expression()
934+
expr.expressionString = $0.toExpressionString
935+
return expr
936+
}
937+
if let values {
938+
var unpivotValues = Spark_Connect_Unpivot.Values()
939+
unpivotValues.values = values.map {
940+
var expr = Spark_Connect_Expression()
941+
expr.expressionString = $0.toExpressionString
942+
return expr
943+
}
944+
unpivot.values = unpivotValues
945+
}
946+
unpivot.variableColumnName = variableColumnName
947+
unpivot.valueColumnName = valueColumnName
948+
var relation = Relation()
949+
relation.unpivot = unpivot
950+
var plan = Plan()
951+
plan.opType = .root(relation)
952+
return plan
953+
}
954+
955+
static func getTranspose(_ child: Relation, _ indexColumn: [String]) -> Plan {
956+
var transpose = Spark_Connect_Transpose()
957+
transpose.input = child
958+
transpose.indexColumns = indexColumn.map {
959+
var expr = Spark_Connect_Expression()
960+
expr.expressionString = $0.toExpressionString
961+
return expr
962+
}
963+
var relation = Relation()
964+
relation.transpose = transpose
965+
var plan = Plan()
966+
plan.opType = .root(relation)
967+
return plan
968+
}
969+
923970
func createTempView(
924971
_ child: Relation, _ viewName: String, replace: Bool, isGlobal: Bool
925972
) async throws {

Tests/SparkConnectTests/DataFrameTests.swift

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -788,6 +788,49 @@ struct DataFrameTests {
788788
#expect(try await spark.sql("SELECT 1 a, 2 b, 3 c").toJSON().collect() == expected)
789789
await spark.stop()
790790
}
791+
792+
@Test
793+
func unpivot() async throws {
794+
let spark = try await SparkSession.builder.getOrCreate()
795+
let df = try await spark.sql(
796+
"""
797+
SELECT * FROM
798+
VALUES (1, 11, 12L),
799+
(2, 21, 22L)
800+
T(id, int, long)
801+
""")
802+
let expected = [
803+
Row(1, "int", 11),
804+
Row(1, "long", 12),
805+
Row(2, "int", 21),
806+
Row(2, "long", 22),
807+
]
808+
#expect(try await df.unpivot(["id"], ["int", "long"], "variable", "value").collect() == expected)
809+
#expect(try await df.melt(["id"], ["int", "long"], "variable", "value").collect() == expected)
810+
await spark.stop()
811+
}
812+
813+
@Test
814+
func transpose() async throws {
815+
let spark = try await SparkSession.builder.getOrCreate()
816+
#expect(try await spark.range(1).transpose().columns == ["key", "0"])
817+
#expect(try await spark.range(1).transpose().count() == 0)
818+
819+
let df = try await spark.sql(
820+
"""
821+
SELECT * FROM
822+
VALUES ('A', 1, 2),
823+
('B', 3, 4)
824+
T(id, val1, val2)
825+
""")
826+
let expected = [
827+
Row("val1", 1, 3),
828+
Row("val2", 2, 4),
829+
]
830+
#expect(try await df.transpose().collect() == expected)
831+
#expect(try await df.transpose("id").collect() == expected)
832+
await spark.stop()
833+
}
791834
#endif
792835

793836
@Test

0 commit comments

Comments
 (0)