Skip to content

Commit 5a5bf04

Browse files
committed
[SPARK-51874][CORE][SQL] Add TypedConfigBuilder for Scala Enumeration
### What changes were proposed in this pull request? This PR introduces TypedConfigBuilder for Scala Enumeration and leverages it for existing configurations that use Enumeration as parameters. Before this PR, we need to change them from Enumeration to string, string to Enumeration, back and forth... We also need to do upper-case transformation, .checkValues validation one by one. After this PR, those steps are centralized. ### Why are the changes needed? better support for Enumeration-like configurations ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? new tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #50674 from yaooqinn/enum. Authored-by: Kent Yao <[email protected]> Signed-off-by: Kent Yao <[email protected]>
1 parent 516859f commit 5a5bf04

File tree

22 files changed

+127
-117
lines changed

22 files changed

+127
-117
lines changed

common/utils/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.internal.config
1919

20+
import java.util.Locale
2021
import java.util.concurrent.TimeUnit
2122
import java.util.regex.PatternSyntaxException
2223

@@ -46,6 +47,16 @@ private object ConfigHelpers {
4647
}
4748
}
4849

50+
def toEnum[E <: Enumeration](s: String, enumClass: E, key: String): enumClass.Value = {
51+
try {
52+
enumClass.withName(s.trim.toUpperCase(Locale.ROOT))
53+
} catch {
54+
case _: NoSuchElementException =>
55+
throw new IllegalArgumentException(
56+
s"$key should be one of ${enumClass.values.mkString(", ")}, but was $s")
57+
}
58+
}
59+
4960
def stringToSeq[T](str: String, converter: String => T): Seq[T] = {
5061
SparkStringUtils.stringToSeq(str).map(converter)
5162
}
@@ -271,6 +282,11 @@ private[spark] case class ConfigBuilder(key: String) {
271282
new TypedConfigBuilder(this, v => v)
272283
}
273284

285+
def enumConf(e: Enumeration): TypedConfigBuilder[e.Value] = {
286+
checkPrependConfig
287+
new TypedConfigBuilder(this, toEnum(_, e, key))
288+
}
289+
274290
def timeConf(unit: TimeUnit): TypedConfigBuilder[Long] = {
275291
checkPrependConfig
276292
new TypedConfigBuilder(this, timeFromString(_, unit), timeToString(_, unit))

core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,4 +387,25 @@ class ConfigEntrySuite extends SparkFunSuite {
387387
ConfigBuilder(testKey("oc5")).onCreate(_ => onCreateCalled = true).fallbackConf(fallback)
388388
assert(onCreateCalled)
389389
}
390+
391+
392+
test("SPARK-51874: Add Enum support to ConfigBuilder") {
393+
object MyTestEnum extends Enumeration {
394+
val X, Y, Z = Value
395+
}
396+
val conf = new SparkConf()
397+
val enumConf = ConfigBuilder("spark.test.enum.key")
398+
.enumConf(MyTestEnum)
399+
.createWithDefault(MyTestEnum.X)
400+
assert(conf.get(enumConf) === MyTestEnum.X)
401+
conf.set(enumConf, MyTestEnum.Y)
402+
assert(conf.get(enumConf) === MyTestEnum.Y)
403+
conf.set(enumConf.key, "Z")
404+
assert(conf.get(enumConf) === MyTestEnum.Z)
405+
val e = intercept[IllegalArgumentException] {
406+
conf.set(enumConf.key, "A")
407+
conf.get(enumConf)
408+
}
409+
assert(e.getMessage === s"${enumConf.key} should be one of X, Y, Z, but was A")
410+
}
390411
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ object CTESubstitution extends Rule[LogicalPlan] {
8484

8585
val cteDefs = ArrayBuffer.empty[CTERelationDef]
8686
val (substituted, firstSubstituted) =
87-
LegacyBehaviorPolicy.withName(conf.getConf(LEGACY_CTE_PRECEDENCE_POLICY)) match {
87+
conf.getConf(LEGACY_CTE_PRECEDENCE_POLICY) match {
8888
case LegacyBehaviorPolicy.EXCEPTION =>
8989
assertNoNameConflictsInCTE(plan)
9090
traverseAndSubstituteCTE(plan, forceInline, Seq.empty, cteDefs, None)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverGuard.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,7 @@ class ResolverGuard(catalogManager: CatalogManager) extends SQLConfHelper {
426426
Some("hiveCaseSensitiveInferenceMode")
427427
} else if (conf.getConf(SQLConf.LEGACY_INLINE_CTE_IN_COMMANDS)) {
428428
Some("legacyInlineCTEInCommands")
429-
} else if (LegacyBehaviorPolicy.withName(conf.getConf(SQLConf.LEGACY_CTE_PRECEDENCE_POLICY)) !=
429+
} else if (conf.getConf(SQLConf.LEGACY_CTE_PRECEDENCE_POLICY) !=
430430
LegacyBehaviorPolicy.CORRECTED) {
431431
Some("legacyCTEPrecedencePolicy")
432432
} else {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ abstract class CodeGeneratorWithInterpretedFallback[IN, OUT] extends Logging {
3838

3939
def createObject(in: IN): OUT = {
4040
// We are allowed to choose codegen-only or no-codegen modes if under tests.
41-
val fallbackMode = CodegenObjectFactoryMode.withName(SQLConf.get.codegenFactoryMode)
41+
val fallbackMode = SQLConf.get.codegenFactoryMode
4242

4343
fallbackMode match {
4444
case CodegenObjectFactoryMode.CODEGEN_ONLY =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToStringBase.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,7 @@ trait ToStringBase { self: UnaryExpression with TimeZoneAwareExpression =>
457457
object ToStringBase {
458458
def getBinaryFormatter: BinaryFormatter = {
459459
val style = SQLConf.get.getConf(SQLConf.BINARY_OUTPUT_STYLE)
460-
style.map(BinaryOutputStyle.withName) match {
460+
style match {
461461
case Some(BinaryOutputStyle.UTF8) =>
462462
(array: Array[Byte]) => UTF8String.fromBytes(array)
463463
case Some(BinaryOutputStyle.BASIC) =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,9 @@ class ArrayBasedMapBuilder(keyType: DataType, valueType: DataType) extends Seria
8484
keys.append(keyNormalized)
8585
values.append(value)
8686
} else {
87-
if (mapKeyDedupPolicy == SQLConf.MapKeyDedupPolicy.EXCEPTION.toString) {
87+
if (mapKeyDedupPolicy == SQLConf.MapKeyDedupPolicy.EXCEPTION) {
8888
throw QueryExecutionErrors.duplicateMapKeyFoundError(key)
89-
} else if (mapKeyDedupPolicy == SQLConf.MapKeyDedupPolicy.LAST_WIN.toString) {
89+
} else if (mapKeyDedupPolicy == SQLConf.MapKeyDedupPolicy.LAST_WIN) {
9090
// Overwrite the previous value, as the policy is last wins.
9191
values(index) = value
9292
} else {

0 commit comments

Comments
 (0)