diff --git a/enginetest/queries/script_queries.go b/enginetest/queries/script_queries.go index f8883ba7ec..540016559c 100644 --- a/enginetest/queries/script_queries.go +++ b/enginetest/queries/script_queries.go @@ -8477,6 +8477,56 @@ where }, }, }, + { + // This is a script test here because every table in the harness setup data is in all lowercase + Name: "case insensitive update with insubqueries and update joins", + Dialect: "mysql", + SetUpScript: []string{ + "create table MiXeDcAsE (i int primary key, j int)", + "insert into mixedcase values (1, 1);", + "insert into mixedcase values (2, 2);", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "update mixedcase set j = 999 where i in (select 1)", + Expected: []sql.Row{ + {types.OkResult{ + RowsAffected: 1, + Info: plan.UpdateInfo{ + Matched: 1, + Updated: 1, + }, + }}, + }, + }, + { + Query: "select * from mixedcase;", + Expected: []sql.Row{ + {1, 999}, + {2, 2}, + }, + }, + { + Query: " with cte(x) as (select 2) update mixedcase set j = 999 where i in (select x from cte)", + Expected: []sql.Row{ + {types.OkResult{ + RowsAffected: 1, + Info: plan.UpdateInfo{ + Matched: 1, + Updated: 1, + }, + }}, + }, + }, + { + Query: "select * from mixedcase;", + Expected: []sql.Row{ + {1, 999}, + {2, 999}, + }, + }, + }, + }, } var SpatialScriptTests = []ScriptTest{ diff --git a/sql/plan/update_join.go b/sql/plan/update_join.go index f2c6e8ff30..814e953a26 100644 --- a/sql/plan/update_join.go +++ b/sql/plan/update_join.go @@ -15,6 +15,8 @@ package plan import ( + "strings" + "github.com/dolthub/go-mysql-server/sql" ) @@ -232,6 +234,7 @@ func SplitRowIntoTableRowMap(row sql.Row, joinSchema sql.Schema) map[string]sql. } } + currentTable = strings.ToLower(currentTable) ret[currentTable] = currentRow return ret diff --git a/sql/rowexec/update.go b/sql/rowexec/update.go index 91ad67ddb5..2c4cf4eff1 100644 --- a/sql/rowexec/update.go +++ b/sql/rowexec/update.go @@ -17,6 +17,7 @@ package rowexec import ( "errors" "fmt" + "strings" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/plan" @@ -230,7 +231,7 @@ func (u *updateJoinIter) Next(ctx *sql.Context) (sql.Row, error) { tableToNewRowMap := plan.SplitRowIntoTableRowMap(newJoinRow, u.joinSchema) for tableName, _ := range u.updaters { - oldTableRow := tableToOldRowMap[tableName] + oldTableRow := tableToOldRowMap[strings.ToLower(tableName)] // Handle the case of row being ignored due to it not being valid in the join row. if isRightOrLeftJoin(u.joinNode) { @@ -388,15 +389,14 @@ func recreateRowFromMap(rowMap map[string]sql.Row, joinSchema sql.Schema) sql.Ro return ret } - currentTable := joinSchema[0].Source + currentTable := strings.ToLower(joinSchema[0].Source) ret = append(ret, rowMap[currentTable]...) for i := 1; i < len(joinSchema); i++ { - c := joinSchema[i] - - if c.Source != currentTable { - ret = append(ret, rowMap[c.Source]...) - currentTable = c.Source + newTable := strings.ToLower(joinSchema[i].Source) + if !strings.EqualFold(newTable, currentTable) { + ret = append(ret, rowMap[newTable]...) + currentTable = newTable } }