|
17 | 17 |
|
18 | 18 | package org.apache.spark.sql.catalyst.expressions
|
19 | 19 |
|
20 |
| -import org.scalatest.BeforeAndAfterAll |
21 | 20 | import org.scalatest.exceptions.TestFailedException
|
22 | 21 |
|
23 |
| -import org.apache.spark.SparkFunSuite |
| 22 | +import org.apache.spark.sql.catalyst.analysis.AnalysisTest |
24 | 23 | import org.apache.spark.sql.catalyst.dsl.plans._
|
25 | 24 | import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
|
26 | 25 | import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
|
27 | 26 | import org.apache.spark.sql.types._
|
28 | 27 |
|
29 |
| -class SelectedFieldSuite extends SparkFunSuite with BeforeAndAfterAll { |
| 28 | +class SelectedFieldSuite extends AnalysisTest { |
30 | 29 | private val ignoredField = StructField("col1", StringType, nullable = false)
|
31 | 30 |
|
32 | 31 | // The test schema as a tree string, i.e. `schema.treeString`
|
@@ -317,6 +316,18 @@ class SelectedFieldSuite extends SparkFunSuite with BeforeAndAfterAll {
|
317 | 316 | StructField("subfield1", IntegerType) :: Nil)) :: Nil), valueContainsNull = false)))
|
318 | 317 | }
|
319 | 318 |
|
| 319 | + testSelect(arrayOfStruct, "map_values(col5[0]).field1.subfield1 as foo") { |
| 320 | + StructField("col5", ArrayType(MapType(StringType, StructType( |
| 321 | + StructField("field1", StructType( |
| 322 | + StructField("subfield1", IntegerType) :: Nil)) :: Nil), valueContainsNull = false))) |
| 323 | + } |
| 324 | + |
| 325 | + testSelect(arrayOfStruct, "map_values(col5[0]).field1.subfield2 as foo") { |
| 326 | + StructField("col5", ArrayType(MapType(StringType, StructType( |
| 327 | + StructField("field1", StructType( |
| 328 | + StructField("subfield2", IntegerType) :: Nil)) :: Nil), valueContainsNull = false))) |
| 329 | + } |
| 330 | + |
320 | 331 | // |-- col1: string (nullable = false)
|
321 | 332 | // |-- col6: map (nullable = true)
|
322 | 333 | // | |-- key: string
|
@@ -394,6 +405,90 @@ class SelectedFieldSuite extends SparkFunSuite with BeforeAndAfterAll {
|
394 | 405 | :: Nil)))
|
395 | 406 | }
|
396 | 407 |
|
| 408 | + // |-- col1: string (nullable = false) |
| 409 | + // |-- col2: map (nullable = true) |
| 410 | + // | |-- key: struct (containsNull = false) |
| 411 | + // | | |-- field1: string (nullable = true) |
| 412 | + // | | |-- field2: integer (nullable = true) |
| 413 | + // | |-- value: array (valueContainsNull = true) |
| 414 | + // | | |-- element: struct (containsNull = false) |
| 415 | + // | | | |-- field3: struct (nullable = true) |
| 416 | + // | | | | |-- subfield1: integer (nullable = true) |
| 417 | + // | | | | |-- subfield2: integer (nullable = true) |
| 418 | + private val mapWithStructKey = StructType(Array(ignoredField, |
| 419 | + StructField("col2", MapType( |
| 420 | + StructType( |
| 421 | + StructField("field1", StringType) :: |
| 422 | + StructField("field2", IntegerType) :: Nil), |
| 423 | + ArrayType(StructType( |
| 424 | + StructField("field3", StructType( |
| 425 | + StructField("subfield1", IntegerType) :: |
| 426 | + StructField("subfield2", IntegerType) :: Nil)) :: Nil), containsNull = false))))) |
| 427 | + |
| 428 | + testSelect(mapWithStructKey, "map_keys(col2).field1 as foo") { |
| 429 | + StructField("col2", MapType( |
| 430 | + StructType(StructField("field1", StringType) :: Nil), |
| 431 | + ArrayType(StructType( |
| 432 | + StructField("field3", StructType( |
| 433 | + StructField("subfield1", IntegerType) :: |
| 434 | + StructField("subfield2", IntegerType) :: Nil)) :: Nil), containsNull = false))) |
| 435 | + } |
| 436 | + |
| 437 | + testSelect(mapWithStructKey, "map_keys(col2).field2 as foo") { |
| 438 | + StructField("col2", MapType( |
| 439 | + StructType(StructField("field2", IntegerType) :: Nil), |
| 440 | + ArrayType(StructType( |
| 441 | + StructField("field3", StructType( |
| 442 | + StructField("subfield1", IntegerType) :: |
| 443 | + StructField("subfield2", IntegerType) :: Nil)) :: Nil), containsNull = false))) |
| 444 | + } |
| 445 | + |
| 446 | + // |-- col1: string (nullable = false) |
| 447 | + // |-- col2: map (nullable = true) |
| 448 | + // | |-- key: array (valueContainsNull = true) |
| 449 | + // | | |-- element: struct (containsNull = false) |
| 450 | + // | | | |-- field1: string (nullable = true) |
| 451 | + // | | | |-- field2: struct (containsNull = false) |
| 452 | + // | | | | |-- subfield1: integer (nullable = true) |
| 453 | + // | | | | |-- subfield2: long (nullable = true) |
| 454 | + // | |-- value: array (valueContainsNull = true) |
| 455 | + // | | |-- element: struct (containsNull = false) |
| 456 | + // | | | |-- field3: struct (nullable = true) |
| 457 | + // | | | | |-- subfield3: integer (nullable = true) |
| 458 | + // | | | | |-- subfield4: integer (nullable = true) |
| 459 | + private val mapWithArrayOfStructKey = StructType(Array(ignoredField, |
| 460 | + StructField("col2", MapType( |
| 461 | + ArrayType(StructType( |
| 462 | + StructField("field1", StringType) :: |
| 463 | + StructField("field2", StructType( |
| 464 | + StructField("subfield1", IntegerType) :: |
| 465 | + StructField("subfield2", LongType) :: Nil)) :: Nil), containsNull = false), |
| 466 | + ArrayType(StructType( |
| 467 | + StructField("field3", StructType( |
| 468 | + StructField("subfield3", IntegerType) :: |
| 469 | + StructField("subfield4", IntegerType) :: Nil)) :: Nil), containsNull = false))))) |
| 470 | + |
| 471 | + testSelect(mapWithArrayOfStructKey, "map_keys(col2)[0].field1 as foo") { |
| 472 | + StructField("col2", MapType( |
| 473 | + ArrayType(StructType( |
| 474 | + StructField("field1", StringType) :: Nil), containsNull = false), |
| 475 | + ArrayType(StructType( |
| 476 | + StructField("field3", StructType( |
| 477 | + StructField("subfield3", IntegerType) :: |
| 478 | + StructField("subfield4", IntegerType) :: Nil)) :: Nil), containsNull = false))) |
| 479 | + } |
| 480 | + |
| 481 | + testSelect(mapWithArrayOfStructKey, "map_keys(col2)[0].field2.subfield1 as foo") { |
| 482 | + StructField("col2", MapType( |
| 483 | + ArrayType(StructType( |
| 484 | + StructField("field2", StructType( |
| 485 | + StructField("subfield1", IntegerType) :: Nil)) :: Nil), containsNull = false), |
| 486 | + ArrayType(StructType( |
| 487 | + StructField("field3", StructType( |
| 488 | + StructField("subfield3", IntegerType) :: |
| 489 | + StructField("subfield4", IntegerType) :: Nil)) :: Nil), containsNull = false))) |
| 490 | + } |
| 491 | + |
397 | 492 | def assertResult(expected: StructField)(actual: StructField)(selectExpr: String): Unit = {
|
398 | 493 | try {
|
399 | 494 | super.assertResult(expected)(actual)
|
@@ -439,7 +534,7 @@ class SelectedFieldSuite extends SparkFunSuite with BeforeAndAfterAll {
|
439 | 534 | private def unapplySelect(expr: String, relation: LocalRelation) = {
|
440 | 535 | val parsedExpr = parseAsCatalystExpression(Seq(expr)).head
|
441 | 536 | val select = relation.select(parsedExpr)
|
442 |
| - val analyzed = select.analyze |
| 537 | + val analyzed = caseSensitiveAnalyzer.execute(select) |
443 | 538 | SelectedField.unapply(analyzed.expressions.head)
|
444 | 539 | }
|
445 | 540 |
|
|
0 commit comments