Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions Sources/SparkConnect/DataFrame.swift
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ import Synchronization
/// - ``dropDuplicatesWithinWatermark(_:)``
/// - ``distinct()``
/// - ``withColumnRenamed(_:_:)``
/// - ``unpivot(_:_:_:)``
/// - ``unpivot(_:_:_:_:)``
/// - ``melt(_:_:_:)``
/// - ``melt(_:_:_:_:)``
///
/// ### Join Operations
/// - ``join(_:)``
Expand Down Expand Up @@ -1202,6 +1206,106 @@ public actor DataFrame: Sendable {
return dropDuplicates()
}

/// Transposes a DataFrame, switching rows to columns. This function transforms the DataFrame
/// such that the values in the first column become the new columns of the DataFrame.
/// - Returns: A transposed ``DataFrame``.
public func transpose() -> DataFrame {
return buildTranspose([])
}

/// Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns
/// set. This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation,
/// which cannot be reversed. This is an alias for `unpivot`.
/// - Parameters:
/// - ids: ID column names
/// - values: Value column names to unpivot
/// - variableColumnName: Name of the variable column
/// - valueColumnName: Name of the value column
/// - Returns: A ``DataFrame``.
public func melt(
_ ids: [String],
_ values: [String],
_ variableColumnName: String,
_ valueColumnName: String
) -> DataFrame {
return unpivot(ids, values, variableColumnName, valueColumnName)
}

/// Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns
/// set. This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation,
/// which cannot be reversed. This is an alias for `unpivot`.
/// - Parameters:
/// - ids: ID column names
/// - variableColumnName: Name of the variable column
/// - valueColumnName: Name of the value column
/// - Returns: A ``DataFrame``.
public func melt(
_ ids: [String],
_ variableColumnName: String,
_ valueColumnName: String
) -> DataFrame {
return unpivot(ids, variableColumnName, valueColumnName)
}

/// Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns
/// set. This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation,
/// which cannot be reversed.
/// - Parameters:
/// - ids: ID column names
/// - values: Value column names to unpivot
/// - variableColumnName: Name of the variable column
/// - valueColumnName: Name of the value column
/// - Returns: A ``DataFrame``.
public func unpivot(
_ ids: [String],
_ values: [String],
_ variableColumnName: String,
_ valueColumnName: String
) -> DataFrame {
return buildUnpivot(ids, values, variableColumnName, valueColumnName)
}

/// Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns
/// set. This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation,
/// which cannot be reversed.
/// - Parameters:
/// - ids: ID column names
/// - variableColumnName: Name of the variable column
/// - valueColumnName: Name of the value column
/// - Returns: A ``DataFrame``.
public func unpivot(
_ ids: [String],
_ variableColumnName: String,
_ valueColumnName: String
) -> DataFrame {
return buildUnpivot(ids, nil, variableColumnName, valueColumnName)
}

func buildUnpivot(
_ ids: [String],
_ values: [String]?,
_ variableColumnName: String,
_ valueColumnName: String,
) -> DataFrame {
let plan = SparkConnectClient.getUnpivot(self.plan.root, ids, values, variableColumnName, valueColumnName)
return DataFrame(spark: self.spark, plan: plan)
}

/// Transposes a ``DataFrame`` such that the values in the specified index column become the new
/// columns of the ``DataFrame``.
/// - Parameter indexColumn: The single column that will be treated as the index for the transpose operation.
/// This column will be used to pivot the data, transforming the DataFrame such that the values of
/// the indexColumn become the new columns in the transposed DataFrame.
/// - Returns: A transposed ``DataFrame``.
public func transpose(_ indexColumn: String) -> DataFrame {
return buildTranspose([indexColumn])
}

func buildTranspose(_ indexColumn: [String]) -> DataFrame {
let plan = SparkConnectClient.getTranspose(self.plan.root, indexColumn)
return DataFrame(spark: self.spark, plan: plan)
}

/// Groups the DataFrame using the specified columns.
///
/// This method is used to perform aggregations on groups of data.
Expand Down
47 changes: 47 additions & 0 deletions Sources/SparkConnect/SparkConnectClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -920,6 +920,53 @@ public actor SparkConnectClient {
return plan
}

static func getUnpivot(
_ child: Relation,
_ ids: [String],
_ values: [String]?,
_ variableColumnName: String,
_ valueColumnName: String,
) -> Plan {
var unpivot = Spark_Connect_Unpivot()
unpivot.input = child
unpivot.ids = ids.map {
var expr = Spark_Connect_Expression()
expr.expressionString = $0.toExpressionString
return expr
}
if let values {
var unpivotValues = Spark_Connect_Unpivot.Values()
unpivotValues.values = values.map {
var expr = Spark_Connect_Expression()
expr.expressionString = $0.toExpressionString
return expr
}
unpivot.values = unpivotValues
}
unpivot.variableColumnName = variableColumnName
unpivot.valueColumnName = valueColumnName
var relation = Relation()
relation.unpivot = unpivot
var plan = Plan()
plan.opType = .root(relation)
return plan
}

static func getTranspose(_ child: Relation, _ indexColumn: [String]) -> Plan {
var transpose = Spark_Connect_Transpose()
transpose.input = child
transpose.indexColumns = indexColumn.map {
var expr = Spark_Connect_Expression()
expr.expressionString = $0.toExpressionString
return expr
}
var relation = Relation()
relation.transpose = transpose
var plan = Plan()
plan.opType = .root(relation)
return plan
}

func createTempView(
_ child: Relation, _ viewName: String, replace: Bool, isGlobal: Bool
) async throws {
Expand Down
45 changes: 45 additions & 0 deletions Tests/SparkConnectTests/DataFrameTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,51 @@ struct DataFrameTests {
#expect(try await spark.sql("SELECT 1 a, 2 b, 3 c").toJSON().collect() == expected)
await spark.stop()
}

@Test
func unpivot() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let df = try await spark.sql(
"""
SELECT * FROM
VALUES (1, 11, 12L),
(2, 21, 22L)
T(id, int, long)
""")
let expected = [
Row(1, "int", 11),
Row(1, "long", 12),
Row(2, "int", 21),
Row(2, "long", 22),
]
#expect(try await df.unpivot(["id"], ["int", "long"], "variable", "value").collect() == expected)
#expect(try await df.melt(["id"], ["int", "long"], "variable", "value").collect() == expected)
await spark.stop()
}

@Test
func transpose() async throws {
let spark = try await SparkSession.builder.getOrCreate()
if await spark.version.starts(with: "4.") {
#expect(try await spark.range(1).transpose().columns == ["key", "0"])
#expect(try await spark.range(1).transpose().count() == 0)

let df = try await spark.sql(
"""
SELECT * FROM
VALUES ('A', 1, 2),
('B', 3, 4)
T(id, val1, val2)
""")
let expected = [
Row("val1", 1, 3),
Row("val2", 2, 4),
]
#expect(try await df.transpose().collect() == expected)
#expect(try await df.transpose("id").collect() == expected)
}
await spark.stop()
}
#endif

@Test
Expand Down
Loading