Skip to content

Commit 64442f6

Browse files
authored
Merge pull request #3191 from dolthub/angela/hashjoin
Use Decimal.String() for key in hashjoin lookups
2 parents d34f9de + cac1486 commit 64442f6

File tree

5 files changed

+150
-108
lines changed

5 files changed

+150
-108
lines changed

enginetest/join_op_tests.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2046,6 +2046,44 @@ WHERE
20462046
},
20472047
},
20482048
},
2049+
{
2050+
name: "joining on decimals",
2051+
setup: [][]string{
2052+
{
2053+
"create table t1(c0 decimal(6,3))",
2054+
"create table t2(c0 decimal(5,2))",
2055+
"insert into t1 values (10.000),(20.505),(30.000)",
2056+
"insert into t2 values (20.5), (25.0), (30.0)",
2057+
},
2058+
},
2059+
tests: []JoinOpTests{
2060+
{
2061+
Query: "select * from t1 join t2 on t1.c0 = t2.c0",
2062+
Expected: []sql.Row{{"30.000", "30.00"}},
2063+
},
2064+
},
2065+
},
2066+
{
2067+
// https://github.com/dolthub/dolt/issues/9777
2068+
name: "join with % condition",
2069+
setup: [][]string{
2070+
{
2071+
"create table t1(c0 int)",
2072+
"create table t2(c0 int)",
2073+
"insert into t1 values (1),(2)",
2074+
"insert into t2 values (3),(4)",
2075+
},
2076+
},
2077+
tests: []JoinOpTests{
2078+
{
2079+
Query: "select * from t1 join t2 on (t1.c0 % 2) = (t2.c0 % 2)",
2080+
Expected: []sql.Row{
2081+
{1, 3},
2082+
{2, 4},
2083+
},
2084+
},
2085+
},
2086+
},
20492087
}
20502088

20512089
var rangeJoinOpTests = []JoinOpTests{

sql/expression/in.go

Lines changed: 7 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ package expression
1616

1717
import (
1818
"fmt"
19-
"strconv"
2019

2120
"github.com/dolthub/go-mysql-server/sql"
21+
"github.com/dolthub/go-mysql-server/sql/hash"
2222
"github.com/dolthub/go-mysql-server/sql/types"
2323
)
2424

@@ -106,11 +106,11 @@ func (in *InTuple) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
106106
elType := el.Type()
107107
if types.IsDecimal(elType) || types.IsFloat(elType) {
108108
rtyp := el.Type().Promote()
109-
left, err := convertOrTruncate(ctx, left, rtyp)
109+
left, err := types.ConvertOrTruncate(ctx, left, rtyp)
110110
if err != nil {
111111
return nil, err
112112
}
113-
right, err := convertOrTruncate(ctx, originalRight, rtyp)
113+
right, err := types.ConvertOrTruncate(ctx, originalRight, rtyp)
114114
if err != nil {
115115
return nil, err
116116
}
@@ -119,7 +119,7 @@ func (in *InTuple) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
119119
return nil, err
120120
}
121121
} else {
122-
right, err := convertOrTruncate(ctx, originalRight, typ)
122+
right, err := types.ConvertOrTruncate(ctx, originalRight, typ)
123123
if err != nil {
124124
return nil, err
125125
}
@@ -233,9 +233,9 @@ func newInMap(ctx *sql.Context, right Tuple, lType sql.Type) (map[uint64]sql.Exp
233233

234234
var key uint64
235235
if types.IsDecimal(rType) || types.IsFloat(rType) {
236-
key, err = hashOfSimple(ctx, i, rType)
236+
key, err = hash.HashOfSimple(ctx, i, rType)
237237
} else {
238-
key, err = hashOfSimple(ctx, i, lType)
238+
key, err = hash.HashOfSimple(ctx, i, lType)
239239
}
240240
if err != nil {
241241
return nil, false, err
@@ -246,66 +246,6 @@ func newInMap(ctx *sql.Context, right Tuple, lType sql.Type) (map[uint64]sql.Exp
246246
return elements, hasNull, nil
247247
}
248248

249-
func hashOfSimple(ctx *sql.Context, i interface{}, t sql.Type) (uint64, error) {
250-
if i == nil {
251-
return 0, nil
252-
}
253-
254-
var str string
255-
coll := sql.Collation_Default
256-
if types.IsTuple(t) {
257-
tup := i.([]interface{})
258-
tupType := t.(types.TupleType)
259-
hashes := make([]uint64, len(tup))
260-
for idx, v := range tup {
261-
h, err := hashOfSimple(ctx, v, tupType[idx])
262-
if err != nil {
263-
return 0, err
264-
}
265-
hashes[idx] = h
266-
}
267-
str = fmt.Sprintf("%v", hashes)
268-
} else if types.IsTextOnly(t) {
269-
coll = t.(sql.StringType).Collation()
270-
if s, ok := i.(string); ok {
271-
str = s
272-
} else {
273-
converted, err := convertOrTruncate(ctx, i, t)
274-
if err != nil {
275-
return 0, err
276-
}
277-
str, _, err = sql.Unwrap[string](ctx, converted)
278-
if err != nil {
279-
return 0, err
280-
}
281-
}
282-
} else {
283-
x, err := convertOrTruncate(ctx, i, t.Promote())
284-
if err != nil {
285-
return 0, err
286-
}
287-
288-
// Remove trailing 0s from floats
289-
switch v := x.(type) {
290-
case float32:
291-
str = strconv.FormatFloat(float64(v), 'f', -1, 32)
292-
if str == "-0" {
293-
str = "0"
294-
}
295-
case float64:
296-
str = strconv.FormatFloat(v, 'f', -1, 64)
297-
if str == "-0" {
298-
str = "0"
299-
}
300-
default:
301-
str = fmt.Sprintf("%v", v)
302-
}
303-
}
304-
305-
// Collated strings that are equivalent may have different runes, so we must make them hash to the same value
306-
return coll.HashToUint(str)
307-
}
308-
309249
// Eval implements the Expression interface.
310250
func (hit *HashInTuple) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
311251
leftElems := types.NumColumns(hit.in.Left().Type().Promote())
@@ -319,7 +259,7 @@ func (hit *HashInTuple) Eval(ctx *sql.Context, row sql.Row) (interface{}, error)
319259
return nil, nil
320260
}
321261

322-
key, err := hashOfSimple(ctx, leftVal, hit.in.Left().Type())
262+
key, err := hash.HashOfSimple(ctx, leftVal, hit.in.Left().Type())
323263
if err != nil {
324264
return nil, err
325265
}
@@ -339,43 +279,6 @@ func (hit *HashInTuple) Eval(ctx *sql.Context, row sql.Row) (interface{}, error)
339279
return true, nil
340280
}
341281

342-
// convertOrTruncate converts the value |i| to type |t| and returns the converted value; if the value does not convert
343-
// cleanly and the type is automatically coerced (i.e. string and numeric types), then a warning is logged and the
344-
// value is truncated to the Zero value for type |t|. If the value does not convert and the type is not automatically
345-
// coerced, then an error is returned.
346-
func convertOrTruncate(ctx *sql.Context, i interface{}, t sql.Type) (interface{}, error) {
347-
converted, _, err := t.Convert(ctx, i)
348-
if err == nil {
349-
return converted, nil
350-
}
351-
352-
// If a value can't be converted to an enum or set type, truncate it to a value that is guaranteed
353-
// to not match any enum value.
354-
if types.IsEnum(t) || types.IsSet(t) {
355-
return nil, nil
356-
}
357-
358-
// Values for numeric and string types are automatically coerced. For all other types, if they
359-
// don't convert cleanly, it's an error.
360-
if err != nil && !(types.IsNumber(t) || types.IsTextOnly(t)) {
361-
return nil, err
362-
}
363-
364-
// For numeric and string types, if the value can't be cleanly converted, truncate to the zero value for
365-
// the type and log a warning in the session.
366-
warning := sql.Warning{
367-
Level: "Warning",
368-
Message: fmt.Sprintf("Truncated incorrect %s value: %v", t.String(), i),
369-
Code: 1292,
370-
}
371-
372-
if ctx != nil && ctx.Session != nil {
373-
ctx.Session.Warn(&warning)
374-
}
375-
376-
return t.Zero(), nil
377-
}
378-
379282
func (hit *HashInTuple) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
380283
return hit.in.CollationCoercibility(ctx)
381284
}

sql/hash/hash.go

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ package hash
1616

1717
import (
1818
"fmt"
19+
"strconv"
1920
"sync"
2021

2122
"github.com/cespare/xxhash/v2"
@@ -41,7 +42,7 @@ func ExprsToSchema(exprs ...sql.Expression) sql.Schema {
4142
return sch
4243
}
4344

44-
// HashOf returns a hash of the given value to be used as key in a cache.
45+
// HashOf returns a hash of the given row to be used as key in a cache.
4546
func HashOf(ctx *sql.Context, sch sql.Schema, row sql.Row) (uint64, error) {
4647
hash := digestPool.Get().(*xxhash.Digest)
4748
hash.Reset()
@@ -97,3 +98,64 @@ func HashOf(ctx *sql.Context, sch sql.Schema, row sql.Row) (uint64, error) {
9798
}
9899
return hash.Sum64(), nil
99100
}
101+
102+
// HashOfSimple returns a hash for a single interface value
103+
func HashOfSimple(ctx *sql.Context, i interface{}, t sql.Type) (uint64, error) {
104+
if i == nil {
105+
return 0, nil
106+
}
107+
108+
var str string
109+
coll := sql.Collation_Default
110+
if types.IsTuple(t) {
111+
tup := i.([]interface{})
112+
tupType := t.(types.TupleType)
113+
hashes := make([]uint64, len(tup))
114+
for idx, v := range tup {
115+
h, err := HashOfSimple(ctx, v, tupType[idx])
116+
if err != nil {
117+
return 0, err
118+
}
119+
hashes[idx] = h
120+
}
121+
str = fmt.Sprintf("%v", hashes)
122+
} else if types.IsTextOnly(t) {
123+
coll = t.(sql.StringType).Collation()
124+
if s, ok := i.(string); ok {
125+
str = s
126+
} else {
127+
converted, err := types.ConvertOrTruncate(ctx, i, t)
128+
if err != nil {
129+
return 0, err
130+
}
131+
str, _, err = sql.Unwrap[string](ctx, converted)
132+
if err != nil {
133+
return 0, err
134+
}
135+
}
136+
} else {
137+
x, err := types.ConvertOrTruncate(ctx, i, t.Promote())
138+
if err != nil {
139+
return 0, err
140+
}
141+
142+
// Remove trailing 0s from floats
143+
switch v := x.(type) {
144+
case float32:
145+
str = strconv.FormatFloat(float64(v), 'f', -1, 32)
146+
if str == "-0" {
147+
str = "0"
148+
}
149+
case float64:
150+
str = strconv.FormatFloat(v, 'f', -1, 64)
151+
if str == "-0" {
152+
str = "0"
153+
}
154+
default:
155+
str = fmt.Sprintf("%v", v)
156+
}
157+
}
158+
159+
// Collated strings that are equivalent may have different runes, so we must make them hash to the same value
160+
return coll.HashToUint(str)
161+
}

sql/plan/hash_lookup.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,8 @@ func (n *HashLookup) GetHashKey(ctx *sql.Context, e sql.Expression, row sql.Row)
122122
if err != nil {
123123
return nil, err
124124
}
125-
key, _, err = n.LeftProbeKey.Type().Convert(ctx, key)
125+
typ := n.LeftProbeKey.Type()
126+
key, _, err = typ.Convert(ctx, key)
126127
if types.ErrValueNotNil.Is(err) {
127128
// The LHS expression was NullType. This is allowed.
128129
return nil, nil
@@ -135,9 +136,10 @@ func (n *HashLookup) GetHashKey(ctx *sql.Context, e sql.Expression, row sql.Row)
135136
}
136137
// byte slices are not hashable
137138
if k, ok := key.([]byte); ok {
138-
key = string(k)
139+
return string(k), nil
139140
}
140-
return key, nil
141+
142+
return hash.HashOfSimple(ctx, key, typ)
141143
}
142144

143145
func (n *HashLookup) Dispose() {

sql/types/conversion.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -762,3 +762,40 @@ func TypeAwareConversion(ctx *sql.Context, val interface{}, originalType sql.Typ
762762
}
763763
return convertedType.Convert(ctx, val)
764764
}
765+
766+
// ConvertOrTruncate converts the value |i| to type |t| and returns the converted value; if the value does not convert
767+
// cleanly and the type is automatically coerced (i.e. string and numeric types), then a warning is logged and the
768+
// value is truncated to the Zero value for type |t|. If the value does not convert and the type is not automatically
769+
// coerced, then an error is returned.
770+
func ConvertOrTruncate(ctx *sql.Context, i interface{}, t sql.Type) (interface{}, error) {
771+
converted, _, err := t.Convert(ctx, i)
772+
if err == nil {
773+
return converted, nil
774+
}
775+
776+
// If a value can't be converted to an enum or set type, truncate it to a value that is guaranteed
777+
// to not match any enum value.
778+
if IsEnum(t) || IsSet(t) {
779+
return nil, nil
780+
}
781+
782+
// Values for numeric and string types are automatically coerced. For all other types, if they
783+
// don't convert cleanly, it's an error.
784+
if err != nil && !(IsNumber(t) || IsTextOnly(t)) {
785+
return nil, err
786+
}
787+
788+
// For numeric and string types, if the value can't be cleanly converted, truncate to the zero value for
789+
// the type and log a warning in the session.
790+
warning := sql.Warning{
791+
Level: "Warning",
792+
Message: fmt.Sprintf("Truncated incorrect %s value: %v", t.String(), i),
793+
Code: 1292,
794+
}
795+
796+
if ctx != nil && ctx.Session != nil {
797+
ctx.Session.Warn(&warning)
798+
}
799+
800+
return t.Zero(), nil
801+
}

0 commit comments

Comments
 (0)