diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala index 3224ccafafec3..5f2be9605e86e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -140,10 +140,14 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { } matched(ordinal) + // Try to resolve literal functions first (for DefaultValueExpression) + case u @ UnresolvedLiteralFunction(nameParts) => withPosition(u) { + LiteralFunctionResolution.resolve(nameParts).getOrElse(u) + } + case u @ UnresolvedAttribute(nameParts) => val result = withPosition(u) { resolveColumnByName(nameParts) - .orElse(LiteralFunctionResolution.resolve(nameParts)) .map { // We trim unnecessary alias here. Note that, we cannot trim the alias at top-level, // as we should resolve `UnresolvedAttribute` to a named expression. The caller side diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index b759c70266f7a..5e22fcefa0100 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -270,6 +270,30 @@ object UnresolvedTVFAliases { } } +/** + * Holds the name of a literal function that has yet to be resolved. + * Similar to UnresolvedAttribute but specifically for literal functions. + */ +case class UnresolvedLiteralFunction(nameParts: Seq[String]) extends Expression with Unevaluable { + + def name: String = + nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".") + + override def dataType: DataType = throw new UnresolvedException("dataType") + override def nullable: Boolean = throw new UnresolvedException("nullable") + override lazy val resolved = false + override def children: Seq[Expression] = Nil + + final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_LITERAL_FUNCTION) + + override def toString: String = s"'$name" + + override def sql: String = nameParts.map(quoteIfNeeded).mkString(".") + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): Expression = this +} + /** * Holds the name of an attribute that has yet to be resolved. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index e43e32f04fbf1..e255249ceeed0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -2932,9 +2932,9 @@ class AstBuilder extends DataTypeAstBuilder CurrentUser() } } else { - // If the parser is not in ansi mode, we should return `UnresolvedAttribute`, in case there - // are columns named `CURRENT_DATE` or `CURRENT_TIMESTAMP` or `CURRENT_TIME` - UnresolvedAttribute.quoted(ctx.name.getText) + // If the parser is not in ansi mode, we should return `UnresolvedLiteralFunction`, + // in case there are columns named `CURRENT_DATE` or `CURRENT_TIMESTAMP` or `CURRENT_TIME` + UnresolvedLiteralFunction(Seq(ctx.name.getText)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index c35aa7403d767..90ca42ddd77f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -109,6 +109,7 @@ object TreePattern extends Enumeration { val UNRESOLVED_DF_STAR: Value = Value val UNRESOLVED_FUNCTION: Value = Value val UNRESOLVED_IDENTIFIER: Value = Value + val UNRESOLVED_LITERAL_FUNCTION: Value = Value val UNRESOLVED_ORDINAL: Value = Value val UNRESOLVED_PLAN_ID: Value = Value val UNRESOLVED_WINDOW_EXPRESSION: Value = Value diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala index ea65a580821af..9c1cd1cbb1056 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala @@ -1590,4 +1590,79 @@ trait AlterTableTests extends SharedSparkSession with QueryErrorsBase { Row(1, Map(Row(20, null) -> Row("d", null)), "sales"))) } } + + private def checkConflictSpecialColNameResult(table: String): Unit = { + val result = sql(s"SELECT * FROM $table").collect() + assert(result.length == 1) + assert(!result(0).getBoolean(0)) + assert(result(0).get(1) != null) + } + + test("Add column with special column name default value conflicting with column name") { + withSQLConf(SQLConf.DEFAULT_COLUMN_ALLOWED_PROVIDERS.key -> s"$v2Format, ") { + val t = fullTableName("table_name") + // There is a default value that is a special column name 'current_timestamp'. + withTable(t) { + sql(s"CREATE TABLE $t (i boolean) USING $v2Format") + sql(s"ALTER TABLE $t ADD COLUMN s timestamp DEFAULT current_timestamp") + sql(s"INSERT INTO $t(i) VALUES(false)") + checkConflictSpecialColNameResult(t) + } + // There is a default value with special column name 'current_user' but in uppercase. + withTable(t) { + sql(s"CREATE TABLE $t (i boolean) USING $v2Format") + sql(s"ALTER TABLE $t ADD COLUMN s string DEFAULT CURRENT_USER") + sql(s"INSERT INTO $t(i) VALUES(false)") + checkConflictSpecialColNameResult(t) + } + // There is a default value with special column name same as current column name + withTable(t) { + sql(s"CREATE TABLE $t (b boolean) USING $v2Format") + sql(s"ALTER TABLE $t ADD COLUMN current_timestamp timestamp DEFAULT current_timestamp") + sql(s"INSERT INTO $t(b) VALUES(false)") + checkConflictSpecialColNameResult(t) + } + // There is a default value with special column name same as another column name + withTable(t) { + sql(s"CREATE TABLE $t (current_date boolean) USING $v2Format") + sql(s"ALTER TABLE $t ADD COLUMN s date DEFAULT current_date") + sql(s"INSERT INTO $t(current_date) VALUES(false)") + checkConflictSpecialColNameResult(t) + } + } + } + + test("Set default value for existing column conflicting with special column names") { + withSQLConf(SQLConf.DEFAULT_COLUMN_ALLOWED_PROVIDERS.key -> s"$v2Format, ") { + val t = fullTableName("table_name") + // There is a default value that is a special column name 'current_timestamp'. + withTable(t) { + sql(s"CREATE TABLE $t (i boolean, s timestamp) USING $v2Format") + sql(s"ALTER TABLE $t ALTER COLUMN s SET DEFAULT current_timestamp") + sql(s"INSERT INTO $t(i) VALUES(false)") + checkConflictSpecialColNameResult(t) + } + // There is a default value with special column name 'current_user' but in uppercase. + withTable(t) { + sql(s"CREATE TABLE $t (i boolean, s string) USING $v2Format") + sql(s"ALTER TABLE $t ALTER COLUMN s SET DEFAULT CURRENT_USER") + sql(s"INSERT INTO $t(i) VALUES(false)") + checkConflictSpecialColNameResult(t) + } + // There is a default value with special column name same as current column name + withTable(t) { + sql(s"CREATE TABLE $t (b boolean, current_timestamp timestamp) USING $v2Format") + sql(s"ALTER TABLE $t ALTER COLUMN current_timestamp SET DEFAULT current_timestamp") + sql(s"INSERT INTO $t(b) VALUES(false)") + checkConflictSpecialColNameResult(t) + } + // There is a default value with special column name same as another column name + withTable(t) { + sql(s"CREATE TABLE $t (current_date boolean, s date) USING $v2Format") + sql(s"ALTER TABLE $t ALTER COLUMN s SET DEFAULT current_date") + sql(s"INSERT INTO $t(current_date) VALUES(false)") + checkConflictSpecialColNameResult(t) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 9171e44571e88..c9a03d9a471e8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -3870,6 +3870,60 @@ class DataSourceV2SQLSuiteV1Filter } } + test("test default value special column name conflicting with real column name") { + val t = "testcat.ns.t" + withTable("t") { + sql(s"""CREATE table $t ( + c1 STRING, + current_date DATE DEFAULT CURRENT_DATE, + current_timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + current_time time DEFAULT CURRENT_TIME, + current_user STRING DEFAULT CURRENT_USER, + session_user STRING DEFAULT SESSION_USER, + user STRING DEFAULT USER, + current_database STRING DEFAULT CURRENT_DATABASE(), + current_catalog STRING DEFAULT CURRENT_CATALOG())""") + sql(s"INSERT INTO $t (c1) VALUES ('a')") + val result = sql(s"SELECT * FROM $t").collect() + assert(result.length == 1) + assert(result(0).getString(0) == "a") + Seq(1 to 8: _*).foreach(i => assert(result(0).get(i) != null)) + } + } + + test("test default value special column name nested in function") { + val t = "testcat.ns.t" + withTable("t") { + sql(s"""CREATE table $t ( + c1 STRING, + current_date DATE DEFAULT DATE_ADD(current_date, 7))""") + sql(s"INSERT INTO $t (c1) VALUES ('a')") + val result = sql(s"SELECT * FROM $t").collect() + assert(result.length == 1) + assert(result(0).getString(0) == "a") + } + } + + test("test default value should not refer to real column") { + val t = "testcat.ns.t" + withTable("t") { + checkError( + exception = intercept[AnalysisException] { + sql(s"""CREATE table $t ( + c1 timestamp, + current_timestamp TIMESTAMP DEFAULT c1)""") + }, + condition = "INVALID_DEFAULT_VALUE.NOT_CONSTANT", + parameters = Map( + "statement" -> "CREATE TABLE", + "colName" -> "`current_timestamp`", + "defaultValue" -> "c1" + ), + sqlState = Some("42623") + ) + } + } + private def testNotSupportedV2Command( sqlCommand: String, sqlParams: String, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala index a167e9299a2b2..ac0bf3bdba9ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala @@ -660,6 +660,33 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { Row(1, 42, "hr") :: Row(2, 2, "software") :: Row(3, 42, "hr") :: Nil) } + test("update with current_timestamp default value using DEFAULT keyword") { + sql(s"""CREATE TABLE $tableNameAsString + | (pk INT NOT NULL, current_timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP)""".stripMargin) + append("pk INT NOT NULL, current_timestamp TIMESTAMP", + """{ "pk": 1, "i": false, "current_timestamp": "2023-01-01 10:00:00" } + |{ "pk": 2, "i": true, "current_timestamp": "2023-01-01 11:00:00" } + |""".stripMargin) + + val initialResult = sql(s"SELECT * FROM $tableNameAsString").collect() + assert(initialResult.length == 2) + val initialTimestamp1 = initialResult(0).getTimestamp(1) + val initialTimestamp2 = initialResult(1).getTimestamp(1) + + sql(s"UPDATE $tableNameAsString SET current_timestamp = DEFAULT WHERE pk = 1") + + val updatedResult = sql(s"SELECT * FROM $tableNameAsString").collect() + assert(updatedResult.length == 2) + + val updatedRow = updatedResult.find(_.getInt(0) == 1).get + val unchangedRow = updatedResult.find(_.getInt(0) == 2).get + + // The timestamp should be different (newer) after the update for pk=1 + assert(updatedRow.getTimestamp(1).getTime > initialTimestamp1.getTime) + // The timestamp should remain unchanged for pk=2 + assert(unchangedRow.getTimestamp(1).getTime == initialTimestamp2.getTime) + } + test("update char/varchar columns") { createTable("pk INT NOT NULL, s STRUCT, dep STRING") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 6678f9535fe0d..0e779f9532d63 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -1097,6 +1097,44 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { sql("insert into t select false, default") checkAnswer(spark.table("t"), Row(false, 42L)) } + // There is a default value that is a special column name 'current_timestamp'. + withTable("t") { + sql("create table t(i boolean, s timestamp default current_timestamp) using parquet") + sql("insert into t(i) values(false)") + val result = spark.table("t").collect() + assert(result.length == 1) + assert(!result(0).getBoolean(0)) + assert(result(0).getTimestamp(1) != null) + } + // There is a default value with special column name 'current_user' but in uppercase. + withTable("t") { + sql("create table t(i boolean, s string default CURRENT_USER) using parquet") + sql("insert into t(i) values(false)") + val result = spark.table("t").collect() + assert(result.length == 1) + assert(!result(0).getBoolean(0)) + assert(result(0).getString(1) != null) + } + // There is a default value with special column name same as current column name + withTable("t") { + sql("create table t(current_timestamp timestamp default current_timestamp, b boolean) " + + "using parquet") + sql("insert into t(b) values(false)") + val result = spark.table("t").collect() + assert(result.length == 1) + assert(result(0).getTimestamp(0) != null) + assert(!result(0).getBoolean(1)) + } + // There is a default value with special column name same as another column name + withTable("t") { + sql("create table t(current_date boolean, s date default current_date) " + + "using parquet") + sql("insert into t(current_date) values(false)") + val result = spark.table("t").collect() + assert(result.length == 1) + assert(!result(0).getBoolean(0)) + assert(result(0).getDate(1) != null) + } // There is a complex query plan in the SELECT query in the INSERT INTO statement. withTable("t") { sql("create table t(i boolean default false, s bigint default 42) using parquet")