Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ class SparkOptimizer(
SchemaPruning,
GroupBasedRowLevelOperationScanPlanning,
V1Writes,
PushVariantIntoScan,
V2ScanRelationPushDown,
V2ScanPartitioningAndOrdering,
V2Writes,
PruneFileSourcePartitions,
PushVariantIntoScan)
PruneFileSourcePartitions)

override def preCBORules: Seq[Rule[LogicalPlan]] =
Seq(OptimizeMetadataOnlyDeleteFromTable)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -279,6 +280,8 @@ object PushVariantIntoScan extends Rule[LogicalPlan] {
relation @ LogicalRelationWithTable(
hadoopFsRelation@HadoopFsRelation(_, _, _, _, _: ParquetFileFormat, _), _)) =>
rewritePlan(p, projectList, filters, relation, hadoopFsRelation)
case p@PhysicalOperation(projectList, filters, relation: DataSourceV2Relation) =>
rewriteV2RelationPlan(p, projectList, filters, relation.output, relation)
}
}

Expand Down Expand Up @@ -343,4 +346,102 @@ object PushVariantIntoScan extends Rule[LogicalPlan] {
}
Project(newProjectList, withFilter)
}

private def rewriteV2RelationPlan(
originalPlan: LogicalPlan,
projectList: Seq[NamedExpression],
filters: Seq[Expression],
relationOutput: Seq[AttributeReference],
relation: LogicalPlan): LogicalPlan = {

// Collect variant fields from the relation output
val (variants, attributeMap) = collectAndRewriteVariants(relationOutput)
if (attributeMap.isEmpty) return originalPlan

// Collect requested fields from projections and filters
projectList.foreach(variants.collectRequestedFields)
filters.foreach(variants.collectRequestedFields)
if (variants.mapping.forall(_._2.isEmpty)) return originalPlan

// Build attribute map with rewritten types
val finalAttributeMap = buildAttributeMap(relationOutput, variants)

// Rewrite the relation with new output
val newRelation = relation match {
case r: DataSourceV2Relation =>
val newOutput = r.output.map(a => finalAttributeMap.getOrElse(a.exprId, a))
r.copy(output = newOutput.toIndexedSeq)
case _ => return originalPlan
}

// Build filter and project with rewritten expressions
buildFilterAndProject(newRelation, projectList, filters, variants, finalAttributeMap)
}

/**
* Collect variant fields and return initialized VariantInRelation.
*/
private def collectAndRewriteVariants(
schemaAttributes: Seq[AttributeReference]): (VariantInRelation, Map[ExprId, Attribute]) = {
val variants = new VariantInRelation
val defaultValues = ResolveDefaultColumns.existenceDefaultValues(StructType(
schemaAttributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))))

for ((a, defaultValue) <- schemaAttributes.zip(defaultValues)) {
variants.addVariantFields(a.exprId, a.dataType, defaultValue, Nil)
}

val attributeMap = if (variants.mapping.isEmpty) {
Map.empty[ExprId, Attribute]
} else {
schemaAttributes.map(a => (a.exprId, a)).toMap
}

(variants, attributeMap)
}

/**
* Build attribute map with rewritten variant types.
*/
private def buildAttributeMap(
schemaAttributes: Seq[AttributeReference],
variants: VariantInRelation): Map[ExprId, AttributeReference] = {
schemaAttributes.map { a =>
if (variants.mapping.get(a.exprId).exists(_.nonEmpty)) {
val newType = variants.rewriteType(a.exprId, a.dataType, Nil)
val newAttr = AttributeReference(a.name, newType, a.nullable, a.metadata)(
qualifier = a.qualifier)
(a.exprId, newAttr)
} else {
(a.exprId, a)
}
}.toMap
}

/**
* Build the final Project(Filter(relation)) plan with rewritten expressions.
*/
private def buildFilterAndProject(
relation: LogicalPlan,
projectList: Seq[NamedExpression],
filters: Seq[Expression],
variants: VariantInRelation,
attributeMap: Map[ExprId, AttributeReference]): LogicalPlan = {

val withFilter = if (filters.nonEmpty) {
Filter(filters.map(variants.rewriteExpr(_, attributeMap)).reduce(And), relation)
} else {
relation
}

val newProjectList = projectList.map { e =>
val rewritten = variants.rewriteExpr(e, attributeMap)
rewritten match {
case n: NamedExpression => n
case _ => Alias(rewritten, e.name)(e.exprId, e.qualifier)
}
}

Project(newProjectList, withFilter)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.datasources.v2

import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.execution.datasources.VariantMetadata
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{IntegerType, StringType, StructType, VariantType}

class VariantV2ReadSuite extends QueryTest with SharedSparkSession {

private val testCatalogClass = "org.apache.spark.sql.connector.catalog.InMemoryTableCatalog"

private def withV2Catalog(f: => Unit): Unit = {
withSQLConf(
SQLConf.DEFAULT_CATALOG.key -> "testcat",
s"spark.sql.catalog.testcat" -> testCatalogClass,
SQLConf.USE_V1_SOURCE_LIST.key -> "",
SQLConf.PUSH_VARIANT_INTO_SCAN.key -> "true",
SQLConf.VARIANT_ALLOW_READING_SHREDDED.key -> "true") {
f
}
}

test("DSV2: push variant_get fields") {
withV2Catalog {
sql("DROP TABLE IF EXISTS testcat.ns.users")
sql(
"""CREATE TABLE testcat.ns.users (
| id bigint,
| name string,
| v variant,
| vd variant default parse_json('1')
|) USING parquet""".stripMargin)

val out = sql(
"""
|SELECT
| id,
| variant_get(v, '$.username', 'string') as username,
| variant_get(v, '$.age', 'int') as age
|FROM testcat.ns.users
|WHERE variant_get(v, '$.status', 'string') = 'active'
|""".stripMargin)

checkAnswer(out, Seq.empty)

// Verify variant column rewrite
val optimized = out.queryExecution.optimizedPlan
val relOutput = optimized.collectFirst {
case s: DataSourceV2ScanRelation => s.output
}.getOrElse(fail("Expected DSv2 relation in optimized plan"))

val vAttr = relOutput.find(_.name == "v").getOrElse(fail("Missing 'v' column"))
vAttr.dataType match {
case s: StructType =>
assert(s.fields.length == 3,
s"Expected 3 fields (username, age, status), got ${s.fields.length}")
assert(s.fields.forall(_.metadata.contains(VariantMetadata.METADATA_KEY)),
"All fields should have VariantMetadata")

val paths = s.fields.map(f => VariantMetadata.fromMetadata(f.metadata).path).toSet
assert(paths == Set("$.username", "$.age", "$.status"),
s"Expected username, age, status paths, got: $paths")

val fieldTypes = s.fields.map(_.dataType).toSet
assert(fieldTypes.contains(StringType), "Expected StringType for string fields")
assert(fieldTypes.contains(IntegerType), "Expected IntegerType for age")

case other =>
fail(s"Expected StructType for 'v', got: $other")
}

// Verify variant with default value is NOT rewritten
relOutput.find(_.name == "vd").foreach { vdAttr =>
assert(vdAttr.dataType == VariantType,
"Variant column with default value should not be rewritten")
}
}
}

test("DSV2: nested column pruning for variant struct") {
withV2Catalog {
sql("DROP TABLE IF EXISTS testcat.ns.users2")
sql(
"""CREATE TABLE testcat.ns.users2 (
| id bigint,
| name string,
| v variant
|) USING parquet""".stripMargin)

val out = sql(
"""
|SELECT id, variant_get(v, '$.username', 'string') as username
|FROM testcat.ns.users2
|""".stripMargin)

checkAnswer(out, Seq.empty)

val scan = out.queryExecution.executedPlan.collectFirst {
case b: BatchScanExec => b.scan
}.getOrElse(fail("Expected BatchScanExec in physical plan"))

val readSchema = scan.readSchema()

// Verify 'v' field exists and is a struct
val vField = readSchema.fields.find(_.name == "v").getOrElse(
fail("Expected 'v' field in read schema")
)

vField.dataType match {
case s: StructType =>
assert(s.fields.length == 1,
"Expected only 1 field ($.username) in pruned schema, got " + s.fields.length + ": " +
s.fields.map(f => VariantMetadata.fromMetadata(f.metadata).path).mkString(", "))

val field = s.fields(0)
assert(field.metadata.contains(VariantMetadata.METADATA_KEY),
"Field should have VariantMetadata")

val metadata = VariantMetadata.fromMetadata(field.metadata)
assert(metadata.path == "$.username",
"Expected path '$.username', got '" + metadata.path + "'")
assert(field.dataType == StringType,
s"Expected StringType, got ${field.dataType}")

case other =>
fail(s"Expected StructType for 'v' after rewrite and pruning, got: $other")
}
}
}
}