diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 8edb59f49282e..9699d8a2563fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -40,11 +40,11 @@ class SparkOptimizer( SchemaPruning, GroupBasedRowLevelOperationScanPlanning, V1Writes, + PushVariantIntoScan, V2ScanRelationPushDown, V2ScanPartitioningAndOrdering, V2Writes, - PruneFileSourcePartitions, - PushVariantIntoScan) + PruneFileSourcePartitions) override def preCBORules: Seq[Rule[LogicalPlan]] = Seq(OptimizeMetadataOnlyDeleteFromTable) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala index 5960cf8c38ced..c30b8adbdd05c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -279,6 +280,8 @@ object PushVariantIntoScan extends Rule[LogicalPlan] { relation @ LogicalRelationWithTable( hadoopFsRelation@HadoopFsRelation(_, _, _, _, _: ParquetFileFormat, _), _)) => rewritePlan(p, projectList, filters, relation, hadoopFsRelation) + case p@PhysicalOperation(projectList, filters, relation: DataSourceV2Relation) => + rewriteV2RelationPlan(p, projectList, filters, relation.output, relation) } } @@ -343,4 +346,102 @@ object PushVariantIntoScan extends Rule[LogicalPlan] { } Project(newProjectList, withFilter) } + + private def rewriteV2RelationPlan( + originalPlan: LogicalPlan, + projectList: Seq[NamedExpression], + filters: Seq[Expression], + relationOutput: Seq[AttributeReference], + relation: LogicalPlan): LogicalPlan = { + + // Collect variant fields from the relation output + val (variants, attributeMap) = collectAndRewriteVariants(relationOutput) + if (attributeMap.isEmpty) return originalPlan + + // Collect requested fields from projections and filters + projectList.foreach(variants.collectRequestedFields) + filters.foreach(variants.collectRequestedFields) + if (variants.mapping.forall(_._2.isEmpty)) return originalPlan + + // Build attribute map with rewritten types + val finalAttributeMap = buildAttributeMap(relationOutput, variants) + + // Rewrite the relation with new output + val newRelation = relation match { + case r: DataSourceV2Relation => + val newOutput = r.output.map(a => finalAttributeMap.getOrElse(a.exprId, a)) + r.copy(output = newOutput.toIndexedSeq) + case _ => return originalPlan + } + + // Build filter and project with rewritten expressions + buildFilterAndProject(newRelation, projectList, filters, variants, finalAttributeMap) + } + + /** + * Collect variant fields and return initialized VariantInRelation. + */ + private def collectAndRewriteVariants( + schemaAttributes: Seq[AttributeReference]): (VariantInRelation, Map[ExprId, Attribute]) = { + val variants = new VariantInRelation + val defaultValues = ResolveDefaultColumns.existenceDefaultValues(StructType( + schemaAttributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata)))) + + for ((a, defaultValue) <- schemaAttributes.zip(defaultValues)) { + variants.addVariantFields(a.exprId, a.dataType, defaultValue, Nil) + } + + val attributeMap = if (variants.mapping.isEmpty) { + Map.empty[ExprId, Attribute] + } else { + schemaAttributes.map(a => (a.exprId, a)).toMap + } + + (variants, attributeMap) + } + + /** + * Build attribute map with rewritten variant types. + */ + private def buildAttributeMap( + schemaAttributes: Seq[AttributeReference], + variants: VariantInRelation): Map[ExprId, AttributeReference] = { + schemaAttributes.map { a => + if (variants.mapping.get(a.exprId).exists(_.nonEmpty)) { + val newType = variants.rewriteType(a.exprId, a.dataType, Nil) + val newAttr = AttributeReference(a.name, newType, a.nullable, a.metadata)( + qualifier = a.qualifier) + (a.exprId, newAttr) + } else { + (a.exprId, a) + } + }.toMap + } + + /** + * Build the final Project(Filter(relation)) plan with rewritten expressions. + */ + private def buildFilterAndProject( + relation: LogicalPlan, + projectList: Seq[NamedExpression], + filters: Seq[Expression], + variants: VariantInRelation, + attributeMap: Map[ExprId, AttributeReference]): LogicalPlan = { + + val withFilter = if (filters.nonEmpty) { + Filter(filters.map(variants.rewriteExpr(_, attributeMap)).reduce(And), relation) + } else { + relation + } + + val newProjectList = projectList.map { e => + val rewritten = variants.rewriteExpr(e, attributeMap) + rewritten match { + case n: NamedExpression => n + case _ => Alias(rewritten, e.name)(e.exprId, e.qualifier) + } + } + + Project(newProjectList, withFilter) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/VariantV2ReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/VariantV2ReadSuite.scala new file mode 100644 index 0000000000000..a6521dfe76da1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/VariantV2ReadSuite.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.execution.datasources.VariantMetadata +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{IntegerType, StringType, StructType, VariantType} + +class VariantV2ReadSuite extends QueryTest with SharedSparkSession { + + private val testCatalogClass = "org.apache.spark.sql.connector.catalog.InMemoryTableCatalog" + + private def withV2Catalog(f: => Unit): Unit = { + withSQLConf( + SQLConf.DEFAULT_CATALOG.key -> "testcat", + s"spark.sql.catalog.testcat" -> testCatalogClass, + SQLConf.USE_V1_SOURCE_LIST.key -> "", + SQLConf.PUSH_VARIANT_INTO_SCAN.key -> "true", + SQLConf.VARIANT_ALLOW_READING_SHREDDED.key -> "true") { + f + } + } + + test("DSV2: push variant_get fields") { + withV2Catalog { + sql("DROP TABLE IF EXISTS testcat.ns.users") + sql( + """CREATE TABLE testcat.ns.users ( + | id bigint, + | name string, + | v variant, + | vd variant default parse_json('1') + |) USING parquet""".stripMargin) + + val out = sql( + """ + |SELECT + | id, + | variant_get(v, '$.username', 'string') as username, + | variant_get(v, '$.age', 'int') as age + |FROM testcat.ns.users + |WHERE variant_get(v, '$.status', 'string') = 'active' + |""".stripMargin) + + checkAnswer(out, Seq.empty) + + // Verify variant column rewrite + val optimized = out.queryExecution.optimizedPlan + val relOutput = optimized.collectFirst { + case s: DataSourceV2ScanRelation => s.output + }.getOrElse(fail("Expected DSv2 relation in optimized plan")) + + val vAttr = relOutput.find(_.name == "v").getOrElse(fail("Missing 'v' column")) + vAttr.dataType match { + case s: StructType => + assert(s.fields.length == 3, + s"Expected 3 fields (username, age, status), got ${s.fields.length}") + assert(s.fields.forall(_.metadata.contains(VariantMetadata.METADATA_KEY)), + "All fields should have VariantMetadata") + + val paths = s.fields.map(f => VariantMetadata.fromMetadata(f.metadata).path).toSet + assert(paths == Set("$.username", "$.age", "$.status"), + s"Expected username, age, status paths, got: $paths") + + val fieldTypes = s.fields.map(_.dataType).toSet + assert(fieldTypes.contains(StringType), "Expected StringType for string fields") + assert(fieldTypes.contains(IntegerType), "Expected IntegerType for age") + + case other => + fail(s"Expected StructType for 'v', got: $other") + } + + // Verify variant with default value is NOT rewritten + relOutput.find(_.name == "vd").foreach { vdAttr => + assert(vdAttr.dataType == VariantType, + "Variant column with default value should not be rewritten") + } + } + } + + test("DSV2: nested column pruning for variant struct") { + withV2Catalog { + sql("DROP TABLE IF EXISTS testcat.ns.users2") + sql( + """CREATE TABLE testcat.ns.users2 ( + | id bigint, + | name string, + | v variant + |) USING parquet""".stripMargin) + + val out = sql( + """ + |SELECT id, variant_get(v, '$.username', 'string') as username + |FROM testcat.ns.users2 + |""".stripMargin) + + checkAnswer(out, Seq.empty) + + val scan = out.queryExecution.executedPlan.collectFirst { + case b: BatchScanExec => b.scan + }.getOrElse(fail("Expected BatchScanExec in physical plan")) + + val readSchema = scan.readSchema() + + // Verify 'v' field exists and is a struct + val vField = readSchema.fields.find(_.name == "v").getOrElse( + fail("Expected 'v' field in read schema") + ) + + vField.dataType match { + case s: StructType => + assert(s.fields.length == 1, + "Expected only 1 field ($.username) in pruned schema, got " + s.fields.length + ": " + + s.fields.map(f => VariantMetadata.fromMetadata(f.metadata).path).mkString(", ")) + + val field = s.fields(0) + assert(field.metadata.contains(VariantMetadata.METADATA_KEY), + "Field should have VariantMetadata") + + val metadata = VariantMetadata.fromMetadata(field.metadata) + assert(metadata.path == "$.username", + "Expected path '$.username', got '" + metadata.path + "'") + assert(field.dataType == StringType, + s"Expected StringType, got ${field.dataType}") + + case other => + fail(s"Expected StructType for 'v' after rewrite and pruning, got: $other") + } + } + } +}