Skip to content

Commit 52ab9ba

Browse files
ilicmarkodbcloud-fan
authored andcommitted
[SPARK-52976][PYTHON] Fix Python UDF not accepting collated string as input param/return type
### What changes were proposed in this pull request? Fix Python UDF not accepting collated strings as input param/return type. ### Why are the changes needed? Bug fix. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #51925 from ilicmarkodb/python_udf_collation_fix. Authored-by: ilicmarkodb <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent aed5516 commit 52ab9ba

File tree

8 files changed

+104
-22
lines changed

8 files changed

+104
-22
lines changed

python/pyspark/sql/tests/test_udf.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1337,6 +1337,29 @@ def doubleInDoubleOut(d):
13371337
with self.subTest(chain=chain):
13381338
assertDataFrameEqual(actual=actual, expected=expected)
13391339

1340+
def test_udf_with_collated_string_types(self):
1341+
@udf("string collate fr")
1342+
def my_udf(input_val):
1343+
return "%s - %s" % (type(input_val), input_val)
1344+
1345+
string_types = [
1346+
StringType(),
1347+
StringType("UTF8_BINARY"),
1348+
StringType("UTF8_LCASE"),
1349+
StringType("UNICODE"),
1350+
]
1351+
data = [("hello",)]
1352+
expected = "<class 'str'> - hello"
1353+
1354+
for string_type in string_types:
1355+
schema = StructType([StructField("input_col", string_type, True)])
1356+
df = self.spark.createDataFrame(data, schema=schema)
1357+
df_result = df.select(my_udf(df.input_col).alias("result"))
1358+
row = df_result.collect()[0][0]
1359+
self.assertEqual(row, expected)
1360+
result_type = df_result.schema["result"].dataType
1361+
self.assertEqual(result_type, StringType("fr"))
1362+
13401363

13411364
class UDFTests(BaseUDFTestsMixin, ReusedSQLTestCase):
13421365
@classmethod

python/pyspark/sql/tests/test_udtf.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3100,6 +3100,41 @@ def eval(self):
31003100
udtf(TestUDTF, returnType=ret_type)().collect()
31013101

31023102

3103+
def test_udtf_with_collated_string_types(self):
3104+
@udtf(
3105+
"out1 string, out2 string collate UTF8_BINARY, out3 string collate UTF8_LCASE,"
3106+
" out4 string collate UNICODE"
3107+
)
3108+
class MyUDTF:
3109+
def eval(self, v1, v2, v3, v4):
3110+
yield (v1 + "1", v2 + "2", v3 + "3", v4 + "4")
3111+
3112+
schema = StructType(
3113+
[
3114+
StructField("col1", StringType(), True),
3115+
StructField("col2", StringType("UTF8_BINARY"), True),
3116+
StructField("col3", StringType("UTF8_LCASE"), True),
3117+
StructField("col4", StringType("UNICODE"), True),
3118+
]
3119+
)
3120+
df = self.spark.createDataFrame([("hello",) * 4], schema=schema)
3121+
3122+
df_out = df.select(MyUDTF(df.col1, df.col2, df.col3, df.col4).alias("out"))
3123+
result_df = df_out.select("out.*")
3124+
3125+
expected_row = ("hello1", "hello2", "hello3", "hello4")
3126+
self.assertEqual(result_df.collect()[0], expected_row)
3127+
3128+
expected_output_types = [
3129+
StringType(),
3130+
StringType("UTF8_BINARY"),
3131+
StringType("UTF8_LCASE"),
3132+
StringType("UNICODE"),
3133+
]
3134+
for idx, field in enumerate(result_df.schema.fields):
3135+
self.assertEqual(field.dataType, expected_output_types[idx])
3136+
3137+
31033138
class UDTFArrowTests(UDTFArrowTestsMixin, ReusedSQLTestCase):
31043139
@classmethod
31053140
def setUpClass(cls):

sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -447,15 +447,35 @@ object DataType {
447447
}
448448

449449
/**
450-
* Check if `from` is equal to `to` type except for collations, which are checked to be
451-
* compatible so that data of type `from` can be interpreted as of type `to`.
450+
* Compares two data types, ignoring compatible collation of StringType. If `checkComplexTypes`
451+
* is true, it will also ignore collations for nested types.
452452
*/
453-
private[sql] def equalsIgnoreCompatibleCollation(from: DataType, to: DataType): Boolean = {
454-
(from, to) match {
455-
// String types with possibly different collations are compatible.
456-
case (a: StringType, b: StringType) => a.constraint == b.constraint
453+
private[sql] def equalsIgnoreCompatibleCollation(
454+
from: DataType,
455+
to: DataType,
456+
checkComplexTypes: Boolean = true): Boolean = {
457+
def transform: PartialFunction[DataType, DataType] = {
458+
case dt @ (_: CharType | _: VarcharType) => dt
459+
case _: StringType => StringType
460+
}
457461

458-
case (fromDataType, toDataType) => fromDataType == toDataType
462+
if (checkComplexTypes) {
463+
from.transformRecursively(transform) == to.transformRecursively(transform)
464+
} else {
465+
(from, to) match {
466+
case (a: StringType, b: StringType) => a.constraint == b.constraint
467+
468+
case (fromDataType, toDataType) => fromDataType == toDataType
469+
}
470+
}
471+
}
472+
473+
private[sql] def equalsIgnoreCompatibleCollation(
474+
from: Seq[DataType],
475+
to: Seq[DataType]): Boolean = {
476+
from.length == to.length &&
477+
from.zip(to).forall { case (fromDataType, toDataType) =>
478+
equalsIgnoreCompatibleCollation(fromDataType, toDataType)
459479
}
460480
}
461481

sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -878,7 +878,7 @@ class DataTypeSuite extends SparkFunSuite {
878878
checkEqualsIgnoreCompatibleCollation(
879879
ArrayType(StringType),
880880
ArrayType(StringType("UTF8_LCASE")),
881-
expected = false
881+
expected = true
882882
)
883883
checkEqualsIgnoreCompatibleCollation(
884884
ArrayType(StringType),
@@ -888,7 +888,7 @@ class DataTypeSuite extends SparkFunSuite {
888888
checkEqualsIgnoreCompatibleCollation(
889889
ArrayType(ArrayType(StringType)),
890890
ArrayType(ArrayType(StringType("UTF8_LCASE"))),
891-
expected = false
891+
expected = true
892892
)
893893
checkEqualsIgnoreCompatibleCollation(
894894
ArrayType(ArrayType(StringType)),
@@ -913,12 +913,12 @@ class DataTypeSuite extends SparkFunSuite {
913913
checkEqualsIgnoreCompatibleCollation(
914914
MapType(StringType, StringType),
915915
MapType(StringType, StringType("UTF8_LCASE")),
916-
expected = false
916+
expected = true
917917
)
918918
checkEqualsIgnoreCompatibleCollation(
919919
MapType(StringType("UTF8_LCASE"), StringType),
920920
MapType(StringType, StringType),
921-
expected = false
921+
expected = true
922922
)
923923
checkEqualsIgnoreCompatibleCollation(
924924
MapType(StringType("UTF8_LCASE"), StringType),
@@ -943,7 +943,7 @@ class DataTypeSuite extends SparkFunSuite {
943943
checkEqualsIgnoreCompatibleCollation(
944944
MapType(StringType("UTF8_LCASE"), ArrayType(StringType)),
945945
MapType(StringType("UTF8_LCASE"), ArrayType(StringType("UTF8_LCASE"))),
946-
expected = false
946+
expected = true
947947
)
948948
checkEqualsIgnoreCompatibleCollation(
949949
MapType(StringType("UTF8_LCASE"), ArrayType(StringType)),
@@ -968,7 +968,7 @@ class DataTypeSuite extends SparkFunSuite {
968968
checkEqualsIgnoreCompatibleCollation(
969969
MapType(ArrayType(StringType), IntegerType),
970970
MapType(ArrayType(StringType("UTF8_LCASE")), IntegerType),
971-
expected = false
971+
expected = true
972972
)
973973
checkEqualsIgnoreCompatibleCollation(
974974
MapType(ArrayType(StringType("UTF8_LCASE")), IntegerType),
@@ -998,7 +998,7 @@ class DataTypeSuite extends SparkFunSuite {
998998
checkEqualsIgnoreCompatibleCollation(
999999
StructType(StructField("a", StringType) :: Nil),
10001000
StructType(StructField("a", StringType("UTF8_LCASE")) :: Nil),
1001-
expected = false
1001+
expected = true
10021002
)
10031003
checkEqualsIgnoreCompatibleCollation(
10041004
StructType(StructField("a", StringType) :: Nil),
@@ -1023,7 +1023,7 @@ class DataTypeSuite extends SparkFunSuite {
10231023
checkEqualsIgnoreCompatibleCollation(
10241024
StructType(StructField("a", ArrayType(StringType)) :: Nil),
10251025
StructType(StructField("a", ArrayType(StringType("UTF8_LCASE"))) :: Nil),
1026-
expected = false
1026+
expected = true
10271027
)
10281028
checkEqualsIgnoreCompatibleCollation(
10291029
StructType(StructField("a", ArrayType(StringType)) :: Nil),
@@ -1048,7 +1048,7 @@ class DataTypeSuite extends SparkFunSuite {
10481048
checkEqualsIgnoreCompatibleCollation(
10491049
StructType(StructField("a", MapType(StringType, IntegerType)) :: Nil),
10501050
StructType(StructField("a", MapType(StringType("UTF8_LCASE"), IntegerType)) :: Nil),
1051-
expected = false
1051+
expected = true
10521052
)
10531053
checkEqualsIgnoreCompatibleCollation(
10541054
StructType(StructField("a", MapType(StringType, IntegerType)) :: Nil),

sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ case class AlterTableChangeColumnCommand(
463463
// when altering column. Only changes in collation of data type or its nested types (recursively)
464464
// are allowed.
465465
private def canEvolveType(from: StructField, to: StructField): Boolean = {
466-
DataType.equalsIgnoreCompatibleCollation(from.dataType, to.dataType)
466+
DataType.equalsIgnoreCompatibleCollation(from.dataType, to.dataType, checkComplexTypes = false)
467467
}
468468
}
469469

sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions._
2626
import org.apache.spark.sql.execution.SparkPlan
2727
import org.apache.spark.sql.execution.metric.SQLMetric
2828
import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata
29+
import org.apache.spark.sql.types.DataType.equalsIgnoreCompatibleCollation
2930
import org.apache.spark.sql.types.StructType
3031

3132
/**
@@ -125,8 +126,9 @@ class ArrowEvalPythonEvaluatorFactory(
125126

126127
columnarBatchIter.flatMap { batch =>
127128
val actualDataTypes = (0 until batch.numCols()).map(i => batch.column(i).dataType())
128-
assert(outputTypes == actualDataTypes, "Invalid schema from pandas_udf: " +
129-
s"expected ${outputTypes.mkString(", ")}, got ${actualDataTypes.mkString(", ")}")
129+
assert(equalsIgnoreCompatibleCollation(outputTypes, actualDataTypes),
130+
s"Invalid schema from arrow-enabled Python UDTF: expected ${outputTypes.mkString(", ")}," +
131+
s" got ${actualDataTypes.mkString(", ")}")
130132
batch.rowIterator.asScala
131133
}
132134
}

sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow
2424
import org.apache.spark.sql.catalyst.expressions._
2525
import org.apache.spark.sql.execution.SparkPlan
2626
import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata
27+
import org.apache.spark.sql.types.DataType.equalsIgnoreCompatibleCollation
2728
import org.apache.spark.sql.types.StructType
2829
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}
2930

@@ -81,8 +82,9 @@ case class ArrowEvalPythonUDTFExec(
8182

8283
val actualDataTypes = (0 until flattenedBatch.numCols()).map(
8384
i => flattenedBatch.column(i).dataType())
84-
assert(outputTypes == actualDataTypes, "Invalid schema from arrow-enabled Python UDTF: " +
85-
s"expected ${outputTypes.mkString(", ")}, got ${actualDataTypes.mkString(", ")}")
85+
assert(equalsIgnoreCompatibleCollation(outputTypes, actualDataTypes),
86+
s"Invalid schema from arrow-enabled Python UDTF: expected ${outputTypes.mkString(", ")}," +
87+
s" got ${actualDataTypes.mkString(", ")}")
8688

8789
flattenedBatch.setNumRows(batch.numRows())
8890
flattenedBatch.rowIterator().asScala

sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ object EvaluatePython {
7878

7979
case (d: Decimal, _) => d.toJavaBigDecimal
8080

81-
case (s: UTF8String, StringType) => s.toString
81+
case (s: UTF8String, _: StringType) => s.toString
8282

8383
case (other, _) => other
8484
}

0 commit comments

Comments
 (0)