Skip to content

Commit 63f2fc6

Browse files
author
James Cor
committed
refactor compare logic in to types package and fix hashintuple logic
1 parent 8ea72b0 commit 63f2fc6

File tree

3 files changed

+123
-83
lines changed

3 files changed

+123
-83
lines changed

sql/expression/in.go

Lines changed: 63 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,8 @@ func NewNotInTuple(left sql.Expression, right sql.Expression) sql.Expression {
154154
// HashInTuple is an expression that checks an expression is inside a list of expressions using a hashmap.
155155
type HashInTuple struct {
156156
in *InTuple
157-
cmp map[uint64]sql.Expression
157+
cmp map[uint64]struct{}
158+
cmpType sql.Type
158159
hasNull bool
159160
}
160161

@@ -169,90 +170,110 @@ func NewHashInTuple(ctx *sql.Context, left, right sql.Expression) (*HashInTuple,
169170
return nil, ErrUnsupportedInOperand.New(right)
170171
}
171172

172-
cmp, hasNull, err := newInMap(ctx, rightTup, left.Type())
173+
cmp, cmpType, hasNull, err := newInMap(ctx, left.Type(), rightTup)
173174
if err != nil {
174175
return nil, err
175176
}
176177

177-
return &HashInTuple{in: NewInTuple(left, right), cmp: cmp, hasNull: hasNull}, nil
178+
return &HashInTuple{
179+
in: NewInTuple(left, right),
180+
cmp: cmp,
181+
cmpType: cmpType,
182+
hasNull: hasNull,
183+
}, nil
178184
}
179185

180186
// newInMap hashes static expressions in the right child Tuple of a InTuple node
181-
func newInMap(ctx *sql.Context, right Tuple, lType sql.Type) (map[uint64]sql.Expression, bool, error) {
187+
func newInMap(ctx *sql.Context, lType sql.Type, right Tuple) (map[uint64]struct{}, sql.Type, bool, error) {
182188
if lType == types.Null {
183-
return nil, true, nil
189+
return nil, nil, true, nil
190+
}
191+
if len(right) == 0 {
192+
return nil, nil, false, nil
184193
}
185194

186-
elements := make(map[uint64]sql.Expression)
187-
hasNull := false
195+
// If left is StringType and ANY of the right is NumberType, then we should use Double Type for comparison
196+
// If left is NumberType and ANT of the left is StringType, then we should use Double Type for comparison
188197
lColumnCount := types.NumColumns(lType)
189-
190-
for _, el := range right {
191-
rType := el.Type().Promote()
198+
lIsNumType := types.IsNumber(lType)
199+
lIsStrType := types.IsText(lType)
200+
var rHasNumType, rHasStrType, rHasNull bool
201+
rVals := make([]any, len(right))
202+
for i, el := range right {
203+
rType := el.Type()
204+
205+
// Nested tuples must have the same number of columns
192206
rColumnCount := types.NumColumns(rType)
193207
if rColumnCount != lColumnCount {
194-
return nil, false, sql.ErrInvalidOperandColumns.New(lColumnCount, rColumnCount)
208+
return nil, nil, false, sql.ErrInvalidOperandColumns.New(lColumnCount, rColumnCount)
195209
}
196210

197-
if rType == types.Null {
198-
hasNull = true
211+
if types.IsNumber(rType) {
212+
rHasNumType = true
213+
} else if types.IsText(rType) {
214+
rHasStrType = true
215+
}
216+
217+
// Null elements are not hashed into the Tuple Map
218+
if types.IsNullType(rType) {
219+
rHasNull = true
199220
continue
200221
}
201-
i, err := el.Eval(ctx, sql.Row{})
222+
v, err := el.Eval(ctx, sql.Row{})
202223
if err != nil {
203-
return nil, hasNull, err
224+
return nil, nil, false, err
204225
}
205-
if i == nil {
206-
hasNull = true
226+
if v == nil {
227+
rHasNull = true
207228
continue
208229
}
209230

210-
var key uint64
211-
if types.IsDecimal(rType) || types.IsFloat(rType) {
212-
key, err = hash.HashOfSimple(ctx, i, rType)
213-
} else {
214-
key, err = hash.HashOfSimple(ctx, i, lType)
215-
}
231+
rVals[i] = v
232+
}
233+
234+
var cmpType sql.Type
235+
if (lIsStrType && rHasNumType) || (lIsNumType && rHasStrType) {
236+
cmpType = types.Float64
237+
} else if types.IsEnum(lType) || types.IsSet(lType) || types.IsText(lType) {
238+
cmpType = lType
239+
} else {
240+
cmpType = types.GetCompareType(lType, right[0].Type())
241+
}
242+
243+
elements := make(map[uint64]struct{})
244+
for _, v := range rVals {
245+
key, err := hash.HashOfSimple(ctx, v, cmpType)
216246
if err != nil {
217-
return nil, false, err
247+
return nil, nil, false, err
218248
}
219-
elements[key] = el
249+
elements[key] = struct{}{}
220250
}
221-
222-
return elements, hasNull, nil
251+
return elements, cmpType, rHasNull, nil
223252
}
224253

225254
// Eval implements the Expression interface.
226255
func (hit *HashInTuple) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
227-
leftElems := types.NumColumns(hit.in.Left().Type().Promote())
228-
229256
leftVal, err := hit.in.Left().Eval(ctx, row)
230257
if err != nil {
231258
return nil, err
232259
}
233-
234260
if leftVal == nil {
235261
return nil, nil
236262
}
237263

238-
key, err := hash.HashOfSimple(ctx, leftVal, hit.in.Left().Type())
264+
// TODO: this needs to pick the same type as right... but there are multiple possibilities??
265+
key, err := hash.HashOfSimple(ctx, leftVal, hit.cmpType)
239266
if err != nil {
240267
return nil, err
241268
}
242269

243-
right, ok := hit.cmp[key]
244-
if !ok {
245-
if hit.hasNull {
246-
return nil, nil
247-
}
248-
return false, nil
270+
if _, ok := hit.cmp[key]; ok {
271+
return true, nil
249272
}
250-
251-
if types.NumColumns(right.Type().Promote()) != leftElems {
252-
return nil, sql.ErrInvalidOperandColumns.New(leftElems, types.NumColumns(right.Type().Promote()))
273+
if hit.hasNull {
274+
return nil, nil
253275
}
254-
255-
return true, nil
276+
return false, nil
256277
}
257278

258279
func (hit *HashInTuple) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {

sql/plan/hash_lookup.go

Lines changed: 1 addition & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ import (
3434
// simply delegates to the child.
3535
func NewHashLookup(n sql.Node, rightEntryKey sql.Expression, leftProbeKey sql.Expression, joinType JoinType) *HashLookup {
3636
leftKeySch := hash.ExprsToSchema(leftProbeKey)
37-
compareType := GetCompareType(leftProbeKey.Type(), rightEntryKey.Type())
37+
compareType := types.GetCompareType(leftProbeKey.Type(), rightEntryKey.Type())
3838
return &HashLookup{
3939
UnaryNode: UnaryNode{n},
4040
RightEntryKey: rightEntryKey,
@@ -61,46 +61,6 @@ var _ sql.Node = (*HashLookup)(nil)
6161
var _ sql.Expressioner = (*HashLookup)(nil)
6262
var _ sql.CollationCoercible = (*HashLookup)(nil)
6363

64-
// GetCompareType returns the type to use when comparing values of types left and right.
65-
func GetCompareType(left, right sql.Type) sql.Type {
66-
// TODO: much of this logic is very similar to castLeftAndRight() from sql/expression/comparison.go
67-
// consider consolidating
68-
if left.Equals(right) {
69-
return left
70-
}
71-
if types.IsTuple(left) && types.IsTuple(right) {
72-
return left
73-
}
74-
if types.IsTime(left) || types.IsTime(right) {
75-
return types.DatetimeMaxPrecision
76-
}
77-
if types.IsJSON(left) || types.IsJSON(right) {
78-
return types.JSON
79-
}
80-
if types.IsBinaryType(left) || types.IsBinaryType(right) {
81-
return types.LongBlob
82-
}
83-
if types.IsNumber(left) || types.IsNumber(right) {
84-
if types.IsDecimal(left) {
85-
return left
86-
}
87-
if types.IsDecimal(right) {
88-
return right
89-
}
90-
if types.IsFloat(left) || types.IsFloat(right) {
91-
return types.Float64
92-
}
93-
if types.IsSigned(left) && types.IsSigned(right) {
94-
return types.Int64
95-
}
96-
if types.IsUnsigned(left) && types.IsUnsigned(right) {
97-
return types.Uint64
98-
}
99-
return types.Float64
100-
}
101-
return types.LongText
102-
}
103-
10464
func (n *HashLookup) Expressions() []sql.Expression {
10565
return []sql.Expression{n.RightEntryKey, n.LeftProbeKey}
10666
}

sql/types/utils.go

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
// Copyright 2025 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package types
16+
17+
import (
18+
"github.com/dolthub/go-mysql-server/sql"
19+
)
20+
21+
// GetCompareType returns the type to use when comparing values of types left and right.
22+
func GetCompareType(left, right sql.Type) sql.Type {
23+
// TODO: much of this logic is very similar to castLeftAndRight() from sql/expression/comparison.go
24+
// consider consolidating
25+
if left.Equals(right) {
26+
return left
27+
}
28+
if IsTuple(left) && IsTuple(right) {
29+
return left
30+
}
31+
if IsTime(left) || IsTime(right) {
32+
return DatetimeMaxPrecision
33+
}
34+
if IsJSON(left) || IsJSON(right) {
35+
return JSON
36+
}
37+
if IsBinaryType(left) || IsBinaryType(right) {
38+
return LongBlob
39+
}
40+
if IsNumber(left) || IsNumber(right) {
41+
if IsDecimal(left) {
42+
return left
43+
}
44+
if IsDecimal(right) {
45+
return right
46+
}
47+
if IsFloat(left) || IsFloat(right) {
48+
return Float64
49+
}
50+
if IsSigned(left) && IsSigned(right) {
51+
return Int64
52+
}
53+
if IsUnsigned(left) && IsUnsigned(right) {
54+
return Uint64
55+
}
56+
return Float64
57+
}
58+
return LongText
59+
}

0 commit comments

Comments
 (0)