Skip to content

Commit d171439

Browse files
author
James Cor
committed
partial fix
1 parent 9688af8 commit d171439

File tree

4 files changed

+57
-33
lines changed

4 files changed

+57
-33
lines changed

enginetest/memory_engine_test.go

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -202,23 +202,18 @@ func TestSingleQueryPrepared(t *testing.T) {
202202

203203
// Convenience test for debugging a single query. Unskip and set to the desired query.
204204
func TestSingleScript(t *testing.T) {
205-
t.Skip()
205+
//t.Skip()
206206
var scripts = []queries.ScriptTest{
207207
{
208-
Name: "AS OF propagates to nested CALLs",
209-
SetUpScript: []string{},
208+
Name: "test script",
209+
SetUpScript: []string{
210+
"CREATE TABLE test (pk INTEGER PRIMARY KEY, name TEXT NOT NULL) COLLATE=utf8mb4_0900_ai_ci;",
211+
"INSERT INTO test VALUES (1, 'aBcDeF');",
212+
},
210213
Assertions: []queries.ScriptTestAssertion{
211214
{
212-
Query: "create procedure create_proc() create table t (i int primary key, j int);",
213-
Expected: []sql.Row{
214-
{types.NewOkResult(0)},
215-
},
216-
},
217-
{
218-
Query: "call create_proc()",
219-
Expected: []sql.Row{
220-
{types.NewOkResult(0)},
221-
},
215+
Query: "select 'abcdef' in (select name from test)",
216+
Expected: []sql.Row{},
222217
},
223218
},
224219
},

sql/cache.go

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import (
2626
)
2727

2828
// HashOf returns a hash of the given value to be used as key in a cache.
29-
func HashOf(ctx context.Context, v Row) (uint64, error) {
29+
func HashOf(sch Schema, v Row) (uint64, error) {
3030
hash := digestPool.Get().(*xxhash.Digest)
3131
hash.Reset()
3232
defer digestPool.Put(hash)
@@ -37,10 +37,17 @@ func HashOf(ctx context.Context, v Row) (uint64, error) {
3737
return 0, err
3838
}
3939
}
40-
x, err := UnwrapAny(ctx, x)
41-
if err != nil {
42-
return 0, err
40+
41+
if i < len(sch) {
42+
typ := sch[i].Type
43+
if strType, ok := typ.(StringType); ok {
44+
strType.Convert(nil, )
45+
strType.Collation().WriteWeightString(hash, )
46+
47+
}
48+
continue
4349
}
50+
4451
// TODO: probably much faster to do this with a type switch
4552
// TODO: we don't have the type info necessary to appropriately encode the value of a string with a non-standard
4653
// collation, which means that two strings that differ only in their collations will hash to the same value.

sql/plan/insubquery.go

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ package plan
1616

1717
import (
1818
"fmt"
19+
"github.com/cespare/xxhash/v2"
1920

2021
"github.com/dolthub/go-mysql-server/sql"
2122
"github.com/dolthub/go-mysql-server/sql/expression"
@@ -75,7 +76,7 @@ func (in *InSubquery) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
7576
return nil, sql.ErrInvalidOperandColumns.New(types.NumColumns(typ), types.NumColumns(right.Type()))
7677
}
7778

78-
typ := right.Type()
79+
rTyp := right.Type()
7980

8081
values, err := right.HashMultiple(ctx, row)
8182
if err != nil {
@@ -91,11 +92,24 @@ func (in *InSubquery) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
9192
}
9293

9394
// convert left to right's type
94-
nLeft, _, err := typ.Convert(ctx, left)
95+
nLeft, _, err := rTyp.Convert(ctx, left)
9596
if err != nil {
9697
return false, nil
9798
}
9899

100+
if strTyp, ok := rTyp.(sql.StringType); ok {
101+
weightStr := xxhash.New()
102+
valStr, err := types.ConvertToString(ctx, nLeft, strTyp, nil)
103+
if err != nil {
104+
return nil, err
105+
}
106+
err = strTyp.Collation().WriteWeightString(weightStr, valStr)
107+
if err != nil {
108+
return nil, err
109+
}
110+
nLeft = weightStr
111+
}
112+
99113
key, err := sql.HashOf(ctx, sql.NewRow(nLeft))
100114
if err != nil {
101115
return nil, err
@@ -109,12 +123,12 @@ func (in *InSubquery) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
109123
return false, nil
110124
}
111125

112-
val, _, err = typ.Convert(ctx, val)
126+
val, _, err = rTyp.Convert(ctx, val)
113127
if err != nil {
114128
return false, nil
115129
}
116130

117-
cmp, err := typ.Compare(ctx, left, val)
131+
cmp, err := rTyp.Compare(ctx, left, val)
118132
if err != nil {
119133
return nil, err
120134
}

sql/plan/subquery.go

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ func (m *Max1Row) CollationCoercibility(ctx *sql.Context) (collation sql.Collati
313313
}
314314

315315
// EvalMultiple returns all rows returned by a subquery.
316-
func (s *Subquery) EvalMultiple(ctx *sql.Context, row sql.Row) ([]interface{}, error) {
316+
func (s *Subquery) EvalMultiple(ctx *sql.Context, row sql.Row) (sql.Row, error) {
317317
s.cacheMu.Lock()
318318
cached := s.resultsCached
319319
s.cacheMu.Unlock()
@@ -341,7 +341,7 @@ func (s *Subquery) canCacheResults() bool {
341341
return s.correlated.Empty() && !s.volatile
342342
}
343343

344-
func (s *Subquery) evalMultiple(ctx *sql.Context, row sql.Row) ([]interface{}, error) {
344+
func (s *Subquery) evalMultiple(ctx *sql.Context, row sql.Row) (sql.Row, error) {
345345
// Any source of rows, as well as any node that alters the schema of its children, needs to be wrapped so that its
346346
// result rows are prepended with the scope row.
347347
q, _, err := transform.Node(s.Query, PrependRowInPlan(row, false))
@@ -362,7 +362,7 @@ func (s *Subquery) evalMultiple(ctx *sql.Context, row sql.Row) ([]interface{}, e
362362

363363
// Reduce the result row to the size of the expected schema. This means chopping off the first len(row) columns.
364364
col := len(row)
365-
var result []interface{}
365+
var result sql.Row
366366
for {
367367
row, err := iter.Next(ctx)
368368
if err == io.EOF {
@@ -407,7 +407,7 @@ func (s *Subquery) HashMultiple(ctx *sql.Context, row sql.Row) (sql.KeyValueCach
407407
defer s.cacheMu.Unlock()
408408
if !s.resultsCached || s.hashCache == nil {
409409
hashCache, disposeFn := ctx.Memory.NewHistoryCache()
410-
err = putAllRows(ctx, hashCache, result)
410+
err = putAllRows(ctx, hashCache, s.Query.Schema(), result)
411411
if err != nil {
412412
return nil, err
413413
}
@@ -417,7 +417,11 @@ func (s *Subquery) HashMultiple(ctx *sql.Context, row sql.Row) (sql.KeyValueCach
417417
}
418418

419419
cache := sql.NewMapCache()
420-
return cache, putAllRows(ctx, cache, result)
420+
err = putAllRows(ctx, cache, s.Query.Schema(), result)
421+
if err != nil {
422+
return nil, err
423+
}
424+
return cache, nil
421425
}
422426

423427
// HasResultRow returns whether the subquery has a result set > 0.
@@ -467,21 +471,25 @@ func (s *Subquery) HasResultRow(ctx *sql.Context, row sql.Row) (bool, error) {
467471
// normalizeValue returns a canonical version of a value for use in a sql.KeyValueCache.
468472
// Two values that compare equal should have the same canonical version.
469473
// TODO: Fix https://github.com/dolthub/dolt/issues/9049 by making this function collation-aware
470-
func normalizeForKeyValueCache(ctx *sql.Context, val interface{}) (interface{}, error) {
471-
return sql.UnwrapAny(ctx, val)
474+
func normalizeForKeyValueCache(ctx *sql.Context, typ sql.Type, val interface{}) (interface{}, error) {
475+
val, err := sql.UnwrapAny(ctx, val)
476+
if err != nil {
477+
return nil, err
478+
}
479+
return val, nil
472480
}
473481

474-
func putAllRows(ctx *sql.Context, cache sql.KeyValueCache, vals []interface{}) error {
475-
for _, val := range vals {
476-
val, err := normalizeForKeyValueCache(ctx, val)
482+
func putAllRows(ctx *sql.Context, cache sql.KeyValueCache, sch sql.Schema, vals []interface{}) error {
483+
for i, val := range vals {
484+
normVal, err := normalizeForKeyValueCache(ctx, sch[i].Type, val)
477485
if err != nil {
478486
return err
479487
}
480-
rowKey, err := sql.HashOf(ctx, sql.NewRow(val))
488+
rowKey, err := sql.HashOf(ctx, sql.NewRow(normVal))
481489
if err != nil {
482490
return err
483491
}
484-
err = cache.Put(rowKey, val)
492+
err = cache.Put(rowKey, normVal)
485493
if err != nil {
486494
return err
487495
}

0 commit comments

Comments
 (0)