Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {
u.copy(child = newChild)
}

case d @ DefaultValueExpression(c: Expression, _, _) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume this works because we resolve expressions top to bottom, hence we see DefaultValueExpression before we see the unresolved attribute?

d.copy(child = resolveLiteralColumns(c))

case _ => e.mapChildren(innerResolve(_, isTopLevel = false))
}
resolved.copyTagsFrom(e)
Expand All @@ -195,6 +198,13 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {
}
}

private def resolveLiteralColumns(e: Expression) = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a little bit worried that this doesn't fix the root problem that ResolvedIdentifier output in CREATE actually is used as candidates for resolving default values. Let me think.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, maybe it does solve the root problem. Am I right we will not recurse into DefaultValueExpression child so we effectively ensure that default value resolution doesn't have access to attributes?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, this is matched before children

e.transformWithPruning(_.containsPattern(UNRESOLVED_ATTRIBUTE)) {
case u @ UnresolvedAttribute(nameParts) =>
LiteralFunctionResolution.resolve(nameParts).getOrElse(u)
}
}

// Resolves `UnresolvedAttribute` to `OuterReference`.
protected def resolveOuterRef(e: Expression): Expression = {
val outerPlan = AnalysisContext.get.outerPlan
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1590,4 +1590,79 @@ trait AlterTableTests extends SharedSparkSession with QueryErrorsBase {
Row(1, Map(Row(20, null) -> Row("d", null)), "sales")))
}
}

private def checkConflictSpecialColNameResult(table: String): Unit = {
val result = sql(s"SELECT * FROM $table").collect()
assert(result.length == 1)
assert(!result(0).getBoolean(0))
assert(result(0).get(1) != null)
}

test("Add column with special column name default value conflicting with column name") {
withSQLConf(SQLConf.DEFAULT_COLUMN_ALLOWED_PROVIDERS.key -> s"$v2Format, ") {
val t = fullTableName("table_name")
// There is a default value that is a special column name 'current_timestamp'.
withTable(t) {
sql(s"CREATE TABLE $t (i boolean) USING $v2Format")
sql(s"ALTER TABLE $t ADD COLUMN s timestamp DEFAULT current_timestamp")
sql(s"INSERT INTO $t(i) VALUES(false)")
checkConflictSpecialColNameResult(t)
}
// There is a default value with special column name 'current_user' but in uppercase.
withTable(t) {
sql(s"CREATE TABLE $t (i boolean) USING $v2Format")
sql(s"ALTER TABLE $t ADD COLUMN s string DEFAULT CURRENT_USER")
sql(s"INSERT INTO $t(i) VALUES(false)")
checkConflictSpecialColNameResult(t)
}
// There is a default value with special column name same as current column name
withTable(t) {
sql(s"CREATE TABLE $t (b boolean) USING $v2Format")
sql(s"ALTER TABLE $t ADD COLUMN current_timestamp timestamp DEFAULT current_timestamp")
sql(s"INSERT INTO $t(b) VALUES(false)")
checkConflictSpecialColNameResult(t)
}
// There is a default value with special column name same as another column name
withTable(t) {
sql(s"CREATE TABLE $t (current_date boolean) USING $v2Format")
sql(s"ALTER TABLE $t ADD COLUMN s date DEFAULT current_date")
sql(s"INSERT INTO $t(current_date) VALUES(false)")
checkConflictSpecialColNameResult(t)
}
}
}

test("Set default value for existing column conflicting with special column names") {
withSQLConf(SQLConf.DEFAULT_COLUMN_ALLOWED_PROVIDERS.key -> s"$v2Format, ") {
val t = fullTableName("table_name")
// There is a default value that is a special column name 'current_timestamp'.
withTable(t) {
sql(s"CREATE TABLE $t (i boolean, s timestamp) USING $v2Format")
sql(s"ALTER TABLE $t ALTER COLUMN s SET DEFAULT current_timestamp")
sql(s"INSERT INTO $t(i) VALUES(false)")
checkConflictSpecialColNameResult(t)
}
// There is a default value with special column name 'current_user' but in uppercase.
withTable(t) {
sql(s"CREATE TABLE $t (i boolean, s string) USING $v2Format")
sql(s"ALTER TABLE $t ALTER COLUMN s SET DEFAULT CURRENT_USER")
sql(s"INSERT INTO $t(i) VALUES(false)")
checkConflictSpecialColNameResult(t)
}
// There is a default value with special column name same as current column name
withTable(t) {
sql(s"CREATE TABLE $t (b boolean, current_timestamp timestamp) USING $v2Format")
sql(s"ALTER TABLE $t ALTER COLUMN current_timestamp SET DEFAULT current_timestamp")
sql(s"INSERT INTO $t(b) VALUES(false)")
checkConflictSpecialColNameResult(t)
}
// There is a default value with special column name same as another column name
withTable(t) {
sql(s"CREATE TABLE $t (current_date boolean, s date) USING $v2Format")
sql(s"ALTER TABLE $t ALTER COLUMN s SET DEFAULT current_date")
sql(s"INSERT INTO $t(current_date) VALUES(false)")
checkConflictSpecialColNameResult(t)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -856,6 +856,139 @@ class DataSourceV2DataFrameSuite
}
}

test("test default value special column name conflicting with real column name: CREATE") {
val t = "testcat.ns.t"
withTable("t") {
val createExec = executeAndKeepPhysicalPlan[CreateTableExec] {
sql(s"""CREATE table $t (
c1 STRING,
current_date DATE DEFAULT CURRENT_DATE,
current_timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
current_time time DEFAULT CURRENT_TIME,
current_user STRING DEFAULT CURRENT_USER,
session_user STRING DEFAULT SESSION_USER,
user STRING DEFAULT USER,
current_database STRING DEFAULT CURRENT_DATABASE(),
current_catalog STRING DEFAULT CURRENT_CATALOG())""")
}

val columns = createExec.columns
checkDefaultValues(
columns,
Array(
null, // c1 has no default value
new ColumnDefaultValue("CURRENT_DATE", null),
new ColumnDefaultValue("CURRENT_TIMESTAMP", null),
new ColumnDefaultValue("CURRENT_TIME", null),
new ColumnDefaultValue("CURRENT_USER", null),
new ColumnDefaultValue("SESSION_USER", null),
new ColumnDefaultValue("USER", null),
new ColumnDefaultValue("CURRENT_DATABASE()", null),
new ColumnDefaultValue("CURRENT_CATALOG()", null)),
compareValue = false)

sql(s"INSERT INTO $t (c1) VALUES ('a')")
val result = sql(s"SELECT * FROM $t").collect()
assert(result.length == 1)
assert(result(0).getString(0) == "a")
Seq(1 to 8: _*).foreach(i => assert(result(0).get(i) != null))
}
}

test("test default value special column name conflicting with real column name: REPLACE") {
val t = "testcat.ns.t"
withTable("t") {
sql(s"""CREATE table $t (
c1 STRING)""")
val replaceExec = executeAndKeepPhysicalPlan[ReplaceTableExec] {
sql(
s"""REPLACE table $t (
c1 STRING,
current_date DATE DEFAULT CURRENT_DATE,
current_timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
current_time time DEFAULT CURRENT_TIME,
current_user STRING DEFAULT CURRENT_USER,
session_user STRING DEFAULT SESSION_USER,
user STRING DEFAULT USER,
current_database STRING DEFAULT CURRENT_DATABASE(),
current_catalog STRING DEFAULT CURRENT_CATALOG())""")
}

val columns = replaceExec.columns
checkDefaultValues(
columns,
Array(
null, // c1 has no default value
new ColumnDefaultValue("CURRENT_DATE", null),
new ColumnDefaultValue("CURRENT_TIMESTAMP", null),
new ColumnDefaultValue("CURRENT_TIME", null),
new ColumnDefaultValue("CURRENT_USER", null),
new ColumnDefaultValue("SESSION_USER", null),
new ColumnDefaultValue("USER", null),
new ColumnDefaultValue("CURRENT_DATABASE()", null),
new ColumnDefaultValue("CURRENT_CATALOG()", null)),
compareValue = false)

sql(s"INSERT INTO $t (c1) VALUES ('a')")
val result = sql(s"SELECT * FROM $t").collect()
assert(result.length == 1)
assert(result(0).getString(0) == "a")
Seq(1 to 8: _*).foreach(i => assert(result(0).get(i) != null))
}
}

test("create table with conflicting literal function value in nested default value") {
val tableName = "testcat.ns1.ns2.tbl"
withTable(tableName) {
val createExec = executeAndKeepPhysicalPlan[CreateTableExec] {
sql(
s"""
|CREATE TABLE $tableName (
| c1 STRING,
| current_date DATE DEFAULT DATE_ADD(current_date, 7)
|) USING foo
|""".stripMargin)
}

// Check that the table was created with the expected default value
val columns = createExec.columns
checkDefaultValues(
columns,
Array(
null, // c1 has no default value
new ColumnDefaultValue("DATE_ADD(current_date, 7)", null)),
compareValue = false)

val df1 = Seq("test1", "test2").toDF("c1")
df1.writeTo(tableName).append()

val result = sql(s"SELECT * FROM $tableName")
assert(result.count() == 2)
assert(result.collect().map(_.getString(0)).toSet == Set("test1", "test2"))
assert(result.collect().forall(_.get(1) != null))
}
}

test("test default value should not refer to real column") {
val t = "testcat.ns.t"
withTable("t") {
checkError(
exception = intercept[AnalysisException] {
sql(s"""CREATE table $t (
c1 timestamp,
current_timestamp TIMESTAMP DEFAULT c1)""")
},
condition = "INVALID_DEFAULT_VALUE.UNRESOLVED_EXPRESSION",
parameters = Map(
"statement" -> "CREATE TABLE",
"colName" -> "`current_timestamp`",
"defaultValue" -> "c1"
),
sqlState = Some("42623")
)
}
}

private def executeAndKeepPhysicalPlan[T <: SparkPlan](func: => Unit): T = {
val qe = withQueryExecutionsCaptured(spark) {
func
Expand All @@ -865,15 +998,15 @@ class DataSourceV2DataFrameSuite

private def checkDefaultValues(
columns: Array[Column],
expectedDefaultValues: Array[ColumnDefaultValue]): Unit = {
expectedDefaultValues: Array[ColumnDefaultValue],
compareValue: Boolean = true): Unit = {
assert(columns.length == expectedDefaultValues.length)

columns.zip(expectedDefaultValues).foreach {
case (column, expectedDefault) =>
assert(
column.defaultValue == expectedDefault,
s"Default value mismatch for column '${column.name}': " +
s"expected $expectedDefault but found ${column.defaultValue}")
assert(compareColumnDefaultValue(column.defaultValue(), expectedDefault, compareValue),
s"Default value mismatch for column '${column.toString}': " +
s"expected $expectedDefault but found ${column.defaultValue}")
}
}

Expand Down Expand Up @@ -912,4 +1045,17 @@ class DataSourceV2DataFrameSuite
s"Default value mismatch for column '${column.toString}': " +
s"expected empty but found ${column.newCurrentDefault()}")
}

private def compareColumnDefaultValue(
left: ColumnDefaultValue,
right: ColumnDefaultValue,
compareValue: Boolean) = {
(left, right) match {
case (null, null) => true
case (null, _) | (_, null) => false
case _ => left.getSql == right.getSql &&
left.getExpression == right.getExpression &&
(!compareValue || left.getValue == right.getValue)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,33 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase {
Row(1, 42, "hr") :: Row(2, 2, "software") :: Row(3, 42, "hr") :: Nil)
}

test("update with current_timestamp default value using DEFAULT keyword") {
sql(s"""CREATE TABLE $tableNameAsString
| (pk INT NOT NULL, current_timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP)""".stripMargin)
append("pk INT NOT NULL, current_timestamp TIMESTAMP",
"""{ "pk": 1, "i": false, "current_timestamp": "2023-01-01 10:00:00" }
|{ "pk": 2, "i": true, "current_timestamp": "2023-01-01 11:00:00" }
|""".stripMargin)

val initialResult = sql(s"SELECT * FROM $tableNameAsString").collect()
assert(initialResult.length == 2)
val initialTimestamp1 = initialResult(0).getTimestamp(1)
val initialTimestamp2 = initialResult(1).getTimestamp(1)

sql(s"UPDATE $tableNameAsString SET current_timestamp = DEFAULT WHERE pk = 1")

val updatedResult = sql(s"SELECT * FROM $tableNameAsString").collect()
assert(updatedResult.length == 2)

val updatedRow = updatedResult.find(_.getInt(0) == 1).get
val unchangedRow = updatedResult.find(_.getInt(0) == 2).get

// The timestamp should be different (newer) after the update for pk=1
assert(updatedRow.getTimestamp(1).getTime > initialTimestamp1.getTime)
// The timestamp should remain unchanged for pk=2
assert(unchangedRow.getTimestamp(1).getTime == initialTimestamp2.getTime)
}

test("update char/varchar columns") {
createTable("pk INT NOT NULL, s STRUCT<n_c: CHAR(3), n_vc: VARCHAR(5)>, dep STRING")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1097,6 +1097,44 @@ class InsertSuite extends DataSourceTest with SharedSparkSession {
sql("insert into t select false, default")
checkAnswer(spark.table("t"), Row(false, 42L))
}
// There is a default value that is a special column name 'current_timestamp'.
withTable("t") {
sql("create table t(i boolean, s timestamp default current_timestamp) using parquet")
sql("insert into t(i) values(false)")
val result = spark.table("t").collect()
assert(result.length == 1)
assert(!result(0).getBoolean(0))
assert(result(0).getTimestamp(1) != null)
}
// There is a default value with special column name 'current_user' but in uppercase.
withTable("t") {
sql("create table t(i boolean, s string default CURRENT_USER) using parquet")
sql("insert into t(i) values(false)")
val result = spark.table("t").collect()
assert(result.length == 1)
assert(!result(0).getBoolean(0))
assert(result(0).getString(1) != null)
}
// There is a default value with special column name same as current column name
withTable("t") {
sql("create table t(current_timestamp timestamp default current_timestamp, b boolean) " +
"using parquet")
sql("insert into t(b) values(false)")
val result = spark.table("t").collect()
assert(result.length == 1)
assert(result(0).getTimestamp(0) != null)
assert(!result(0).getBoolean(1))
}
// There is a default value with special column name same as another column name
withTable("t") {
sql("create table t(current_date boolean, s date default current_date) " +
"using parquet")
sql("insert into t(current_date) values(false)")
val result = spark.table("t").collect()
assert(result.length == 1)
assert(!result(0).getBoolean(0))
assert(result(0).getDate(1) != null)
}
// There is a complex query plan in the SELECT query in the INSERT INTO statement.
withTable("t") {
sql("create table t(i boolean default false, s bigint default 42) using parquet")
Expand Down