diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala index 3224ccafafec3..912f5211a2d95 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -175,6 +175,9 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { u.copy(child = newChild) } + case d @ DefaultValueExpression(c: Expression, _, _) => + d.copy(child = resolveLiteralColumns(c)) + case _ => e.mapChildren(innerResolve(_, isTopLevel = false)) } resolved.copyTagsFrom(e) @@ -195,6 +198,13 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { } } + private def resolveLiteralColumns(e: Expression) = { + e.transformWithPruning(_.containsPattern(UNRESOLVED_ATTRIBUTE)) { + case u @ UnresolvedAttribute(nameParts) => + LiteralFunctionResolution.resolve(nameParts).getOrElse(u) + } + } + // Resolves `UnresolvedAttribute` to `OuterReference`. protected def resolveOuterRef(e: Expression): Expression = { val outerPlan = AnalysisContext.get.outerPlan diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala index ea65a580821af..9c1cd1cbb1056 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala @@ -1590,4 +1590,79 @@ trait AlterTableTests extends SharedSparkSession with QueryErrorsBase { Row(1, Map(Row(20, null) -> Row("d", null)), "sales"))) } } + + private def checkConflictSpecialColNameResult(table: String): Unit = { + val result = sql(s"SELECT * FROM $table").collect() + assert(result.length == 1) + assert(!result(0).getBoolean(0)) + assert(result(0).get(1) != null) + } + + test("Add column with special column name default value conflicting with column name") { + withSQLConf(SQLConf.DEFAULT_COLUMN_ALLOWED_PROVIDERS.key -> s"$v2Format, ") { + val t = fullTableName("table_name") + // There is a default value that is a special column name 'current_timestamp'. + withTable(t) { + sql(s"CREATE TABLE $t (i boolean) USING $v2Format") + sql(s"ALTER TABLE $t ADD COLUMN s timestamp DEFAULT current_timestamp") + sql(s"INSERT INTO $t(i) VALUES(false)") + checkConflictSpecialColNameResult(t) + } + // There is a default value with special column name 'current_user' but in uppercase. + withTable(t) { + sql(s"CREATE TABLE $t (i boolean) USING $v2Format") + sql(s"ALTER TABLE $t ADD COLUMN s string DEFAULT CURRENT_USER") + sql(s"INSERT INTO $t(i) VALUES(false)") + checkConflictSpecialColNameResult(t) + } + // There is a default value with special column name same as current column name + withTable(t) { + sql(s"CREATE TABLE $t (b boolean) USING $v2Format") + sql(s"ALTER TABLE $t ADD COLUMN current_timestamp timestamp DEFAULT current_timestamp") + sql(s"INSERT INTO $t(b) VALUES(false)") + checkConflictSpecialColNameResult(t) + } + // There is a default value with special column name same as another column name + withTable(t) { + sql(s"CREATE TABLE $t (current_date boolean) USING $v2Format") + sql(s"ALTER TABLE $t ADD COLUMN s date DEFAULT current_date") + sql(s"INSERT INTO $t(current_date) VALUES(false)") + checkConflictSpecialColNameResult(t) + } + } + } + + test("Set default value for existing column conflicting with special column names") { + withSQLConf(SQLConf.DEFAULT_COLUMN_ALLOWED_PROVIDERS.key -> s"$v2Format, ") { + val t = fullTableName("table_name") + // There is a default value that is a special column name 'current_timestamp'. + withTable(t) { + sql(s"CREATE TABLE $t (i boolean, s timestamp) USING $v2Format") + sql(s"ALTER TABLE $t ALTER COLUMN s SET DEFAULT current_timestamp") + sql(s"INSERT INTO $t(i) VALUES(false)") + checkConflictSpecialColNameResult(t) + } + // There is a default value with special column name 'current_user' but in uppercase. + withTable(t) { + sql(s"CREATE TABLE $t (i boolean, s string) USING $v2Format") + sql(s"ALTER TABLE $t ALTER COLUMN s SET DEFAULT CURRENT_USER") + sql(s"INSERT INTO $t(i) VALUES(false)") + checkConflictSpecialColNameResult(t) + } + // There is a default value with special column name same as current column name + withTable(t) { + sql(s"CREATE TABLE $t (b boolean, current_timestamp timestamp) USING $v2Format") + sql(s"ALTER TABLE $t ALTER COLUMN current_timestamp SET DEFAULT current_timestamp") + sql(s"INSERT INTO $t(b) VALUES(false)") + checkConflictSpecialColNameResult(t) + } + // There is a default value with special column name same as another column name + withTable(t) { + sql(s"CREATE TABLE $t (current_date boolean, s date) USING $v2Format") + sql(s"ALTER TABLE $t ALTER COLUMN s SET DEFAULT current_date") + sql(s"INSERT INTO $t(current_date) VALUES(false)") + checkConflictSpecialColNameResult(t) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 9171e44571e88..d649e81ca6c78 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -3870,6 +3870,60 @@ class DataSourceV2SQLSuiteV1Filter } } + test("test default value special column name conflicting with real column name") { + val t = "testcat.ns.t" + withTable("t") { + sql(s"""CREATE table $t ( + c1 STRING, + current_date DATE DEFAULT CURRENT_DATE, + current_timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + current_time time DEFAULT CURRENT_TIME, + current_user STRING DEFAULT CURRENT_USER, + session_user STRING DEFAULT SESSION_USER, + user STRING DEFAULT USER, + current_database STRING DEFAULT CURRENT_DATABASE(), + current_catalog STRING DEFAULT CURRENT_CATALOG())""") + sql(s"INSERT INTO $t (c1) VALUES ('a')") + val result = sql(s"SELECT * FROM $t").collect() + assert(result.length == 1) + assert(result(0).getString(0) == "a") + Seq(1 to 8: _*).foreach(i => assert(result(0).get(i) != null)) + } + } + + test("test default value special column name nested in function") { + val t = "testcat.ns.t" + withTable("t") { + sql(s"""CREATE table $t ( + c1 STRING, + current_date DATE DEFAULT DATE_ADD(current_date, 7))""") + sql(s"INSERT INTO $t (c1) VALUES ('a')") + val result = sql(s"SELECT * FROM $t").collect() + assert(result.length == 1) + assert(result(0).getString(0) == "a") + } + } + + test("test default value should not refer to real column") { + val t = "testcat.ns.t" + withTable("t") { + checkError( + exception = intercept[AnalysisException] { + sql(s"""CREATE table $t ( + c1 timestamp, + current_timestamp TIMESTAMP DEFAULT c1)""") + }, + condition = "INVALID_DEFAULT_VALUE.UNRESOLVED_EXPRESSION", + parameters = Map( + "statement" -> "CREATE TABLE", + "colName" -> "`current_timestamp`", + "defaultValue" -> "c1" + ), + sqlState = Some("42623") + ) + } + } + private def testNotSupportedV2Command( sqlCommand: String, sqlParams: String, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala index a167e9299a2b2..ac0bf3bdba9ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala @@ -660,6 +660,33 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { Row(1, 42, "hr") :: Row(2, 2, "software") :: Row(3, 42, "hr") :: Nil) } + test("update with current_timestamp default value using DEFAULT keyword") { + sql(s"""CREATE TABLE $tableNameAsString + | (pk INT NOT NULL, current_timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP)""".stripMargin) + append("pk INT NOT NULL, current_timestamp TIMESTAMP", + """{ "pk": 1, "i": false, "current_timestamp": "2023-01-01 10:00:00" } + |{ "pk": 2, "i": true, "current_timestamp": "2023-01-01 11:00:00" } + |""".stripMargin) + + val initialResult = sql(s"SELECT * FROM $tableNameAsString").collect() + assert(initialResult.length == 2) + val initialTimestamp1 = initialResult(0).getTimestamp(1) + val initialTimestamp2 = initialResult(1).getTimestamp(1) + + sql(s"UPDATE $tableNameAsString SET current_timestamp = DEFAULT WHERE pk = 1") + + val updatedResult = sql(s"SELECT * FROM $tableNameAsString").collect() + assert(updatedResult.length == 2) + + val updatedRow = updatedResult.find(_.getInt(0) == 1).get + val unchangedRow = updatedResult.find(_.getInt(0) == 2).get + + // The timestamp should be different (newer) after the update for pk=1 + assert(updatedRow.getTimestamp(1).getTime > initialTimestamp1.getTime) + // The timestamp should remain unchanged for pk=2 + assert(unchangedRow.getTimestamp(1).getTime == initialTimestamp2.getTime) + } + test("update char/varchar columns") { createTable("pk INT NOT NULL, s STRUCT, dep STRING") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 6678f9535fe0d..0e779f9532d63 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -1097,6 +1097,44 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { sql("insert into t select false, default") checkAnswer(spark.table("t"), Row(false, 42L)) } + // There is a default value that is a special column name 'current_timestamp'. + withTable("t") { + sql("create table t(i boolean, s timestamp default current_timestamp) using parquet") + sql("insert into t(i) values(false)") + val result = spark.table("t").collect() + assert(result.length == 1) + assert(!result(0).getBoolean(0)) + assert(result(0).getTimestamp(1) != null) + } + // There is a default value with special column name 'current_user' but in uppercase. + withTable("t") { + sql("create table t(i boolean, s string default CURRENT_USER) using parquet") + sql("insert into t(i) values(false)") + val result = spark.table("t").collect() + assert(result.length == 1) + assert(!result(0).getBoolean(0)) + assert(result(0).getString(1) != null) + } + // There is a default value with special column name same as current column name + withTable("t") { + sql("create table t(current_timestamp timestamp default current_timestamp, b boolean) " + + "using parquet") + sql("insert into t(b) values(false)") + val result = spark.table("t").collect() + assert(result.length == 1) + assert(result(0).getTimestamp(0) != null) + assert(!result(0).getBoolean(1)) + } + // There is a default value with special column name same as another column name + withTable("t") { + sql("create table t(current_date boolean, s date default current_date) " + + "using parquet") + sql("insert into t(current_date) values(false)") + val result = spark.table("t").collect() + assert(result.length == 1) + assert(!result(0).getBoolean(0)) + assert(result(0).getDate(1) != null) + } // There is a complex query plan in the SELECT query in the INSERT INTO statement. withTable("t") { sql("create table t(i boolean default false, s bigint default 42) using parquet")