Skip to content

Commit fc1cb78

Browse files
committed
[SPARK-51834][SQL] Support end-to-end table constraint management
### What changes were proposed in this pull request? Support end-to-end table constraint management: - Create a DSV2 table with constraints - Replace a DSV2 table with constraints - ALTER a DSV2 table to add a new constraint - ALTER a DSV2 table to drop a constraint ### Why are the changes needed? Allow users to define and modify table constraints in connectors that support them. ### Does this PR introduce _any_ user-facing change? No, it is for DSV2 framework. ### How was this patch tested? New UT ### Was this patch authored or co-authored using generative AI tooling? No Closes #50631 from gengliangwang/constraintE2E. Authored-by: Gengliang Wang <[email protected]> Signed-off-by: Gengliang Wang <[email protected]>
1 parent 0225903 commit fc1cb78

File tree

19 files changed

+888
-48
lines changed

19 files changed

+888
-48
lines changed

common/utils/src/main/resources/error/error-conditions.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4095,6 +4095,12 @@
40954095
],
40964096
"sqlState" : "HV091"
40974097
},
4098+
"NON_DETERMINISTIC_CHECK_CONSTRAINT" : {
4099+
"message" : [
4100+
"The check constraint `<checkCondition>` is non-deterministic. Check constraints must only contain deterministic expressions."
4101+
],
4102+
"sqlState" : "42621"
4103+
},
40984104
"NON_FOLDABLE_ARGUMENT" : {
40994105
"message" : [
41004106
"The function <funcName> requires the parameter <paramName> to be a foldable expression of the type <paramType>, but the actual argument is a non-foldable."

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1047,6 +1047,12 @@ trait CheckAnalysis extends LookupCatalog with QueryErrorsBase with PlanToString
10471047
case RenameColumn(table: ResolvedTable, col: ResolvedFieldName, newName) =>
10481048
checkColumnNotExists("rename", col.path :+ newName, table.schema)
10491049

1050+
case AddConstraint(_: ResolvedTable, check: CheckConstraint) if !check.deterministic =>
1051+
check.child.failAnalysis(
1052+
errorClass = "NON_DETERMINISTIC_CHECK_CONSTRAINT",
1053+
messageParameters = Map("checkCondition" -> check.condition)
1054+
)
1055+
10501056
case AlterColumns(table: ResolvedTable, specs) =>
10511057
val groupedColumns = specs.groupBy(_.column.name)
10521058
groupedColumns.collect {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,11 @@ import scala.jdk.CollectionConverters._
2222
import org.apache.spark.SparkException
2323
import org.apache.spark.sql.AnalysisException
2424
import org.apache.spark.sql.catalyst.SqlScriptingLocalVariableManager
25+
import org.apache.spark.sql.catalyst.expressions.AttributeReference
2526
import org.apache.spark.sql.catalyst.plans.logical._
2627
import org.apache.spark.sql.catalyst.rules.Rule
27-
import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, Identifier, LookupCatalog, SupportsNamespaces}
28+
import org.apache.spark.sql.catalyst.util.SparkCharVarcharUtils.replaceCharVarcharWithString
29+
import org.apache.spark.sql.connector.catalog._
2830
import org.apache.spark.sql.errors.DataTypeErrors.toSQLId
2931
import org.apache.spark.sql.errors.QueryCompilationErrors
3032
import org.apache.spark.util.ArrayImplicits._
@@ -77,14 +79,19 @@ class ResolveCatalogs(val catalogManager: CatalogManager)
7779
assertValidSessionVariableNameParts(nameParts, resolved)
7880
d.copy(name = resolved)
7981

82+
// For CREATE TABLE and REPLACE TABLE statements, resolve the table identifier and include
83+
// the table columns as output. This allows expressions (e.g., constraints) referencing these
84+
// columns to be resolved correctly.
85+
case c @ CreateTable(UnresolvedIdentifier(nameParts, allowTemp), columns, _, _, _) =>
86+
val resolvedIdentifier = resolveIdentifier(nameParts, allowTemp, columns)
87+
c.copy(name = resolvedIdentifier)
88+
89+
case r @ ReplaceTable(UnresolvedIdentifier(nameParts, allowTemp), columns, _, _, _) =>
90+
val resolvedIdentifier = resolveIdentifier(nameParts, allowTemp, columns)
91+
r.copy(name = resolvedIdentifier)
92+
8093
case UnresolvedIdentifier(nameParts, allowTemp) =>
81-
if (allowTemp && catalogManager.v1SessionCatalog.isTempView(nameParts)) {
82-
val ident = Identifier.of(nameParts.dropRight(1).toArray, nameParts.last)
83-
ResolvedIdentifier(FakeSystemCatalog, ident)
84-
} else {
85-
val CatalogAndIdentifier(catalog, identifier) = nameParts
86-
ResolvedIdentifier(catalog, identifier)
87-
}
94+
resolveIdentifier(nameParts, allowTemp, Nil)
8895

8996
case CurrentNamespace =>
9097
ResolvedNamespace(currentCatalog, catalogManager.currentNamespace.toImmutableArraySeq)
@@ -94,6 +101,27 @@ class ResolveCatalogs(val catalogManager: CatalogManager)
94101
resolveNamespace(catalog, ns, fetchMetadata)
95102
}
96103

104+
private def resolveIdentifier(
105+
nameParts: Seq[String],
106+
allowTemp: Boolean,
107+
columns: Seq[ColumnDefinition]): ResolvedIdentifier = {
108+
val columnOutput = columns.map { col =>
109+
val dataType = if (conf.preserveCharVarcharTypeInfo) {
110+
col.dataType
111+
} else {
112+
replaceCharVarcharWithString(col.dataType)
113+
}
114+
AttributeReference(col.name, dataType, col.nullable, col.metadata)()
115+
}
116+
if (allowTemp && catalogManager.v1SessionCatalog.isTempView(nameParts)) {
117+
val ident = Identifier.of(nameParts.dropRight(1).toArray, nameParts.last)
118+
ResolvedIdentifier(FakeSystemCatalog, ident, columnOutput)
119+
} else {
120+
val CatalogAndIdentifier(catalog, identifier) = nameParts
121+
ResolvedIdentifier(catalog, identifier, columnOutput)
122+
}
123+
}
124+
97125
private def resolveNamespace(
98126
catalog: CatalogPlugin,
99127
ns: Seq[String],

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableSpec.scala

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.catalyst.analysis
1919

2020
import org.apache.spark.SparkThrowable
21-
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal}
21+
import org.apache.spark.sql.catalyst.expressions._
2222
import org.apache.spark.sql.catalyst.optimizer.ComputeCurrentTime
2323
import org.apache.spark.sql.catalyst.plans.logical._
2424
import org.apache.spark.sql.catalyst.rules.Rule
@@ -61,7 +61,7 @@ object ResolveTableSpec extends Rule[LogicalPlan] {
6161
input: LogicalPlan,
6262
tableSpec: TableSpecBase,
6363
withNewSpec: TableSpecBase => LogicalPlan): LogicalPlan = tableSpec match {
64-
case u: UnresolvedTableSpec if u.optionExpression.resolved =>
64+
case u: UnresolvedTableSpec if u.childrenResolved =>
6565
val newOptions: Seq[(String, String)] = u.optionExpression.options.map {
6666
case (key: String, null) =>
6767
(key, null)
@@ -86,6 +86,18 @@ object ResolveTableSpec extends Rule[LogicalPlan] {
8686
}
8787
(key, newValue)
8888
}
89+
90+
u.constraints.foreach {
91+
case check: CheckConstraint =>
92+
if (!check.child.deterministic) {
93+
check.child.failAnalysis(
94+
errorClass = "NON_DETERMINISTIC_CHECK_CONSTRAINT",
95+
messageParameters = Map("checkCondition" -> check.condition)
96+
)
97+
}
98+
case _ =>
99+
}
100+
89101
val newTableSpec = TableSpec(
90102
properties = u.properties,
91103
provider = u.provider,
@@ -94,7 +106,8 @@ object ResolveTableSpec extends Rule[LogicalPlan] {
94106
comment = u.comment,
95107
collation = u.collation,
96108
serde = u.serde,
97-
external = u.external)
109+
external = u.external,
110+
constraints = u.constraints.map(_.toV2Constraint))
98111
withNewSpec(newTableSpec)
99112
case _ =>
100113
input

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,8 +252,13 @@ case class ResolvedNonPersistentFunc(
252252
*/
253253
case class ResolvedIdentifier(
254254
catalog: CatalogPlugin,
255-
identifier: Identifier) extends LeafNodeWithoutStats {
256-
override def output: Seq[Attribute] = Nil
255+
identifier: Identifier,
256+
override val output: Seq[Attribute] = Nil) extends LeafNodeWithoutStats
257+
258+
object ResolvedIdentifier {
259+
def unapply(ri: ResolvedIdentifier): Option[(CatalogPlugin, Identifier)] = {
260+
Some((ri.catalog, ri.identifier))
261+
}
257262
}
258263

259264
// A fake v2 catalog to hold temp views.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala

Lines changed: 69 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,17 @@ package org.apache.spark.sql.catalyst.expressions
1818

1919
import java.util.UUID
2020

21+
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
2122
import org.apache.spark.sql.catalyst.parser.ParseException
2223
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
23-
import org.apache.spark.sql.types.{DataType, StringType}
24+
import org.apache.spark.sql.catalyst.util.V2ExpressionBuilder
25+
import org.apache.spark.sql.connector.catalog.constraints.Constraint
26+
import org.apache.spark.sql.connector.expressions.FieldReference
27+
import org.apache.spark.sql.types.DataType
2428

25-
trait TableConstraint {
29+
trait TableConstraint extends Expression with Unevaluable {
30+
/** Convert to a data source v2 constraint */
31+
def toV2Constraint: Constraint
2632

2733
/** Returns the user-provided name of the constraint */
2834
def userProvidedName: String
@@ -92,6 +98,10 @@ trait TableConstraint {
9298
)
9399
}
94100
}
101+
102+
override def nullable: Boolean = throw new UnresolvedException("nullable")
103+
104+
override def dataType: DataType = throw new UnresolvedException("dataType")
95105
}
96106

97107
case class ConstraintCharacteristic(enforced: Option[Boolean], rely: Option[Boolean])
@@ -108,10 +118,25 @@ case class CheckConstraint(
108118
override val tableName: String = null,
109119
override val userProvidedCharacteristic: ConstraintCharacteristic = ConstraintCharacteristic.empty)
110120
extends UnaryExpression
111-
with Unevaluable
112121
with TableConstraint {
113122
// scalastyle:on line.size.limit
114123

124+
def toV2Constraint: Constraint = {
125+
val predicate = new V2ExpressionBuilder(child, true).buildPredicate().orNull
126+
val enforced = userProvidedCharacteristic.enforced.getOrElse(true)
127+
val rely = userProvidedCharacteristic.rely.getOrElse(false)
128+
// TODO(SPARK-51903): Change the status to VALIDATED when we support validation on ALTER TABLE
129+
val validateStatus = Constraint.ValidationStatus.UNVALIDATED
130+
Constraint
131+
.check(name)
132+
.predicateSql(condition)
133+
.predicate(predicate)
134+
.rely(rely)
135+
.enforced(enforced)
136+
.validationStatus(validateStatus)
137+
.build()
138+
}
139+
115140
override protected def withNewChildInternal(newChild: Expression): Expression =
116141
copy(child = newChild)
117142

@@ -121,8 +146,6 @@ case class CheckConstraint(
121146

122147
override def sql: String = s"CONSTRAINT $userProvidedName CHECK ($condition)"
123148

124-
override def dataType: DataType = StringType
125-
126149
override def withUserProvidedName(name: String): TableConstraint = copy(userProvidedName = name)
127150

128151
override def withTableName(tableName: String): TableConstraint = copy(tableName = tableName)
@@ -137,9 +160,20 @@ case class PrimaryKeyConstraint(
137160
override val userProvidedName: String = null,
138161
override val tableName: String = null,
139162
override val userProvidedCharacteristic: ConstraintCharacteristic = ConstraintCharacteristic.empty)
140-
extends TableConstraint {
163+
extends LeafExpression with TableConstraint {
141164
// scalastyle:on line.size.limit
142165

166+
override def toV2Constraint: Constraint = {
167+
val enforced = userProvidedCharacteristic.enforced.getOrElse(false)
168+
val rely = userProvidedCharacteristic.rely.getOrElse(false)
169+
Constraint
170+
.primaryKey(name, columns.map(FieldReference.column).toArray)
171+
.rely(rely)
172+
.enforced(enforced)
173+
.validationStatus(Constraint.ValidationStatus.UNVALIDATED)
174+
.build()
175+
}
176+
143177
override protected def generateName(tableName: String): String = s"${tableName}_pk"
144178

145179
override def withUserProvidedName(name: String): TableConstraint = copy(userProvidedName = name)
@@ -158,9 +192,20 @@ case class UniqueConstraint(
158192
override val userProvidedName: String = null,
159193
override val tableName: String = null,
160194
override val userProvidedCharacteristic: ConstraintCharacteristic = ConstraintCharacteristic.empty)
161-
extends TableConstraint {
195+
extends LeafExpression with TableConstraint {
162196
// scalastyle:on line.size.limit
163197

198+
override def toV2Constraint: Constraint = {
199+
val enforced = userProvidedCharacteristic.enforced.getOrElse(false)
200+
val rely = userProvidedCharacteristic.rely.getOrElse(false)
201+
Constraint
202+
.unique(name, columns.map(FieldReference.column).toArray)
203+
.rely(rely)
204+
.enforced(enforced)
205+
.validationStatus(Constraint.ValidationStatus.UNVALIDATED)
206+
.build()
207+
}
208+
164209
override protected def generateName(tableName: String): String = {
165210
s"${tableName}_uniq_$randomSuffix"
166211
}
@@ -183,9 +228,25 @@ case class ForeignKeyConstraint(
183228
override val userProvidedName: String = null,
184229
override val tableName: String = null,
185230
override val userProvidedCharacteristic: ConstraintCharacteristic = ConstraintCharacteristic.empty)
186-
extends TableConstraint {
231+
extends LeafExpression with TableConstraint {
187232
// scalastyle:on line.size.limit
188233

234+
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
235+
236+
override def toV2Constraint: Constraint = {
237+
val enforced = userProvidedCharacteristic.enforced.getOrElse(false)
238+
val rely = userProvidedCharacteristic.rely.getOrElse(false)
239+
Constraint
240+
.foreignKey(name,
241+
childColumns.map(FieldReference.column).toArray,
242+
parentTableId.asIdentifier,
243+
parentColumns.map(FieldReference.column).toArray)
244+
.rely(rely)
245+
.enforced(enforced)
246+
.validationStatus(Constraint.ValidationStatus.UNVALIDATED)
247+
.build()
248+
}
249+
189250
override protected def generateName(tableName: String): String =
190251
s"${tableName}_${parentTableId.last}_fk_$randomSuffix"
191252

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.plans.logical
1919

20-
import org.apache.spark.sql.catalyst.analysis.{FieldName, FieldPosition, ResolvedFieldName, UnresolvedException}
20+
import org.apache.spark.sql.catalyst.analysis.{FieldName, FieldPosition, ResolvedFieldName, ResolvedTable, UnresolvedException}
2121
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
2222
import org.apache.spark.sql.catalyst.catalog.ClusterBySpec
2323
import org.apache.spark.sql.catalyst.expressions.{Expression, TableConstraint, Unevaluable}
@@ -295,7 +295,16 @@ case class AlterTableCollation(
295295
case class AddConstraint(
296296
table: LogicalPlan,
297297
tableConstraint: TableConstraint) extends AlterTableCommand {
298-
override def changes: Seq[TableChange] = Seq.empty
298+
override def changes: Seq[TableChange] = {
299+
val constraint = tableConstraint.toV2Constraint
300+
val validatedTableVersion = table match {
301+
case t: ResolvedTable if constraint.enforced() =>
302+
t.table.currentVersion()
303+
case _ =>
304+
null
305+
}
306+
Seq(TableChange.addConstraint(constraint, validatedTableVersion))
307+
}
299308

300309
protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = copy(table = newChild)
301310
}
@@ -308,7 +317,8 @@ case class DropConstraint(
308317
name: String,
309318
ifExists: Boolean,
310319
cascade: Boolean) extends AlterTableCommand {
311-
override def changes: Seq[TableChange] = Seq.empty
320+
override def changes: Seq[TableChange] =
321+
Seq(TableChange.dropConstraint(name, ifExists, cascade))
312322

313323
protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = copy(table = newChild)
314324
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1512,19 +1512,25 @@ case class UnresolvedTableSpec(
15121512
serde: Option[SerdeInfo],
15131513
external: Boolean,
15141514
constraints: Seq[TableConstraint])
1515-
extends UnaryExpression with Unevaluable with TableSpecBase {
1515+
extends Expression with Unevaluable with TableSpecBase {
15161516

15171517
override def dataType: DataType =
15181518
throw new SparkUnsupportedOperationException("_LEGACY_ERROR_TEMP_3113")
15191519

1520-
override def child: Expression = optionExpression
1521-
1522-
override protected def withNewChildInternal(newChild: Expression): Expression =
1523-
this.copy(optionExpression = newChild.asInstanceOf[OptionList])
1524-
15251520
override def simpleString(maxFields: Int): String = {
15261521
this.copy(properties = Utils.redact(properties).toMap).toString
15271522
}
1523+
1524+
override def nullable: Boolean = true
1525+
1526+
override def children: Seq[Expression] = optionExpression +: constraints
1527+
1528+
override protected def withNewChildrenInternal(
1529+
newChildren: IndexedSeq[Expression]): Expression = {
1530+
copy(
1531+
optionExpression = newChildren.head.asInstanceOf[OptionList],
1532+
constraints = newChildren.tail.asInstanceOf[Seq[TableConstraint]])
1533+
}
15281534
}
15291535

15301536
/**

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateTableExec.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ case class CreateTableExec(
4747
.withColumns(columns)
4848
.withPartitions(partitioning.toArray)
4949
.withProperties(tableProperties.asJava)
50+
.withConstraints(tableSpec.constraints.toArray)
5051
.build()
5152
catalog.createTable(identifier, tableInfo)
5253
} catch {

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ case class ReplaceTableExec(
5252
.withColumns(columns)
5353
.withPartitions(partitioning.toArray)
5454
.withProperties(tableProperties.asJava)
55+
.withConstraints(tableSpec.constraints.toArray)
5556
.build()
5657
catalog.createTable(ident, tableInfo)
5758
Seq.empty

0 commit comments

Comments
 (0)