Skip to content

Commit 3e69b40

Browse files
[SPARK-49683][SQL] Block trim collation
### What changes were proposed in this pull request? Trim collation is currently in implementation phase. These change blocks all paths from using it and afterwards trim collation gets enabled for different expressions it will be gradually whitelisted. ### Why are the changes needed? Trim collation is currently in implementation phase. These change blocks all paths from using it and afterwards trim collation gets enabled for different expressions it will be gradually whitelisted. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? No additional tests, just added field that's not used. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48336 from jovanpavl-db/block-collation-trim. Lead-authored-by: Jovan Pavlovic <[email protected]> Co-authored-by: Hyukjin Kwon <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent d8c04cf commit 3e69b40

File tree

6 files changed

+89
-37
lines changed

6 files changed

+89
-37
lines changed

sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,43 +21,79 @@ import org.apache.spark.sql.internal.SqlApiConf
2121
import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType}
2222

2323
/**
24-
* AbstractStringType is an abstract class for StringType with collation support.
24+
* AbstractStringType is an abstract class for StringType with collation support. As every type of
25+
* collation can support trim specifier this class is parametrized with it.
2526
*/
26-
abstract class AbstractStringType extends AbstractDataType {
27+
abstract class AbstractStringType(private[sql] val supportsTrimCollation: Boolean = false)
28+
extends AbstractDataType {
2729
override private[sql] def defaultConcreteType: DataType = SqlApiConf.get.defaultStringType
2830
override private[sql] def simpleString: String = "string"
31+
private[sql] def canUseTrimCollation(other: DataType): Boolean =
32+
supportsTrimCollation || !other.asInstanceOf[StringType].usesTrimCollation
2933
}
3034

3135
/**
3236
* Use StringTypeBinary for expressions supporting only binary collation.
3337
*/
34-
case object StringTypeBinary extends AbstractStringType {
38+
case class StringTypeBinary(override val supportsTrimCollation: Boolean = false)
39+
extends AbstractStringType(supportsTrimCollation) {
3540
override private[sql] def acceptsType(other: DataType): Boolean =
36-
other.isInstanceOf[StringType] && other.asInstanceOf[StringType].supportsBinaryEquality
41+
other.isInstanceOf[StringType] && other.asInstanceOf[StringType].supportsBinaryEquality &&
42+
canUseTrimCollation(other)
43+
}
44+
45+
object StringTypeBinary extends StringTypeBinary(false) {
46+
def apply(supportsTrimCollation: Boolean): StringTypeBinary = {
47+
new StringTypeBinary(supportsTrimCollation)
48+
}
3749
}
3850

3951
/**
4052
* Use StringTypeBinaryLcase for expressions supporting only binary and lowercase collation.
4153
*/
42-
case object StringTypeBinaryLcase extends AbstractStringType {
54+
case class StringTypeBinaryLcase(override val supportsTrimCollation: Boolean = false)
55+
extends AbstractStringType(supportsTrimCollation) {
4356
override private[sql] def acceptsType(other: DataType): Boolean =
4457
other.isInstanceOf[StringType] && (other.asInstanceOf[StringType].supportsBinaryEquality ||
45-
other.asInstanceOf[StringType].isUTF8LcaseCollation)
58+
other.asInstanceOf[StringType].isUTF8LcaseCollation) && canUseTrimCollation(other)
59+
}
60+
61+
object StringTypeBinaryLcase extends StringTypeBinaryLcase(false) {
62+
def apply(supportsTrimCollation: Boolean): StringTypeBinaryLcase = {
63+
new StringTypeBinaryLcase(supportsTrimCollation)
64+
}
4665
}
4766

4867
/**
4968
* Use StringTypeWithCaseAccentSensitivity for expressions supporting all collation types (binary
5069
* and ICU) but limited to using case and accent sensitivity specifiers.
5170
*/
52-
case object StringTypeWithCaseAccentSensitivity extends AbstractStringType {
53-
override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[StringType]
71+
case class StringTypeWithCaseAccentSensitivity(
72+
override val supportsTrimCollation: Boolean = false)
73+
extends AbstractStringType(supportsTrimCollation) {
74+
override private[sql] def acceptsType(other: DataType): Boolean =
75+
other.isInstanceOf[StringType] && canUseTrimCollation(other)
76+
}
77+
78+
object StringTypeWithCaseAccentSensitivity extends StringTypeWithCaseAccentSensitivity(false) {
79+
def apply(supportsTrimCollation: Boolean): StringTypeWithCaseAccentSensitivity = {
80+
new StringTypeWithCaseAccentSensitivity(supportsTrimCollation)
81+
}
5482
}
5583

5684
/**
5785
* Use StringTypeNonCSAICollation for expressions supporting all possible collation types except
5886
* CS_AI collation types.
5987
*/
60-
case object StringTypeNonCSAICollation extends AbstractStringType {
88+
case class StringTypeNonCSAICollation(override val supportsTrimCollation: Boolean = false)
89+
extends AbstractStringType(supportsTrimCollation) {
6190
override private[sql] def acceptsType(other: DataType): Boolean =
62-
other.isInstanceOf[StringType] && other.asInstanceOf[StringType].isNonCSAI
91+
other.isInstanceOf[StringType] && other.asInstanceOf[StringType].isNonCSAI &&
92+
canUseTrimCollation(other)
93+
}
94+
95+
object StringTypeNonCSAICollation extends StringTypeNonCSAICollation(false) {
96+
def apply(supportsTrimCollation: Boolean): StringTypeNonCSAICollation = {
97+
new StringTypeNonCSAICollation(supportsTrimCollation)
98+
}
6399
}

sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ class StringType private (val collationId: Int) extends AtomicType with Serializ
4747
private[sql] def isNonCSAI: Boolean =
4848
!CollationFactory.isCaseSensitiveAndAccentInsensitive(collationId)
4949

50+
private[sql] def usesTrimCollation: Boolean =
51+
CollationFactory.usesTrimCollation(collationId)
52+
5053
private[sql] def isUTF8BinaryCollation: Boolean =
5154
collationId == CollationFactory.UTF8_BINARY_COLLATION_ID
5255

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationKey.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ import org.apache.spark.sql.types._
2424
import org.apache.spark.unsafe.types.UTF8String
2525

2626
case class CollationKey(expr: Expression) extends UnaryExpression with ExpectsInputTypes {
27-
override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity)
27+
override def inputTypes: Seq[AbstractDataType] =
28+
Seq(StringTypeWithCaseAccentSensitivity(/* supportsTrimCollation = */ true))
2829
override def dataType: DataType = BinaryType
2930

3031
final lazy val collationId: Int = expr.dataType match {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/datasketchesAggregates.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,11 @@ case class HllSketchAgg(
106106

107107
override def inputTypes: Seq[AbstractDataType] =
108108
Seq(
109-
TypeCollection(IntegerType, LongType, StringTypeWithCaseAccentSensitivity, BinaryType),
109+
TypeCollection(
110+
IntegerType,
111+
LongType,
112+
StringTypeWithCaseAccentSensitivity(/* supportsTrimCollation = */ true),
113+
BinaryType),
110114
IntegerType)
111115

112116
override def dataType: DataType = BinaryType

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ case class Collate(child: Expression, collationName: String)
7777
extends UnaryExpression with ExpectsInputTypes {
7878
private val collationId = CollationFactory.collationNameToId(collationName)
7979
override def dataType: DataType = StringType(collationId)
80-
override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity)
80+
override def inputTypes: Seq[AbstractDataType] =
81+
Seq(StringTypeWithCaseAccentSensitivity(/* supportsTrimCollation = */ true))
8182

8283
override protected def withNewChildInternal(
8384
newChild: Expression): Expression = copy(newChild)
@@ -115,5 +116,6 @@ case class Collation(child: Expression)
115116
val collationName = CollationFactory.fetchCollation(collationId).collationName
116117
Literal.create(collationName, SQLConf.get.defaultStringType)
117118
}
118-
override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity)
119+
override def inputTypes: Seq[AbstractDataType] =
120+
Seq(StringTypeWithCaseAccentSensitivity(/* supportsTrimCollation = */ true))
119121
}

sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -982,7 +982,11 @@ class CollationSQLExpressionsSuite
982982
StringToMapTestCase("1/AX2/BX3/C", "x", "/", "UNICODE_CI",
983983
Map("1" -> "A", "2" -> "B", "3" -> "C"))
984984
)
985-
val unsupportedTestCase = StringToMapTestCase("a:1,b:2,c:3", "?", "?", "UNICODE_AI", null)
985+
val unsupportedTestCases = Seq(
986+
StringToMapTestCase("a:1,b:2,c:3", "?", "?", "UNICODE_AI", null),
987+
StringToMapTestCase("a:1,b:2,c:3", "?", "?", "UNICODE_RTRIM", null),
988+
StringToMapTestCase("a:1,b:2,c:3", "?", "?", "UTF8_BINARY_RTRIM", null),
989+
StringToMapTestCase("a:1,b:2,c:3", "?", "?", "UTF8_LCASE_RTRIM", null))
986990
testCases.foreach(t => {
987991
// Unit test.
988992
val text = Literal.create(t.text, StringType(t.collation))
@@ -998,28 +1002,30 @@ class CollationSQLExpressionsSuite
9981002
}
9991003
})
10001004
// Test unsupported collation.
1001-
withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) {
1002-
val query =
1003-
s"select str_to_map('${unsupportedTestCase.text}', '${unsupportedTestCase.pairDelim}', " +
1004-
s"'${unsupportedTestCase.keyValueDelim}')"
1005-
checkError(
1006-
exception = intercept[AnalysisException] {
1007-
sql(query).collect()
1008-
},
1009-
condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
1010-
sqlState = Some("42K09"),
1011-
parameters = Map(
1012-
"sqlExpr" -> ("\"str_to_map('a:1,b:2,c:3' collate UNICODE_AI, " +
1013-
"'?' collate UNICODE_AI, '?' collate UNICODE_AI)\""),
1014-
"paramIndex" -> "first",
1015-
"inputSql" -> "\"'a:1,b:2,c:3' collate UNICODE_AI\"",
1016-
"inputType" -> "\"STRING COLLATE UNICODE_AI\"",
1017-
"requiredType" -> "\"STRING\""),
1018-
context = ExpectedContext(
1019-
fragment = "str_to_map('a:1,b:2,c:3', '?', '?')",
1020-
start = 7,
1021-
stop = 41))
1022-
}
1005+
unsupportedTestCases.foreach(t => {
1006+
withSQLConf(SQLConf.DEFAULT_COLLATION.key -> t.collation) {
1007+
val query =
1008+
s"select str_to_map('${t.text}', '${t.pairDelim}', " +
1009+
s"'${t.keyValueDelim}')"
1010+
checkError(
1011+
exception = intercept[AnalysisException] {
1012+
sql(query).collect()
1013+
},
1014+
condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
1015+
sqlState = Some("42K09"),
1016+
parameters = Map(
1017+
"sqlExpr" -> ("\"str_to_map('a:1,b:2,c:3' collate " + s"${t.collation}, " +
1018+
"'?' collate " + s"${t.collation}, '?' collate ${t.collation})" + "\""),
1019+
"paramIndex" -> "first",
1020+
"inputSql" -> ("\"'a:1,b:2,c:3' collate " + s"${t.collation}" + "\""),
1021+
"inputType" -> ("\"STRING COLLATE " + s"${t.collation}" + "\""),
1022+
"requiredType" -> "\"STRING\""),
1023+
context = ExpectedContext(
1024+
fragment = "str_to_map('a:1,b:2,c:3', '?', '?')",
1025+
start = 7,
1026+
stop = 41))
1027+
}
1028+
})
10231029
}
10241030

10251031
test("Support RaiseError misc expression with collation") {

0 commit comments

Comments
 (0)