Skip to content

Commit 6f18ac9

Browse files
viiryadongjoon-hyun
authored andcommitted
[SPARK-27241][SQL] Support map_keys and map_values in SelectedField
## What changes were proposed in this pull request? `SelectedField` doesn't support map_keys and map_values for now. When map key or value is complex struct, we should be able to prune unnecessary fields from keys/values. This proposes to add map_keys and map_values support to `SelectedField`. ## How was this patch tested? Added tests. Closes apache#24179 from viirya/SPARK-27241. Authored-by: Liang-Chi Hsieh <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 01e6305 commit 6f18ac9

File tree

2 files changed

+121
-4
lines changed

2 files changed

+121
-4
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,28 @@ object SelectedField {
9797
val MapType(keyType, _, valueContainsNull) = child.dataType
9898
val opt = dataTypeOpt.map(dt => MapType(keyType, dt, valueContainsNull))
9999
selectField(child, opt)
100+
case MapValues(child) =>
101+
val MapType(keyType, _, valueContainsNull) = child.dataType
102+
// MapValues does not select a field from a struct (i.e. prune the struct) so it can't be
103+
// the top-level extractor. However it can be part of an extractor chain.
104+
val opt = dataTypeOpt.map {
105+
case ArrayType(dataType, _) => MapType(keyType, dataType, valueContainsNull)
106+
case x =>
107+
// This should not happen.
108+
throw new AnalysisException(s"DataType '$x' is not supported by MapValues.")
109+
}
110+
selectField(child, opt)
111+
case MapKeys(child) =>
112+
val MapType(_, valueType, valueContainsNull) = child.dataType
113+
// MapKeys does not select a field from a struct (i.e. prune the struct) so it can't be
114+
// the top-level extractor. However it can be part of an extractor chain.
115+
val opt = dataTypeOpt.map {
116+
case ArrayType(dataType, _) => MapType(dataType, valueType, valueContainsNull)
117+
case x =>
118+
// This should not happen.
119+
throw new AnalysisException(s"DataType '$x' is not supported by MapKeys.")
120+
}
121+
selectField(child, opt)
100122
case GetArrayItem(child, _) =>
101123
// GetArrayItem does not select a field from a struct (i.e. prune the struct) so it can't be
102124
// the top-level extractor. However it can be part of an extractor chain.

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SelectedFieldSuite.scala

Lines changed: 99 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,15 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20-
import org.scalatest.BeforeAndAfterAll
2120
import org.scalatest.exceptions.TestFailedException
2221

23-
import org.apache.spark.SparkFunSuite
22+
import org.apache.spark.sql.catalyst.analysis.AnalysisTest
2423
import org.apache.spark.sql.catalyst.dsl.plans._
2524
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
2625
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
2726
import org.apache.spark.sql.types._
2827

29-
class SelectedFieldSuite extends SparkFunSuite with BeforeAndAfterAll {
28+
class SelectedFieldSuite extends AnalysisTest {
3029
private val ignoredField = StructField("col1", StringType, nullable = false)
3130

3231
// The test schema as a tree string, i.e. `schema.treeString`
@@ -317,6 +316,18 @@ class SelectedFieldSuite extends SparkFunSuite with BeforeAndAfterAll {
317316
StructField("subfield1", IntegerType) :: Nil)) :: Nil), valueContainsNull = false)))
318317
}
319318

319+
testSelect(arrayOfStruct, "map_values(col5[0]).field1.subfield1 as foo") {
320+
StructField("col5", ArrayType(MapType(StringType, StructType(
321+
StructField("field1", StructType(
322+
StructField("subfield1", IntegerType) :: Nil)) :: Nil), valueContainsNull = false)))
323+
}
324+
325+
testSelect(arrayOfStruct, "map_values(col5[0]).field1.subfield2 as foo") {
326+
StructField("col5", ArrayType(MapType(StringType, StructType(
327+
StructField("field1", StructType(
328+
StructField("subfield2", IntegerType) :: Nil)) :: Nil), valueContainsNull = false)))
329+
}
330+
320331
// |-- col1: string (nullable = false)
321332
// |-- col6: map (nullable = true)
322333
// | |-- key: string
@@ -394,6 +405,90 @@ class SelectedFieldSuite extends SparkFunSuite with BeforeAndAfterAll {
394405
:: Nil)))
395406
}
396407

408+
// |-- col1: string (nullable = false)
409+
// |-- col2: map (nullable = true)
410+
// | |-- key: struct (containsNull = false)
411+
// | | |-- field1: string (nullable = true)
412+
// | | |-- field2: integer (nullable = true)
413+
// | |-- value: array (valueContainsNull = true)
414+
// | | |-- element: struct (containsNull = false)
415+
// | | | |-- field3: struct (nullable = true)
416+
// | | | | |-- subfield1: integer (nullable = true)
417+
// | | | | |-- subfield2: integer (nullable = true)
418+
private val mapWithStructKey = StructType(Array(ignoredField,
419+
StructField("col2", MapType(
420+
StructType(
421+
StructField("field1", StringType) ::
422+
StructField("field2", IntegerType) :: Nil),
423+
ArrayType(StructType(
424+
StructField("field3", StructType(
425+
StructField("subfield1", IntegerType) ::
426+
StructField("subfield2", IntegerType) :: Nil)) :: Nil), containsNull = false)))))
427+
428+
testSelect(mapWithStructKey, "map_keys(col2).field1 as foo") {
429+
StructField("col2", MapType(
430+
StructType(StructField("field1", StringType) :: Nil),
431+
ArrayType(StructType(
432+
StructField("field3", StructType(
433+
StructField("subfield1", IntegerType) ::
434+
StructField("subfield2", IntegerType) :: Nil)) :: Nil), containsNull = false)))
435+
}
436+
437+
testSelect(mapWithStructKey, "map_keys(col2).field2 as foo") {
438+
StructField("col2", MapType(
439+
StructType(StructField("field2", IntegerType) :: Nil),
440+
ArrayType(StructType(
441+
StructField("field3", StructType(
442+
StructField("subfield1", IntegerType) ::
443+
StructField("subfield2", IntegerType) :: Nil)) :: Nil), containsNull = false)))
444+
}
445+
446+
// |-- col1: string (nullable = false)
447+
// |-- col2: map (nullable = true)
448+
// | |-- key: array (valueContainsNull = true)
449+
// | | |-- element: struct (containsNull = false)
450+
// | | | |-- field1: string (nullable = true)
451+
// | | | |-- field2: struct (containsNull = false)
452+
// | | | | |-- subfield1: integer (nullable = true)
453+
// | | | | |-- subfield2: long (nullable = true)
454+
// | |-- value: array (valueContainsNull = true)
455+
// | | |-- element: struct (containsNull = false)
456+
// | | | |-- field3: struct (nullable = true)
457+
// | | | | |-- subfield3: integer (nullable = true)
458+
// | | | | |-- subfield4: integer (nullable = true)
459+
private val mapWithArrayOfStructKey = StructType(Array(ignoredField,
460+
StructField("col2", MapType(
461+
ArrayType(StructType(
462+
StructField("field1", StringType) ::
463+
StructField("field2", StructType(
464+
StructField("subfield1", IntegerType) ::
465+
StructField("subfield2", LongType) :: Nil)) :: Nil), containsNull = false),
466+
ArrayType(StructType(
467+
StructField("field3", StructType(
468+
StructField("subfield3", IntegerType) ::
469+
StructField("subfield4", IntegerType) :: Nil)) :: Nil), containsNull = false)))))
470+
471+
testSelect(mapWithArrayOfStructKey, "map_keys(col2)[0].field1 as foo") {
472+
StructField("col2", MapType(
473+
ArrayType(StructType(
474+
StructField("field1", StringType) :: Nil), containsNull = false),
475+
ArrayType(StructType(
476+
StructField("field3", StructType(
477+
StructField("subfield3", IntegerType) ::
478+
StructField("subfield4", IntegerType) :: Nil)) :: Nil), containsNull = false)))
479+
}
480+
481+
testSelect(mapWithArrayOfStructKey, "map_keys(col2)[0].field2.subfield1 as foo") {
482+
StructField("col2", MapType(
483+
ArrayType(StructType(
484+
StructField("field2", StructType(
485+
StructField("subfield1", IntegerType) :: Nil)) :: Nil), containsNull = false),
486+
ArrayType(StructType(
487+
StructField("field3", StructType(
488+
StructField("subfield3", IntegerType) ::
489+
StructField("subfield4", IntegerType) :: Nil)) :: Nil), containsNull = false)))
490+
}
491+
397492
def assertResult(expected: StructField)(actual: StructField)(selectExpr: String): Unit = {
398493
try {
399494
super.assertResult(expected)(actual)
@@ -439,7 +534,7 @@ class SelectedFieldSuite extends SparkFunSuite with BeforeAndAfterAll {
439534
private def unapplySelect(expr: String, relation: LocalRelation) = {
440535
val parsedExpr = parseAsCatalystExpression(Seq(expr)).head
441536
val select = relation.select(parsedExpr)
442-
val analyzed = select.analyze
537+
val analyzed = caseSensitiveAnalyzer.execute(select)
443538
SelectedField.unapply(analyzed.expressions.head)
444539
}
445540

0 commit comments

Comments
 (0)