Skip to content

Commit 482b03c

Browse files
committed
Revert "[SPARK-54760][SQL] DelegatingCatalogExtension as session catalog supports both V1 and V2 functions"
This reverts commit 51042d6.
1 parent cb76c66 commit 482b03c

File tree

10 files changed

+126
-197
lines changed

10 files changed

+126
-197
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,17 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
2626
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2727
import org.apache.spark.sql.connector.catalog.{
2828
CatalogManager,
29+
CatalogV2Util,
30+
FunctionCatalog,
31+
Identifier,
2932
LookupCatalog
3033
}
3134
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
3235
import org.apache.spark.sql.connector.catalog.functions.{
3336
AggregateFunction => V2AggregateFunction,
34-
ScalarFunction,
35-
UnboundFunction
37+
ScalarFunction
3638
}
3739
import org.apache.spark.sql.errors.{DataTypeErrorsBase, QueryCompilationErrors}
38-
import org.apache.spark.sql.internal.connector.V1Function
3940
import org.apache.spark.sql.types._
4041

4142
class FunctionResolution(
@@ -51,14 +52,10 @@ class FunctionResolution(
5152
resolveBuiltinOrTempFunction(u.nameParts, u.arguments, u).getOrElse {
5253
val CatalogAndIdentifier(catalog, ident) =
5354
relationResolution.expandIdentifier(u.nameParts)
54-
catalog.asFunctionCatalog.loadFunction(ident) match {
55-
case V1Function(_) =>
56-
// this triggers the second time v1 function resolution but should be cheap
57-
// (no RPC to external catalog), since the metadata has been already cached
58-
// in FunctionRegistry during the above `catalog.loadFunction` call.
59-
resolveV1Function(ident.asFunctionIdentifier, u.arguments, u)
60-
case unboundV2Func =>
61-
resolveV2Function(unboundV2Func, u.arguments, u)
55+
if (CatalogV2Util.isSessionCatalog(catalog)) {
56+
resolveV1Function(ident.asFunctionIdentifier, u.arguments, u)
57+
} else {
58+
resolveV2Function(catalog.asFunctionCatalog, ident, u.arguments, u)
6259
}
6360
}
6461
}
@@ -275,9 +272,11 @@ class FunctionResolution(
275272
}
276273

277274
private def resolveV2Function(
278-
unbound: UnboundFunction,
275+
catalog: FunctionCatalog,
276+
ident: Identifier,
279277
arguments: Seq[Expression],
280278
u: UnresolvedFunction): Expression = {
279+
val unbound = catalog.loadFunction(ident)
281280
val inputType = StructType(arguments.zipWithIndex.map {
282281
case (exp, pos) => StructField(s"_$pos", exp.dataType, exp.nullable)
283282
})

sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause-legacy.sql.out

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -972,11 +972,11 @@ VALUES(IDENTIFIER('a.b.c.d')())
972972
-- !query analysis
973973
org.apache.spark.sql.AnalysisException
974974
{
975-
"errorClass" : "REQUIRES_SINGLE_PART_NAMESPACE",
976-
"sqlState" : "42K05",
975+
"errorClass" : "IDENTIFIER_TOO_MANY_NAME_PARTS",
976+
"sqlState" : "42601",
977977
"messageParameters" : {
978-
"namespace" : "`a`.`b`.`c`",
979-
"sessionCatalog" : "spark_catalog"
978+
"identifier" : "`a`.`b`.`c`.`d`",
979+
"limit" : "2"
980980
},
981981
"queryContext" : [ {
982982
"objectType" : "",

sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -972,11 +972,11 @@ VALUES(IDENTIFIER('a.b.c.d')())
972972
-- !query analysis
973973
org.apache.spark.sql.AnalysisException
974974
{
975-
"errorClass" : "REQUIRES_SINGLE_PART_NAMESPACE",
976-
"sqlState" : "42K05",
975+
"errorClass" : "IDENTIFIER_TOO_MANY_NAME_PARTS",
976+
"sqlState" : "42601",
977977
"messageParameters" : {
978-
"namespace" : "`a`.`b`.`c`",
979-
"sessionCatalog" : "spark_catalog"
978+
"identifier" : "`a`.`b`.`c`.`d`",
979+
"limit" : "2"
980980
},
981981
"queryContext" : [ {
982982
"objectType" : "",

sql/core/src/test/resources/sql-tests/results/identifier-clause-legacy.sql.out

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1112,11 +1112,11 @@ struct<>
11121112
-- !query output
11131113
org.apache.spark.sql.AnalysisException
11141114
{
1115-
"errorClass" : "REQUIRES_SINGLE_PART_NAMESPACE",
1116-
"sqlState" : "42K05",
1115+
"errorClass" : "IDENTIFIER_TOO_MANY_NAME_PARTS",
1116+
"sqlState" : "42601",
11171117
"messageParameters" : {
1118-
"namespace" : "`a`.`b`.`c`",
1119-
"sessionCatalog" : "spark_catalog"
1118+
"identifier" : "`a`.`b`.`c`.`d`",
1119+
"limit" : "2"
11201120
},
11211121
"queryContext" : [ {
11221122
"objectType" : "",

sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1112,11 +1112,11 @@ struct<>
11121112
-- !query output
11131113
org.apache.spark.sql.AnalysisException
11141114
{
1115-
"errorClass" : "REQUIRES_SINGLE_PART_NAMESPACE",
1116-
"sqlState" : "42K05",
1115+
"errorClass" : "IDENTIFIER_TOO_MANY_NAME_PARTS",
1116+
"sqlState" : "42601",
11171117
"messageParameters" : {
1118-
"namespace" : "`a`.`b`.`c`",
1119-
"sessionCatalog" : "spark_catalog"
1118+
"identifier" : "`a`.`b`.`c`.`d`",
1119+
"limit" : "2"
11201120
},
11211121
"queryContext" : [ {
11221122
"objectType" : "",

sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -168,10 +168,6 @@ private [connector] trait SessionCatalogTest[T <: Table, Catalog <: TestV2Sessio
168168
spark.sessionState.catalogManager.catalog(name)
169169
}
170170

171-
protected def sessionCatalog: Catalog = {
172-
catalog(SESSION_CATALOG_NAME).asInstanceOf[Catalog]
173-
}
174-
175171
protected val v2Format: String = classOf[FakeV2ProviderWithCustomSchema].getName
176172

177173
protected val catalogClassName: String = classOf[InMemoryTableSessionCatalog].getName
@@ -182,9 +178,7 @@ private [connector] trait SessionCatalogTest[T <: Table, Catalog <: TestV2Sessio
182178

183179
override def afterEach(): Unit = {
184180
super.afterEach()
185-
sessionCatalog.checkUsage()
186-
sessionCatalog.clearTables()
187-
sessionCatalog.clearFunctions()
181+
catalog(SESSION_CATALOG_NAME).asInstanceOf[Catalog].clearTables()
188182
spark.conf.unset(V2_SESSION_CATALOG_IMPLEMENTATION.key)
189183
}
190184

sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala

Lines changed: 91 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -702,127 +702,127 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase {
702702
comparePlans(df1.queryExecution.optimizedPlan, df2.queryExecution.optimizedPlan)
703703
checkAnswer(df1, Row(3) :: Nil)
704704
}
705-
}
706705

707-
case object StrLenDefault extends ScalarFunction[Int] {
708-
override def inputTypes(): Array[DataType] = Array(StringType)
709-
override def resultType(): DataType = IntegerType
710-
override def name(): String = "strlen_default"
706+
private case object StrLenDefault extends ScalarFunction[Int] {
707+
override def inputTypes(): Array[DataType] = Array(StringType)
708+
override def resultType(): DataType = IntegerType
709+
override def name(): String = "strlen_default"
711710

712-
override def produceResult(input: InternalRow): Int = {
713-
val s = input.getString(0)
714-
s.length
711+
override def produceResult(input: InternalRow): Int = {
712+
val s = input.getString(0)
713+
s.length
714+
}
715715
}
716-
}
717716

718-
case object StrLenMagic extends ScalarFunction[Int] {
719-
override def inputTypes(): Array[DataType] = Array(StringType)
720-
override def resultType(): DataType = IntegerType
721-
override def name(): String = "strlen_magic"
717+
case object StrLenMagic extends ScalarFunction[Int] {
718+
override def inputTypes(): Array[DataType] = Array(StringType)
719+
override def resultType(): DataType = IntegerType
720+
override def name(): String = "strlen_magic"
722721

723-
def invoke(input: UTF8String): Int = {
724-
input.toString.length
722+
def invoke(input: UTF8String): Int = {
723+
input.toString.length
724+
}
725725
}
726-
}
727726

728-
case object StrLenBadMagic extends ScalarFunction[Int] {
729-
override def inputTypes(): Array[DataType] = Array(StringType)
730-
override def resultType(): DataType = IntegerType
731-
override def name(): String = "strlen_bad_magic"
727+
case object StrLenBadMagic extends ScalarFunction[Int] {
728+
override def inputTypes(): Array[DataType] = Array(StringType)
729+
override def resultType(): DataType = IntegerType
730+
override def name(): String = "strlen_bad_magic"
732731

733-
def invoke(input: String): Int = {
734-
input.length
732+
def invoke(input: String): Int = {
733+
input.length
734+
}
735735
}
736-
}
737736

738-
case object StrLenBadMagicWithDefault extends ScalarFunction[Int] {
739-
override def inputTypes(): Array[DataType] = Array(StringType)
740-
override def resultType(): DataType = IntegerType
741-
override def name(): String = "strlen_bad_magic"
737+
case object StrLenBadMagicWithDefault extends ScalarFunction[Int] {
738+
override def inputTypes(): Array[DataType] = Array(StringType)
739+
override def resultType(): DataType = IntegerType
740+
override def name(): String = "strlen_bad_magic"
741+
742+
def invoke(input: String): Int = {
743+
input.length
744+
}
742745

743-
def invoke(input: String): Int = {
744-
input.length
746+
override def produceResult(input: InternalRow): Int = {
747+
val s = input.getString(0)
748+
s.length
749+
}
745750
}
746751

747-
override def produceResult(input: InternalRow): Int = {
748-
val s = input.getString(0)
749-
s.length
752+
private case object StrLenNoImpl extends ScalarFunction[Int] {
753+
override def inputTypes(): Array[DataType] = Array(StringType)
754+
override def resultType(): DataType = IntegerType
755+
override def name(): String = "strlen_noimpl"
750756
}
751-
}
752757

753-
case object StrLenNoImpl extends ScalarFunction[Int] {
754-
override def inputTypes(): Array[DataType] = Array(StringType)
755-
override def resultType(): DataType = IntegerType
756-
override def name(): String = "strlen_noimpl"
757-
}
758+
// input type doesn't match arguments accepted by `UnboundFunction.bind`
759+
private case object StrLenBadInputTypes extends ScalarFunction[Int] {
760+
override def inputTypes(): Array[DataType] = Array(StringType, IntegerType)
761+
override def resultType(): DataType = IntegerType
762+
override def name(): String = "strlen_bad_input_types"
763+
}
758764

759-
// input type doesn't match arguments accepted by `UnboundFunction.bind`
760-
case object StrLenBadInputTypes extends ScalarFunction[Int] {
761-
override def inputTypes(): Array[DataType] = Array(StringType, IntegerType)
762-
override def resultType(): DataType = IntegerType
763-
override def name(): String = "strlen_bad_input_types"
764-
}
765+
private case object BadBoundFunction extends BoundFunction {
766+
override def inputTypes(): Array[DataType] = Array(StringType)
767+
override def resultType(): DataType = IntegerType
768+
override def name(): String = "bad_bound_func"
769+
}
765770

766-
case object BadBoundFunction extends BoundFunction {
767-
override def inputTypes(): Array[DataType] = Array(StringType)
768-
override def resultType(): DataType = IntegerType
769-
override def name(): String = "bad_bound_func"
770-
}
771+
object UnboundDecimalAverage extends UnboundFunction {
772+
override def name(): String = "decimal_avg"
771773

772-
object UnboundDecimalAverage extends UnboundFunction {
773-
override def name(): String = "decimal_avg"
774+
override def bind(inputType: StructType): BoundFunction = {
775+
if (inputType.fields.length > 1) {
776+
throw new UnsupportedOperationException("Too many arguments")
777+
}
774778

775-
override def bind(inputType: StructType): BoundFunction = {
776-
if (inputType.fields.length > 1) {
777-
throw new UnsupportedOperationException("Too many arguments")
779+
// put interval type here for testing purpose
780+
inputType.fields(0).dataType match {
781+
case _: NumericType | _: DayTimeIntervalType => DecimalAverage
782+
case dataType =>
783+
throw new UnsupportedOperationException(s"Unsupported input type: $dataType")
784+
}
778785
}
779786

780-
// put interval type here for testing purpose
781-
inputType.fields(0).dataType match {
782-
case _: NumericType | _: DayTimeIntervalType => DecimalAverage
783-
case dataType =>
784-
throw new UnsupportedOperationException(s"Unsupported input type: $dataType")
785-
}
787+
override def description(): String =
788+
"decimal_avg: produces an average using decimal division"
786789
}
787790

788-
override def description(): String =
789-
"decimal_avg: produces an average using decimal division"
790-
}
791-
792-
object DecimalAverage extends AggregateFunction[(Decimal, Int), Decimal] {
793-
override def name(): String = "decimal_avg"
794-
override def inputTypes(): Array[DataType] = Array(DecimalType.SYSTEM_DEFAULT)
795-
override def resultType(): DataType = DecimalType.SYSTEM_DEFAULT
791+
object DecimalAverage extends AggregateFunction[(Decimal, Int), Decimal] {
792+
override def name(): String = "decimal_avg"
793+
override def inputTypes(): Array[DataType] = Array(DecimalType.SYSTEM_DEFAULT)
794+
override def resultType(): DataType = DecimalType.SYSTEM_DEFAULT
796795

797-
override def newAggregationState(): (Decimal, Int) = (Decimal.ZERO, 0)
796+
override def newAggregationState(): (Decimal, Int) = (Decimal.ZERO, 0)
798797

799-
override def update(state: (Decimal, Int), input: InternalRow): (Decimal, Int) = {
800-
if (input.isNullAt(0)) {
801-
state
802-
} else {
803-
val l = input.getDecimal(0, DecimalType.SYSTEM_DEFAULT.precision,
804-
DecimalType.SYSTEM_DEFAULT.scale)
805-
state match {
806-
case (_, d) if d == 0 =>
807-
(l, 1)
808-
case (total, count) =>
809-
(total + l, count + 1)
798+
override def update(state: (Decimal, Int), input: InternalRow): (Decimal, Int) = {
799+
if (input.isNullAt(0)) {
800+
state
801+
} else {
802+
val l = input.getDecimal(0, DecimalType.SYSTEM_DEFAULT.precision,
803+
DecimalType.SYSTEM_DEFAULT.scale)
804+
state match {
805+
case (_, d) if d == 0 =>
806+
(l, 1)
807+
case (total, count) =>
808+
(total + l, count + 1)
809+
}
810810
}
811811
}
812-
}
813812

814-
override def merge(leftState: (Decimal, Int), rightState: (Decimal, Int)): (Decimal, Int) = {
815-
(leftState._1 + rightState._1, leftState._2 + rightState._2)
816-
}
813+
override def merge(leftState: (Decimal, Int), rightState: (Decimal, Int)): (Decimal, Int) = {
814+
(leftState._1 + rightState._1, leftState._2 + rightState._2)
815+
}
817816

818-
override def produceResult(state: (Decimal, Int)): Decimal = state._1 / Decimal(state._2)
819-
}
817+
override def produceResult(state: (Decimal, Int)): Decimal = state._1 / Decimal(state._2)
818+
}
820819

821-
object NoImplAverage extends UnboundFunction {
822-
override def name(): String = "no_impl_avg"
823-
override def description(): String = name()
820+
object NoImplAverage extends UnboundFunction {
821+
override def name(): String = "no_impl_avg"
822+
override def description(): String = name()
824823

825-
override def bind(inputType: StructType): BoundFunction = {
826-
throw SparkUnsupportedOperationException()
824+
override def bind(inputType: StructType): BoundFunction = {
825+
throw SparkUnsupportedOperationException()
826+
}
827827
}
828828
}

sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSessionCatalogSuite.scala

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.connector
1919

20-
import org.apache.spark.sql.{DataFrame, Row, SaveMode}
20+
import org.apache.spark.sql.{DataFrame, SaveMode}
2121
import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable, Table, TableCatalog}
2222

2323
class DataSourceV2SQLSessionCatalogSuite
@@ -79,11 +79,4 @@ class DataSourceV2SQLSessionCatalogSuite
7979
assert(getTableMetadata("default.t").columns().map(_.name()) === Seq("c2", "c1"))
8080
}
8181
}
82-
83-
test("SPARK-54760: DelegatingCatalogExtension supports both V1 and V2 functions") {
84-
sessionCatalog.createFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenDefault))
85-
checkAnswer(
86-
sql("SELECT char_length('Hello') as v1, ns.strlen('Spark') as v2"),
87-
Row(5, 5))
88-
}
8982
}

0 commit comments

Comments
 (0)