Skip to content

Commit 623c2ec

Browse files
MaxGekkHyukjinKwon
authored andcommitted
[SPARK-25048][SQL] Pivoting by multiple columns in Scala/Java
## What changes were proposed in this pull request? In the PR, I propose to extend implementation of existing method: ``` def pivot(pivotColumn: Column, values: Seq[Any]): RelationalGroupedDataset ``` to support values of the struct type. This allows pivoting by multiple columns combined by `struct`: ``` trainingSales .groupBy($"sales.year") .pivot( pivotColumn = struct(lower($"sales.course"), $"training"), values = Seq( struct(lit("dotnet"), lit("Experts")), struct(lit("java"), lit("Dummies"))) ).agg(sum($"sales.earnings")) ``` ## How was this patch tested? Added a test for values specified via `struct` in Java and Scala. Closes apache#22316 from MaxGekk/pivoting-by-multiple-columns2. Lead-authored-by: Maxim Gekk <[email protected]> Co-authored-by: Maxim Gekk <[email protected]> Signed-off-by: hyukjinkwon <[email protected]>
1 parent dcb9a97 commit 623c2ec

File tree

3 files changed

+54
-2
lines changed

3 files changed

+54
-2
lines changed

sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,15 @@ class RelationalGroupedDataset protected[sql](
330330
* df.groupBy("year").pivot("course").sum("earnings")
331331
* }}}
332332
*
333+
* From Spark 2.5.0, values can be literal columns, for instance, struct. For pivoting by
334+
* multiple columns, use the `struct` function to combine the columns and values:
335+
*
336+
* {{{
337+
* df.groupBy("year")
338+
* .pivot("trainingCourse", Seq(struct(lit("java"), lit("Experts"))))
339+
* .agg(sum($"earnings"))
340+
* }}}
341+
*
333342
* @param pivotColumn Name of the column to pivot.
334343
* @param values List of values that will be translated to columns in the output DataFrame.
335344
* @since 1.6.0
@@ -413,10 +422,14 @@ class RelationalGroupedDataset protected[sql](
413422
def pivot(pivotColumn: Column, values: Seq[Any]): RelationalGroupedDataset = {
414423
groupType match {
415424
case RelationalGroupedDataset.GroupByType =>
425+
val valueExprs = values.map(_ match {
426+
case c: Column => c.expr
427+
case v => Literal.apply(v)
428+
})
416429
new RelationalGroupedDataset(
417430
df,
418431
groupingExprs,
419-
RelationalGroupedDataset.PivotType(pivotColumn.expr, values.map(Literal.apply)))
432+
RelationalGroupedDataset.PivotType(pivotColumn.expr, valueExprs))
420433
case _: RelationalGroupedDataset.PivotType =>
421434
throw new UnsupportedOperationException("repeated pivots are not supported")
422435
case _ =>
@@ -561,5 +574,5 @@ private[sql] object RelationalGroupedDataset {
561574
/**
562575
* To indicate it's the PIVOT
563576
*/
564-
private[sql] case class PivotType(pivotCol: Expression, values: Seq[Literal]) extends GroupType
577+
private[sql] case class PivotType(pivotCol: Expression, values: Seq[Expression]) extends GroupType
565578
}

sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,22 @@ public void pivot() {
317317
Assert.assertEquals(30000.0, actual.get(1).getDouble(2), 0.01);
318318
}
319319

320+
@Test
321+
public void pivotColumnValues() {
322+
Dataset<Row> df = spark.table("courseSales");
323+
List<Row> actual = df.groupBy("year")
324+
.pivot(col("course"), Arrays.asList(lit("dotNET"), lit("Java")))
325+
.agg(sum("earnings")).orderBy("year").collectAsList();
326+
327+
Assert.assertEquals(2012, actual.get(0).getInt(0));
328+
Assert.assertEquals(15000.0, actual.get(0).getDouble(1), 0.01);
329+
Assert.assertEquals(20000.0, actual.get(0).getDouble(2), 0.01);
330+
331+
Assert.assertEquals(2013, actual.get(1).getInt(0));
332+
Assert.assertEquals(48000.0, actual.get(1).getDouble(1), 0.01);
333+
Assert.assertEquals(30000.0, actual.get(1).getDouble(2), 0.01);
334+
}
335+
320336
private String getResource(String resource) {
321337
try {
322338
// The following "getResource" has different behaviors in SBT and Maven.

sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,4 +308,27 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext {
308308

309309
assert(exception.getMessage.contains("aggregate functions are not allowed"))
310310
}
311+
312+
test("pivoting column list with values") {
313+
val expected = Row(2012, 10000.0, null) :: Row(2013, 48000.0, 30000.0) :: Nil
314+
val df = trainingSales
315+
.groupBy($"sales.year")
316+
.pivot(struct(lower($"sales.course"), $"training"), Seq(
317+
struct(lit("dotnet"), lit("Experts")),
318+
struct(lit("java"), lit("Dummies")))
319+
).agg(sum($"sales.earnings"))
320+
321+
checkAnswer(df, expected)
322+
}
323+
324+
test("pivoting column list") {
325+
val exception = intercept[RuntimeException] {
326+
trainingSales
327+
.groupBy($"sales.year")
328+
.pivot(struct(lower($"sales.course"), $"training"))
329+
.agg(sum($"sales.earnings"))
330+
.collect()
331+
}
332+
assert(exception.getMessage.contains("Unsupported literal type"))
333+
}
311334
}

0 commit comments

Comments
 (0)