Skip to content

Commit 512099b

Browse files
pan3793cloud-fan
authored andcommitted
[SPARK-54760][SQL] DelegatingCatalogExtension as session catalog supports both V1 and V2 functions
### What changes were proposed in this pull request? This PR fixes a bug that occurs when the user uses a custom `DelegatingCatalogExtension` as the session catalog, Spark can not load the v2 function properly provided by the catalog. A typical use case is Iceberg's `SparkSessionCatalog` ``` $ spark-sql \ --conf spark.sql.catalog.spark_catalog=org.apache.iceberg.spark.SparkSessionCatalog \ ... ``` ``` spark-sql (default)> SELECT spark_catalog.system.iceberg_version(); [ROUTINE_NOT_FOUND] The routine `system`.`iceberg_version` cannot be found. Verify the spelling and correctness of the schema and catalog. If you did not qualify the name with a schema and catalog, verify the current_schema() output, or qualify the name with the correct schema and catalog. To tolerate the error on drop use DROP ... IF EXISTS. SQLSTATE: 42883; line 1 pos 7 ``` ### Why are the changes needed? Fix bug. ### Does this PR introduce _any_ user-facing change? Yes, it fixes a bug. ### How was this patch tested? Add new UT. Also manually tested with Iceberg. ``` spark-sql (default)> SELECT spark_catalog.system.iceberg_version(); 1.10.0 Time taken: 1.715 seconds, Fetched 1 row(s) ``` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #53531 from pan3793/SPARK-54760. Authored-by: Cheng Pan <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent db36f74 commit 512099b

File tree

10 files changed

+197
-126
lines changed

10 files changed

+197
-126
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,16 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
2626
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2727
import org.apache.spark.sql.connector.catalog.{
2828
CatalogManager,
29-
CatalogV2Util,
30-
FunctionCatalog,
31-
Identifier,
3229
LookupCatalog
3330
}
3431
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
3532
import org.apache.spark.sql.connector.catalog.functions.{
3633
AggregateFunction => V2AggregateFunction,
37-
ScalarFunction
34+
ScalarFunction,
35+
UnboundFunction
3836
}
3937
import org.apache.spark.sql.errors.{DataTypeErrorsBase, QueryCompilationErrors}
38+
import org.apache.spark.sql.internal.connector.V1Function
4039
import org.apache.spark.sql.types._
4140

4241
class FunctionResolution(
@@ -52,10 +51,14 @@ class FunctionResolution(
5251
resolveBuiltinOrTempFunction(u.nameParts, u.arguments, u).getOrElse {
5352
val CatalogAndIdentifier(catalog, ident) =
5453
relationResolution.expandIdentifier(u.nameParts)
55-
if (CatalogV2Util.isSessionCatalog(catalog)) {
56-
resolveV1Function(ident.asFunctionIdentifier, u.arguments, u)
57-
} else {
58-
resolveV2Function(catalog.asFunctionCatalog, ident, u.arguments, u)
54+
catalog.asFunctionCatalog.loadFunction(ident) match {
55+
case V1Function(_) =>
56+
// this triggers the second time v1 function resolution but should be cheap
57+
// (no RPC to external catalog), since the metadata has been already cached
58+
// in FunctionRegistry during the above `catalog.loadFunction` call.
59+
resolveV1Function(ident.asFunctionIdentifier, u.arguments, u)
60+
case unboundV2Func =>
61+
resolveV2Function(unboundV2Func, u.arguments, u)
5962
}
6063
}
6164
}
@@ -272,11 +275,9 @@ class FunctionResolution(
272275
}
273276

274277
private def resolveV2Function(
275-
catalog: FunctionCatalog,
276-
ident: Identifier,
278+
unbound: UnboundFunction,
277279
arguments: Seq[Expression],
278280
u: UnresolvedFunction): Expression = {
279-
val unbound = catalog.loadFunction(ident)
280281
val inputType = StructType(arguments.zipWithIndex.map {
281282
case (exp, pos) => StructField(s"_$pos", exp.dataType, exp.nullable)
282283
})

sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause-legacy.sql.out

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -972,11 +972,11 @@ VALUES(IDENTIFIER('a.b.c.d')())
972972
-- !query analysis
973973
org.apache.spark.sql.AnalysisException
974974
{
975-
"errorClass" : "IDENTIFIER_TOO_MANY_NAME_PARTS",
976-
"sqlState" : "42601",
975+
"errorClass" : "REQUIRES_SINGLE_PART_NAMESPACE",
976+
"sqlState" : "42K05",
977977
"messageParameters" : {
978-
"identifier" : "`a`.`b`.`c`.`d`",
979-
"limit" : "2"
978+
"namespace" : "`a`.`b`.`c`",
979+
"sessionCatalog" : "spark_catalog"
980980
},
981981
"queryContext" : [ {
982982
"objectType" : "",

sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -972,11 +972,11 @@ VALUES(IDENTIFIER('a.b.c.d')())
972972
-- !query analysis
973973
org.apache.spark.sql.AnalysisException
974974
{
975-
"errorClass" : "IDENTIFIER_TOO_MANY_NAME_PARTS",
976-
"sqlState" : "42601",
975+
"errorClass" : "REQUIRES_SINGLE_PART_NAMESPACE",
976+
"sqlState" : "42K05",
977977
"messageParameters" : {
978-
"identifier" : "`a`.`b`.`c`.`d`",
979-
"limit" : "2"
978+
"namespace" : "`a`.`b`.`c`",
979+
"sessionCatalog" : "spark_catalog"
980980
},
981981
"queryContext" : [ {
982982
"objectType" : "",

sql/core/src/test/resources/sql-tests/results/identifier-clause-legacy.sql.out

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1112,11 +1112,11 @@ struct<>
11121112
-- !query output
11131113
org.apache.spark.sql.AnalysisException
11141114
{
1115-
"errorClass" : "IDENTIFIER_TOO_MANY_NAME_PARTS",
1116-
"sqlState" : "42601",
1115+
"errorClass" : "REQUIRES_SINGLE_PART_NAMESPACE",
1116+
"sqlState" : "42K05",
11171117
"messageParameters" : {
1118-
"identifier" : "`a`.`b`.`c`.`d`",
1119-
"limit" : "2"
1118+
"namespace" : "`a`.`b`.`c`",
1119+
"sessionCatalog" : "spark_catalog"
11201120
},
11211121
"queryContext" : [ {
11221122
"objectType" : "",

sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1112,11 +1112,11 @@ struct<>
11121112
-- !query output
11131113
org.apache.spark.sql.AnalysisException
11141114
{
1115-
"errorClass" : "IDENTIFIER_TOO_MANY_NAME_PARTS",
1116-
"sqlState" : "42601",
1115+
"errorClass" : "REQUIRES_SINGLE_PART_NAMESPACE",
1116+
"sqlState" : "42K05",
11171117
"messageParameters" : {
1118-
"identifier" : "`a`.`b`.`c`.`d`",
1119-
"limit" : "2"
1118+
"namespace" : "`a`.`b`.`c`",
1119+
"sessionCatalog" : "spark_catalog"
11201120
},
11211121
"queryContext" : [ {
11221122
"objectType" : "",

sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,10 @@ private [connector] trait SessionCatalogTest[T <: Table, Catalog <: TestV2Sessio
168168
spark.sessionState.catalogManager.catalog(name)
169169
}
170170

171+
protected def sessionCatalog: Catalog = {
172+
catalog(SESSION_CATALOG_NAME).asInstanceOf[Catalog]
173+
}
174+
171175
protected val v2Format: String = classOf[FakeV2ProviderWithCustomSchema].getName
172176

173177
protected val catalogClassName: String = classOf[InMemoryTableSessionCatalog].getName
@@ -178,7 +182,9 @@ private [connector] trait SessionCatalogTest[T <: Table, Catalog <: TestV2Sessio
178182

179183
override def afterEach(): Unit = {
180184
super.afterEach()
181-
catalog(SESSION_CATALOG_NAME).asInstanceOf[Catalog].clearTables()
185+
sessionCatalog.checkUsage()
186+
sessionCatalog.clearTables()
187+
sessionCatalog.clearFunctions()
182188
spark.conf.unset(V2_SESSION_CATALOG_IMPLEMENTATION.key)
183189
}
184190

sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala

Lines changed: 91 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -702,127 +702,127 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase {
702702
comparePlans(df1.queryExecution.optimizedPlan, df2.queryExecution.optimizedPlan)
703703
checkAnswer(df1, Row(3) :: Nil)
704704
}
705+
}
705706

706-
private case object StrLenDefault extends ScalarFunction[Int] {
707-
override def inputTypes(): Array[DataType] = Array(StringType)
708-
override def resultType(): DataType = IntegerType
709-
override def name(): String = "strlen_default"
707+
case object StrLenDefault extends ScalarFunction[Int] {
708+
override def inputTypes(): Array[DataType] = Array(StringType)
709+
override def resultType(): DataType = IntegerType
710+
override def name(): String = "strlen_default"
710711

711-
override def produceResult(input: InternalRow): Int = {
712-
val s = input.getString(0)
713-
s.length
714-
}
712+
override def produceResult(input: InternalRow): Int = {
713+
val s = input.getString(0)
714+
s.length
715715
}
716+
}
716717

717-
case object StrLenMagic extends ScalarFunction[Int] {
718-
override def inputTypes(): Array[DataType] = Array(StringType)
719-
override def resultType(): DataType = IntegerType
720-
override def name(): String = "strlen_magic"
718+
case object StrLenMagic extends ScalarFunction[Int] {
719+
override def inputTypes(): Array[DataType] = Array(StringType)
720+
override def resultType(): DataType = IntegerType
721+
override def name(): String = "strlen_magic"
721722

722-
def invoke(input: UTF8String): Int = {
723-
input.toString.length
724-
}
723+
def invoke(input: UTF8String): Int = {
724+
input.toString.length
725725
}
726+
}
726727

727-
case object StrLenBadMagic extends ScalarFunction[Int] {
728-
override def inputTypes(): Array[DataType] = Array(StringType)
729-
override def resultType(): DataType = IntegerType
730-
override def name(): String = "strlen_bad_magic"
728+
case object StrLenBadMagic extends ScalarFunction[Int] {
729+
override def inputTypes(): Array[DataType] = Array(StringType)
730+
override def resultType(): DataType = IntegerType
731+
override def name(): String = "strlen_bad_magic"
731732

732-
def invoke(input: String): Int = {
733-
input.length
734-
}
733+
def invoke(input: String): Int = {
734+
input.length
735735
}
736+
}
736737

737-
case object StrLenBadMagicWithDefault extends ScalarFunction[Int] {
738-
override def inputTypes(): Array[DataType] = Array(StringType)
739-
override def resultType(): DataType = IntegerType
740-
override def name(): String = "strlen_bad_magic"
741-
742-
def invoke(input: String): Int = {
743-
input.length
744-
}
738+
case object StrLenBadMagicWithDefault extends ScalarFunction[Int] {
739+
override def inputTypes(): Array[DataType] = Array(StringType)
740+
override def resultType(): DataType = IntegerType
741+
override def name(): String = "strlen_bad_magic"
745742

746-
override def produceResult(input: InternalRow): Int = {
747-
val s = input.getString(0)
748-
s.length
749-
}
743+
def invoke(input: String): Int = {
744+
input.length
750745
}
751746

752-
private case object StrLenNoImpl extends ScalarFunction[Int] {
753-
override def inputTypes(): Array[DataType] = Array(StringType)
754-
override def resultType(): DataType = IntegerType
755-
override def name(): String = "strlen_noimpl"
747+
override def produceResult(input: InternalRow): Int = {
748+
val s = input.getString(0)
749+
s.length
756750
}
751+
}
757752

758-
// input type doesn't match arguments accepted by `UnboundFunction.bind`
759-
private case object StrLenBadInputTypes extends ScalarFunction[Int] {
760-
override def inputTypes(): Array[DataType] = Array(StringType, IntegerType)
761-
override def resultType(): DataType = IntegerType
762-
override def name(): String = "strlen_bad_input_types"
763-
}
753+
case object StrLenNoImpl extends ScalarFunction[Int] {
754+
override def inputTypes(): Array[DataType] = Array(StringType)
755+
override def resultType(): DataType = IntegerType
756+
override def name(): String = "strlen_noimpl"
757+
}
764758

765-
private case object BadBoundFunction extends BoundFunction {
766-
override def inputTypes(): Array[DataType] = Array(StringType)
767-
override def resultType(): DataType = IntegerType
768-
override def name(): String = "bad_bound_func"
769-
}
759+
// input type doesn't match arguments accepted by `UnboundFunction.bind`
760+
case object StrLenBadInputTypes extends ScalarFunction[Int] {
761+
override def inputTypes(): Array[DataType] = Array(StringType, IntegerType)
762+
override def resultType(): DataType = IntegerType
763+
override def name(): String = "strlen_bad_input_types"
764+
}
770765

771-
object UnboundDecimalAverage extends UnboundFunction {
772-
override def name(): String = "decimal_avg"
766+
case object BadBoundFunction extends BoundFunction {
767+
override def inputTypes(): Array[DataType] = Array(StringType)
768+
override def resultType(): DataType = IntegerType
769+
override def name(): String = "bad_bound_func"
770+
}
773771

774-
override def bind(inputType: StructType): BoundFunction = {
775-
if (inputType.fields.length > 1) {
776-
throw new UnsupportedOperationException("Too many arguments")
777-
}
772+
object UnboundDecimalAverage extends UnboundFunction {
773+
override def name(): String = "decimal_avg"
778774

779-
// put interval type here for testing purpose
780-
inputType.fields(0).dataType match {
781-
case _: NumericType | _: DayTimeIntervalType => DecimalAverage
782-
case dataType =>
783-
throw new UnsupportedOperationException(s"Unsupported input type: $dataType")
784-
}
775+
override def bind(inputType: StructType): BoundFunction = {
776+
if (inputType.fields.length > 1) {
777+
throw new UnsupportedOperationException("Too many arguments")
785778
}
786779

787-
override def description(): String =
788-
"decimal_avg: produces an average using decimal division"
780+
// put interval type here for testing purpose
781+
inputType.fields(0).dataType match {
782+
case _: NumericType | _: DayTimeIntervalType => DecimalAverage
783+
case dataType =>
784+
throw new UnsupportedOperationException(s"Unsupported input type: $dataType")
785+
}
789786
}
790787

791-
object DecimalAverage extends AggregateFunction[(Decimal, Int), Decimal] {
792-
override def name(): String = "decimal_avg"
793-
override def inputTypes(): Array[DataType] = Array(DecimalType.SYSTEM_DEFAULT)
794-
override def resultType(): DataType = DecimalType.SYSTEM_DEFAULT
788+
override def description(): String =
789+
"decimal_avg: produces an average using decimal division"
790+
}
795791

796-
override def newAggregationState(): (Decimal, Int) = (Decimal.ZERO, 0)
792+
object DecimalAverage extends AggregateFunction[(Decimal, Int), Decimal] {
793+
override def name(): String = "decimal_avg"
794+
override def inputTypes(): Array[DataType] = Array(DecimalType.SYSTEM_DEFAULT)
795+
override def resultType(): DataType = DecimalType.SYSTEM_DEFAULT
797796

798-
override def update(state: (Decimal, Int), input: InternalRow): (Decimal, Int) = {
799-
if (input.isNullAt(0)) {
800-
state
801-
} else {
802-
val l = input.getDecimal(0, DecimalType.SYSTEM_DEFAULT.precision,
803-
DecimalType.SYSTEM_DEFAULT.scale)
804-
state match {
805-
case (_, d) if d == 0 =>
806-
(l, 1)
807-
case (total, count) =>
808-
(total + l, count + 1)
809-
}
810-
}
811-
}
797+
override def newAggregationState(): (Decimal, Int) = (Decimal.ZERO, 0)
812798

813-
override def merge(leftState: (Decimal, Int), rightState: (Decimal, Int)): (Decimal, Int) = {
814-
(leftState._1 + rightState._1, leftState._2 + rightState._2)
799+
override def update(state: (Decimal, Int), input: InternalRow): (Decimal, Int) = {
800+
if (input.isNullAt(0)) {
801+
state
802+
} else {
803+
val l = input.getDecimal(0, DecimalType.SYSTEM_DEFAULT.precision,
804+
DecimalType.SYSTEM_DEFAULT.scale)
805+
state match {
806+
case (_, d) if d == 0 =>
807+
(l, 1)
808+
case (total, count) =>
809+
(total + l, count + 1)
810+
}
815811
}
812+
}
816813

817-
override def produceResult(state: (Decimal, Int)): Decimal = state._1 / Decimal(state._2)
814+
override def merge(leftState: (Decimal, Int), rightState: (Decimal, Int)): (Decimal, Int) = {
815+
(leftState._1 + rightState._1, leftState._2 + rightState._2)
818816
}
819817

820-
object NoImplAverage extends UnboundFunction {
821-
override def name(): String = "no_impl_avg"
822-
override def description(): String = name()
818+
override def produceResult(state: (Decimal, Int)): Decimal = state._1 / Decimal(state._2)
819+
}
820+
821+
object NoImplAverage extends UnboundFunction {
822+
override def name(): String = "no_impl_avg"
823+
override def description(): String = name()
823824

824-
override def bind(inputType: StructType): BoundFunction = {
825-
throw SparkUnsupportedOperationException()
826-
}
825+
override def bind(inputType: StructType): BoundFunction = {
826+
throw SparkUnsupportedOperationException()
827827
}
828828
}

sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSessionCatalogSuite.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.connector
1919

20-
import org.apache.spark.sql.{DataFrame, SaveMode}
20+
import org.apache.spark.sql.{DataFrame, Row, SaveMode}
2121
import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable, Table, TableCatalog}
2222

2323
class DataSourceV2SQLSessionCatalogSuite
@@ -79,4 +79,11 @@ class DataSourceV2SQLSessionCatalogSuite
7979
assert(getTableMetadata("default.t").columns().map(_.name()) === Seq("c2", "c1"))
8080
}
8181
}
82+
83+
test("SPARK-54760: DelegatingCatalogExtension supports both V1 and V2 functions") {
84+
sessionCatalog.createFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenDefault))
85+
checkAnswer(
86+
sql("SELECT char_length('Hello') as v1, ns.strlen('Spark') as v2"),
87+
Row(5, 5))
88+
}
8289
}

0 commit comments

Comments
 (0)