Skip to content

Commit 46cdc10

Browse files
authored
Merge branch 'main' into angela/emptyjoin
2 parents fd14e9a + 64442f6 commit 46cdc10

File tree

10 files changed

+215
-127
lines changed

10 files changed

+215
-127
lines changed

enginetest/join_op_tests.go

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2047,16 +2047,54 @@ WHERE
20472047
},
20482048
},
20492049
{
2050-
// https://github.com/dolthub/dolt/issues/9782
2051-
name: "joining with subquery on empty table",
2050+
name: "joining on decimals",
20522051
setup: [][]string{
20532052
{
2054-
"CREATE TABLE t(c INT);",
2055-
"INSERT INTO t VALUES (1);",
2053+
"create table t1(c0 decimal(6,3))",
2054+
"create table t2(c0 decimal(5,2))",
2055+
"insert into t1 values (10.000),(20.505),(30.000)",
2056+
"insert into t2 values (20.5), (25.0), (30.0)",
2057+
},
2058+
},
2059+
tests: []JoinOpTests{
2060+
{
2061+
Query: "select * from t1 join t2 on t1.c0 = t2.c0",
2062+
Expected: []sql.Row{{"30.000", "30.00"}},
2063+
},
2064+
},
2065+
},
2066+
{
2067+
// https://github.com/dolthub/dolt/issues/9777
2068+
name: "join with % condition",
2069+
setup: [][]string{
2070+
{
2071+
"create table t1(c0 int)",
2072+
"create table t2(c0 int)",
2073+
"insert into t1 values (1),(2)",
2074+
"insert into t2 values (3),(4)",
20562075
},
20572076
},
20582077
tests: []JoinOpTests{
20592078
{
2079+
Query: "select * from t1 join t2 on (t1.c0 % 2) = (t2.c0 % 2)",
2080+
Expected: []sql.Row{
2081+
{1, 3},
2082+
{2, 4},
2083+
},
2084+
},
2085+
},
2086+
},
2087+
{
2088+
// https://github.com/dolthub/dolt/issues/9782
2089+
name: "joining with subquery on empty table",
2090+
setup: [][]string{
2091+
{
2092+
"CREATE TABLE t(c INT);",
2093+
"INSERT INTO t VALUES (1);",
2094+
},
2095+
},
2096+
tests: []JoinOpTests{
2097+
{
20602098
Query: "SELECT t.c FROM t LEFT JOIN (SELECT t.c FROM t WHERE FALSE) AS subq ON TRUE;",
20612099
Expected: []sql.Row{{1}},
20622100
},
@@ -2071,9 +2109,9 @@ WHERE
20712109
{
20722110
Query: "SELECT t.c FROM (SELECT t.c FROM t WHERE FALSE) AS subq NATURAL RIGHT JOIN t;",
20732111
Expected: []sql.Row{{1}},
2074-
},
2075-
},
2076-
},
2112+
},
2113+
},
2114+
},
20772115
}
20782116

20792117
var rangeJoinOpTests = []JoinOpTests{

enginetest/queries/update_queries.go

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -597,9 +597,9 @@ var UpdateScriptTests = []ScriptTest{
597597
},
598598
},
599599
{
600-
Dialect: "mysql",
601600
// https://github.com/dolthub/dolt/issues/9403
602-
Name: "UPDATE join – multiple tables with same column names with triggers",
601+
Dialect: "mysql",
602+
Name: "UPDATE join – multiple tables with same column names with triggers",
603603
SetUpScript: []string{
604604
"create table customers (id int primary key, name text, tier text)",
605605
"create table orders (id int primary key, customer_id int, status text)",
@@ -632,8 +632,54 @@ var UpdateScriptTests = []ScriptTest{
632632
},
633633
},
634634
{
635-
Name: "UPDATE with subquery in keyless tables",
635+
Dialect: "mysql",
636+
Name: "UPDATE join - conflicting alias in Subquery Alias",
637+
SetUpScript: []string{
638+
"create table parent (id int primary key);",
639+
"insert into parent values (1), (2), (3);",
640+
"create table child (id int primary key, pid int, foreign key (pid) references parent(id), oid int);",
641+
"insert into child values (1, 1, 0), (2, 2, 0), (3, 3, 0);",
642+
},
643+
Assertions: []ScriptTestAssertion{
644+
{
645+
Query: `
646+
update child t1
647+
left join
648+
(
649+
select
650+
t1.id
651+
from
652+
child t1
653+
) sqa
654+
on
655+
t1.id = sqa.id
656+
join
657+
child t2
658+
set
659+
t1.oid = t2.pid;`,
660+
Expected: []sql.Row{
661+
{types.OkResult{
662+
RowsAffected: 3,
663+
Info: plan.UpdateInfo{
664+
Matched: 3,
665+
Updated: 3,
666+
},
667+
}},
668+
},
669+
},
670+
{
671+
Query: "select * from child;",
672+
Expected: []sql.Row{
673+
{1, 1, 1},
674+
{2, 2, 1},
675+
{3, 3, 1},
676+
},
677+
},
678+
},
679+
},
680+
{
636681
// https://github.com/dolthub/dolt/issues/9334
682+
Name: "UPDATE with subquery in keyless tables",
637683
SetUpScript: []string{
638684
"create table t (i int)",
639685
"insert into t values (1)",

sql/analyzer/apply_foreign_keys.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,7 @@ func applyForeignKeysToNodes(ctx *sql.Context, a *Analyzer, n sql.Node, cache *f
127127
fkHandlerMap := make(map[string]sql.Node, len(updateTargets))
128128
for tableName, updateTarget := range updateTargets {
129129
fkHandlerMap[tableName] = updateTarget
130-
fkHandler, err :=
131-
getForeignKeyHandlerFromUpdateTarget(ctx, a, updateTarget, cache, fkChain)
130+
fkHandler, err := getForeignKeyHandlerFromUpdateTarget(ctx, a, updateTarget, cache, fkChain)
132131
if err != nil {
133132
return nil, transform.SameTree, err
134133
}

sql/analyzer/tables.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -98,28 +98,28 @@ func getResolvedTable(node sql.Node) *plan.ResolvedTable {
9898
}
9999

100100
// getTablesByName takes a node and returns all found resolved tables in a map.
101+
// This function will not look inside sql.OpaqueNodes (like plan.SubqueryAlias).
101102
func getTablesByName(node sql.Node) map[string]*plan.ResolvedTable {
102103
ret := make(map[string]*plan.ResolvedTable)
103-
104-
transform.Inspect(node, func(node sql.Node) bool {
104+
// TODO: We should change transform.Inspect to not walk the children of sql.OpaqueNodes (like transform.Node)
105+
// and add a transform.InspectWithOpaque that does.
106+
// Using transform.Node here achieves the same result without a large refactor.
107+
transform.Node(node, func(node sql.Node) (sql.Node, transform.TreeIdentity, error) {
105108
switch n := node.(type) {
106109
case *plan.ResolvedTable:
107110
ret[strings.ToLower(n.Table.Name())] = n
108111
case *plan.IndexedTableAccess:
109112
rt, ok := n.TableNode.(*plan.ResolvedTable)
110113
if ok {
111114
ret[strings.ToLower(rt.Name())] = rt
112-
return false
113115
}
114116
case *plan.TableAlias:
115117
rt := getResolvedTable(n)
116118
if rt != nil {
117119
ret[n.Name()] = rt
118120
}
119-
default:
120121
}
121-
return true
122+
return nil, transform.SameTree, nil
122123
})
123-
124124
return ret
125125
}

sql/expression/in.go

Lines changed: 7 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ package expression
1616

1717
import (
1818
"fmt"
19-
"strconv"
2019

2120
"github.com/dolthub/go-mysql-server/sql"
21+
"github.com/dolthub/go-mysql-server/sql/hash"
2222
"github.com/dolthub/go-mysql-server/sql/types"
2323
)
2424

@@ -106,11 +106,11 @@ func (in *InTuple) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
106106
elType := el.Type()
107107
if types.IsDecimal(elType) || types.IsFloat(elType) {
108108
rtyp := el.Type().Promote()
109-
left, err := convertOrTruncate(ctx, left, rtyp)
109+
left, err := types.ConvertOrTruncate(ctx, left, rtyp)
110110
if err != nil {
111111
return nil, err
112112
}
113-
right, err := convertOrTruncate(ctx, originalRight, rtyp)
113+
right, err := types.ConvertOrTruncate(ctx, originalRight, rtyp)
114114
if err != nil {
115115
return nil, err
116116
}
@@ -119,7 +119,7 @@ func (in *InTuple) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
119119
return nil, err
120120
}
121121
} else {
122-
right, err := convertOrTruncate(ctx, originalRight, typ)
122+
right, err := types.ConvertOrTruncate(ctx, originalRight, typ)
123123
if err != nil {
124124
return nil, err
125125
}
@@ -233,9 +233,9 @@ func newInMap(ctx *sql.Context, right Tuple, lType sql.Type) (map[uint64]sql.Exp
233233

234234
var key uint64
235235
if types.IsDecimal(rType) || types.IsFloat(rType) {
236-
key, err = hashOfSimple(ctx, i, rType)
236+
key, err = hash.HashOfSimple(ctx, i, rType)
237237
} else {
238-
key, err = hashOfSimple(ctx, i, lType)
238+
key, err = hash.HashOfSimple(ctx, i, lType)
239239
}
240240
if err != nil {
241241
return nil, false, err
@@ -246,66 +246,6 @@ func newInMap(ctx *sql.Context, right Tuple, lType sql.Type) (map[uint64]sql.Exp
246246
return elements, hasNull, nil
247247
}
248248

249-
func hashOfSimple(ctx *sql.Context, i interface{}, t sql.Type) (uint64, error) {
250-
if i == nil {
251-
return 0, nil
252-
}
253-
254-
var str string
255-
coll := sql.Collation_Default
256-
if types.IsTuple(t) {
257-
tup := i.([]interface{})
258-
tupType := t.(types.TupleType)
259-
hashes := make([]uint64, len(tup))
260-
for idx, v := range tup {
261-
h, err := hashOfSimple(ctx, v, tupType[idx])
262-
if err != nil {
263-
return 0, err
264-
}
265-
hashes[idx] = h
266-
}
267-
str = fmt.Sprintf("%v", hashes)
268-
} else if types.IsTextOnly(t) {
269-
coll = t.(sql.StringType).Collation()
270-
if s, ok := i.(string); ok {
271-
str = s
272-
} else {
273-
converted, err := convertOrTruncate(ctx, i, t)
274-
if err != nil {
275-
return 0, err
276-
}
277-
str, _, err = sql.Unwrap[string](ctx, converted)
278-
if err != nil {
279-
return 0, err
280-
}
281-
}
282-
} else {
283-
x, err := convertOrTruncate(ctx, i, t.Promote())
284-
if err != nil {
285-
return 0, err
286-
}
287-
288-
// Remove trailing 0s from floats
289-
switch v := x.(type) {
290-
case float32:
291-
str = strconv.FormatFloat(float64(v), 'f', -1, 32)
292-
if str == "-0" {
293-
str = "0"
294-
}
295-
case float64:
296-
str = strconv.FormatFloat(v, 'f', -1, 64)
297-
if str == "-0" {
298-
str = "0"
299-
}
300-
default:
301-
str = fmt.Sprintf("%v", v)
302-
}
303-
}
304-
305-
// Collated strings that are equivalent may have different runes, so we must make them hash to the same value
306-
return coll.HashToUint(str)
307-
}
308-
309249
// Eval implements the Expression interface.
310250
func (hit *HashInTuple) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
311251
leftElems := types.NumColumns(hit.in.Left().Type().Promote())
@@ -319,7 +259,7 @@ func (hit *HashInTuple) Eval(ctx *sql.Context, row sql.Row) (interface{}, error)
319259
return nil, nil
320260
}
321261

322-
key, err := hashOfSimple(ctx, leftVal, hit.in.Left().Type())
262+
key, err := hash.HashOfSimple(ctx, leftVal, hit.in.Left().Type())
323263
if err != nil {
324264
return nil, err
325265
}
@@ -339,43 +279,6 @@ func (hit *HashInTuple) Eval(ctx *sql.Context, row sql.Row) (interface{}, error)
339279
return true, nil
340280
}
341281

342-
// convertOrTruncate converts the value |i| to type |t| and returns the converted value; if the value does not convert
343-
// cleanly and the type is automatically coerced (i.e. string and numeric types), then a warning is logged and the
344-
// value is truncated to the Zero value for type |t|. If the value does not convert and the type is not automatically
345-
// coerced, then an error is returned.
346-
func convertOrTruncate(ctx *sql.Context, i interface{}, t sql.Type) (interface{}, error) {
347-
converted, _, err := t.Convert(ctx, i)
348-
if err == nil {
349-
return converted, nil
350-
}
351-
352-
// If a value can't be converted to an enum or set type, truncate it to a value that is guaranteed
353-
// to not match any enum value.
354-
if types.IsEnum(t) || types.IsSet(t) {
355-
return nil, nil
356-
}
357-
358-
// Values for numeric and string types are automatically coerced. For all other types, if they
359-
// don't convert cleanly, it's an error.
360-
if err != nil && !(types.IsNumber(t) || types.IsTextOnly(t)) {
361-
return nil, err
362-
}
363-
364-
// For numeric and string types, if the value can't be cleanly converted, truncate to the zero value for
365-
// the type and log a warning in the session.
366-
warning := sql.Warning{
367-
Level: "Warning",
368-
Message: fmt.Sprintf("Truncated incorrect %s value: %v", t.String(), i),
369-
Code: 1292,
370-
}
371-
372-
if ctx != nil && ctx.Session != nil {
373-
ctx.Session.Warn(&warning)
374-
}
375-
376-
return t.Zero(), nil
377-
}
378-
379282
func (hit *HashInTuple) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
380283
return hit.in.CollationCoercibility(ctx)
381284
}

0 commit comments

Comments
 (0)