Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,14 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {
}
matched(ordinal)

// Try to resolve literal functions first (for DefaultValueExpression)
case u @ UnresolvedLiteralFunction(nameParts) => withPosition(u) {
LiteralFunctionResolution.resolve(nameParts).getOrElse(u)
}

case u @ UnresolvedAttribute(nameParts) =>
val result = withPosition(u) {
resolveColumnByName(nameParts)
.orElse(LiteralFunctionResolution.resolve(nameParts))
.map {
// We trim unnecessary alias here. Note that, we cannot trim the alias at top-level,
// as we should resolve `UnresolvedAttribute` to a named expression. The caller side
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,30 @@ object UnresolvedTVFAliases {
}
}

/**
* Holds the name of a literal function that has yet to be resolved.
* Similar to UnresolvedAttribute but specifically for literal functions.
*/
case class UnresolvedLiteralFunction(nameParts: Seq[String]) extends Expression with Unevaluable {

def name: String =
nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".")

override def dataType: DataType = throw new UnresolvedException("dataType")
override def nullable: Boolean = throw new UnresolvedException("nullable")
override lazy val resolved = false
override def children: Seq[Expression] = Nil

final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_LITERAL_FUNCTION)

override def toString: String = s"'$name"

override def sql: String = nameParts.map(quoteIfNeeded).mkString(".")

override protected def withNewChildrenInternal(
newChildren: IndexedSeq[Expression]): Expression = this
}

/**
* Holds the name of an attribute that has yet to be resolved.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2932,9 +2932,9 @@ class AstBuilder extends DataTypeAstBuilder
CurrentUser()
}
} else {
// If the parser is not in ansi mode, we should return `UnresolvedAttribute`, in case there
// are columns named `CURRENT_DATE` or `CURRENT_TIMESTAMP` or `CURRENT_TIME`
UnresolvedAttribute.quoted(ctx.name.getText)
// If the parser is not in ansi mode, we should return `UnresolvedLiteralFunction`,
// in case there are columns named `CURRENT_DATE` or `CURRENT_TIMESTAMP` or `CURRENT_TIME`
UnresolvedLiteralFunction(Seq(ctx.name.getText))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ object TreePattern extends Enumeration {
val UNRESOLVED_DF_STAR: Value = Value
val UNRESOLVED_FUNCTION: Value = Value
val UNRESOLVED_IDENTIFIER: Value = Value
val UNRESOLVED_LITERAL_FUNCTION: Value = Value
val UNRESOLVED_ORDINAL: Value = Value
val UNRESOLVED_PLAN_ID: Value = Value
val UNRESOLVED_WINDOW_EXPRESSION: Value = Value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1590,4 +1590,79 @@ trait AlterTableTests extends SharedSparkSession with QueryErrorsBase {
Row(1, Map(Row(20, null) -> Row("d", null)), "sales")))
}
}

private def checkConflictSpecialColNameResult(table: String): Unit = {
val result = sql(s"SELECT * FROM $table").collect()
assert(result.length == 1)
assert(!result(0).getBoolean(0))
assert(result(0).get(1) != null)
}

test("Add column with special column name default value conflicting with column name") {
withSQLConf(SQLConf.DEFAULT_COLUMN_ALLOWED_PROVIDERS.key -> s"$v2Format, ") {
val t = fullTableName("table_name")
// There is a default value that is a special column name 'current_timestamp'.
withTable(t) {
sql(s"CREATE TABLE $t (i boolean) USING $v2Format")
sql(s"ALTER TABLE $t ADD COLUMN s timestamp DEFAULT current_timestamp")
sql(s"INSERT INTO $t(i) VALUES(false)")
checkConflictSpecialColNameResult(t)
}
// There is a default value with special column name 'current_user' but in uppercase.
withTable(t) {
sql(s"CREATE TABLE $t (i boolean) USING $v2Format")
sql(s"ALTER TABLE $t ADD COLUMN s string DEFAULT CURRENT_USER")
sql(s"INSERT INTO $t(i) VALUES(false)")
checkConflictSpecialColNameResult(t)
}
// There is a default value with special column name same as current column name
withTable(t) {
sql(s"CREATE TABLE $t (b boolean) USING $v2Format")
sql(s"ALTER TABLE $t ADD COLUMN current_timestamp timestamp DEFAULT current_timestamp")
sql(s"INSERT INTO $t(b) VALUES(false)")
checkConflictSpecialColNameResult(t)
}
// There is a default value with special column name same as another column name
withTable(t) {
sql(s"CREATE TABLE $t (current_date boolean) USING $v2Format")
sql(s"ALTER TABLE $t ADD COLUMN s date DEFAULT current_date")
sql(s"INSERT INTO $t(current_date) VALUES(false)")
checkConflictSpecialColNameResult(t)
}
}
}

test("Set default value for existing column conflicting with special column names") {
withSQLConf(SQLConf.DEFAULT_COLUMN_ALLOWED_PROVIDERS.key -> s"$v2Format, ") {
val t = fullTableName("table_name")
// There is a default value that is a special column name 'current_timestamp'.
withTable(t) {
sql(s"CREATE TABLE $t (i boolean, s timestamp) USING $v2Format")
sql(s"ALTER TABLE $t ALTER COLUMN s SET DEFAULT current_timestamp")
sql(s"INSERT INTO $t(i) VALUES(false)")
checkConflictSpecialColNameResult(t)
}
// There is a default value with special column name 'current_user' but in uppercase.
withTable(t) {
sql(s"CREATE TABLE $t (i boolean, s string) USING $v2Format")
sql(s"ALTER TABLE $t ALTER COLUMN s SET DEFAULT CURRENT_USER")
sql(s"INSERT INTO $t(i) VALUES(false)")
checkConflictSpecialColNameResult(t)
}
// There is a default value with special column name same as current column name
withTable(t) {
sql(s"CREATE TABLE $t (b boolean, current_timestamp timestamp) USING $v2Format")
sql(s"ALTER TABLE $t ALTER COLUMN current_timestamp SET DEFAULT current_timestamp")
sql(s"INSERT INTO $t(b) VALUES(false)")
checkConflictSpecialColNameResult(t)
}
// There is a default value with special column name same as another column name
withTable(t) {
sql(s"CREATE TABLE $t (current_date boolean, s date) USING $v2Format")
sql(s"ALTER TABLE $t ALTER COLUMN s SET DEFAULT current_date")
sql(s"INSERT INTO $t(current_date) VALUES(false)")
checkConflictSpecialColNameResult(t)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3870,6 +3870,60 @@ class DataSourceV2SQLSuiteV1Filter
}
}

test("test default value special column name conflicting with real column name") {
val t = "testcat.ns.t"
withTable("t") {
sql(s"""CREATE table $t (
c1 STRING,
current_date DATE DEFAULT CURRENT_DATE,
current_timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
current_time time DEFAULT CURRENT_TIME,
current_user STRING DEFAULT CURRENT_USER,
session_user STRING DEFAULT SESSION_USER,
user STRING DEFAULT USER,
current_database STRING DEFAULT CURRENT_DATABASE(),
current_catalog STRING DEFAULT CURRENT_CATALOG())""")
sql(s"INSERT INTO $t (c1) VALUES ('a')")
val result = sql(s"SELECT * FROM $t").collect()
assert(result.length == 1)
assert(result(0).getString(0) == "a")
Seq(1 to 8: _*).foreach(i => assert(result(0).get(i) != null))
}
}

test("test default value special column name nested in function") {
val t = "testcat.ns.t"
withTable("t") {
sql(s"""CREATE table $t (
c1 STRING,
current_date DATE DEFAULT DATE_ADD(current_date, 7))""")
sql(s"INSERT INTO $t (c1) VALUES ('a')")
val result = sql(s"SELECT * FROM $t").collect()
assert(result.length == 1)
assert(result(0).getString(0) == "a")
}
}

test("test default value should not refer to real column") {
val t = "testcat.ns.t"
withTable("t") {
checkError(
exception = intercept[AnalysisException] {
sql(s"""CREATE table $t (
c1 timestamp,
current_timestamp TIMESTAMP DEFAULT c1)""")
},
condition = "INVALID_DEFAULT_VALUE.NOT_CONSTANT",
parameters = Map(
"statement" -> "CREATE TABLE",
"colName" -> "`current_timestamp`",
"defaultValue" -> "c1"
),
sqlState = Some("42623")
)
}
}

private def testNotSupportedV2Command(
sqlCommand: String,
sqlParams: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,33 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase {
Row(1, 42, "hr") :: Row(2, 2, "software") :: Row(3, 42, "hr") :: Nil)
}

test("update with current_timestamp default value using DEFAULT keyword") {
sql(s"""CREATE TABLE $tableNameAsString
| (pk INT NOT NULL, current_timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP)""".stripMargin)
append("pk INT NOT NULL, current_timestamp TIMESTAMP",
"""{ "pk": 1, "i": false, "current_timestamp": "2023-01-01 10:00:00" }
|{ "pk": 2, "i": true, "current_timestamp": "2023-01-01 11:00:00" }
|""".stripMargin)

val initialResult = sql(s"SELECT * FROM $tableNameAsString").collect()
assert(initialResult.length == 2)
val initialTimestamp1 = initialResult(0).getTimestamp(1)
val initialTimestamp2 = initialResult(1).getTimestamp(1)

sql(s"UPDATE $tableNameAsString SET current_timestamp = DEFAULT WHERE pk = 1")

val updatedResult = sql(s"SELECT * FROM $tableNameAsString").collect()
assert(updatedResult.length == 2)

val updatedRow = updatedResult.find(_.getInt(0) == 1).get
val unchangedRow = updatedResult.find(_.getInt(0) == 2).get

// The timestamp should be different (newer) after the update for pk=1
assert(updatedRow.getTimestamp(1).getTime > initialTimestamp1.getTime)
// The timestamp should remain unchanged for pk=2
assert(unchangedRow.getTimestamp(1).getTime == initialTimestamp2.getTime)
}

test("update char/varchar columns") {
createTable("pk INT NOT NULL, s STRUCT<n_c: CHAR(3), n_vc: VARCHAR(5)>, dep STRING")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1097,6 +1097,44 @@ class InsertSuite extends DataSourceTest with SharedSparkSession {
sql("insert into t select false, default")
checkAnswer(spark.table("t"), Row(false, 42L))
}
// There is a default value that is a special column name 'current_timestamp'.
withTable("t") {
sql("create table t(i boolean, s timestamp default current_timestamp) using parquet")
sql("insert into t(i) values(false)")
val result = spark.table("t").collect()
assert(result.length == 1)
assert(!result(0).getBoolean(0))
assert(result(0).getTimestamp(1) != null)
}
// There is a default value with special column name 'current_user' but in uppercase.
withTable("t") {
sql("create table t(i boolean, s string default CURRENT_USER) using parquet")
sql("insert into t(i) values(false)")
val result = spark.table("t").collect()
assert(result.length == 1)
assert(!result(0).getBoolean(0))
assert(result(0).getString(1) != null)
}
// There is a default value with special column name same as current column name
withTable("t") {
sql("create table t(current_timestamp timestamp default current_timestamp, b boolean) " +
"using parquet")
sql("insert into t(b) values(false)")
val result = spark.table("t").collect()
assert(result.length == 1)
assert(result(0).getTimestamp(0) != null)
assert(!result(0).getBoolean(1))
}
// There is a default value with special column name same as another column name
withTable("t") {
sql("create table t(current_date boolean, s date default current_date) " +
"using parquet")
sql("insert into t(current_date) values(false)")
val result = spark.table("t").collect()
assert(result.length == 1)
assert(!result(0).getBoolean(0))
assert(result(0).getDate(1) != null)
}
// There is a complex query plan in the SELECT query in the INSERT INTO statement.
withTable("t") {
sql("create table t(i boolean default false, s bigint default 42) using parquet")
Expand Down