Skip to content
Merged
242 changes: 124 additions & 118 deletions sql/analyzer/costed_index_scan.go

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion sql/analyzer/resolve_column_defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ func normalizeDefault(ctx *sql.Context, colDefault *sql.ColumnDefaultValue) (sql
// serialization before being passed to the integrator for table creation
func skipDefaultNormalizationForType(typ sql.Type) bool {
// Extended types handle their own serialization concerns
if _, ok := typ.(types.ExtendedType); ok {
if _, ok := typ.(sql.ExtendedType); ok {
return true
}
return types.IsTime(typ) || types.IsTimespan(typ) || types.IsEnum(typ) || types.IsSet(typ) || types.IsJSON(typ)
Expand Down
2 changes: 1 addition & 1 deletion sql/hash/hash.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func HashOf(ctx *sql.Context, sch sql.Schema, row sql.Row) (uint64, error) {
}

switch typ := sch[i].Type.(type) {
case types.ExtendedType:
case sql.ExtendedType:
// TODO: Doltgres follows Postgres conventions which don't align with the expectations of MySQL,
// so we're using the old (probably incorrect) behavior for now
_, err = fmt.Fprintf(hash, "%v", v)
Expand Down
43 changes: 43 additions & 0 deletions sql/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,49 @@ type IndexLookup struct {

var emptyLookup = IndexLookup{}

type IndexComparisonExpression interface {
// TODO: IndexScanOp probably needs to be moved into this package as well
IndexScanOperation() (IndexScanOp, Expression, Expression, bool)
}

type IndexScanOp uint8

//go:generate stringer -type=IndexScanOp -linecomment

const (
IndexScanOpEq IndexScanOp = iota // =
IndexScanOpNullSafeEq // <=>
IndexScanOpInSet // =
IndexScanOpNotInSet // !=
IndexScanOpNotEq // !=
IndexScanOpGt // >
IndexScanOpGte // >=
IndexScanOpLt // <
IndexScanOpLte // <=
IndexScanOpAnd // &&
IndexScanOpOr // ||
IndexScanOpIsNull // IS NULL
IndexScanOpIsNotNull // IS NOT NULL
IndexScanOpSpatialEq // SpatialEq
IndexScanOpFulltextEq // FulltextEq
)

// Swap returns the identity op for swapping a comparison's LHS and RHS
func (o IndexScanOp) Swap() IndexScanOp {
switch o {
case IndexScanOpGt:
return IndexScanOpLt
case IndexScanOpGte:
return IndexScanOpLte
case IndexScanOpLt:
return IndexScanOpGt
case IndexScanOpLte:
return IndexScanOpGte
default:
return o
}
}

func NewIndexLookup(idx Index, ranges MySQLRangeCollection, isPointLookup, isEmptyRange, isSpatialLookup, isReverse bool) IndexLookup {
if isReverse {
for i, j := 0, len(ranges)-1; i < j; i, j = i+1, j-1 {
Expand Down
46 changes: 29 additions & 17 deletions sql/index_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,11 @@ func floor(val interface{}) interface{} {
}

// Equals represents colExpr = key. For IN expressions, pass all of them in the same Equals call.
func (b *MySQLIndexBuilder) Equals(ctx *Context, colExpr string, keys ...interface{}) *MySQLIndexBuilder {
func (b *MySQLIndexBuilder) Equals(ctx *Context, colExpr string, keyType Type, keys ...interface{}) *MySQLIndexBuilder {
if b.isInvalid {
return b
}
typ, ok := b.colExprTypes[colExpr]
colTyp, ok := b.colExprTypes[colExpr]
if !ok {
b.isInvalid = true
b.err = ErrInvalidColExpr.New(colExpr, b.idx.ID())
Expand All @@ -117,37 +117,38 @@ func (b *MySQLIndexBuilder) Equals(ctx *Context, colExpr string, keys ...interfa
potentialRanges := make([]MySQLRangeColumnExpr, len(keys))
for i, k := range keys {
// if converting from float to int results in rounding, then it's empty range
if t, ok := typ.(NumberType); ok && !t.IsFloat() {
if t, ok := colTyp.(NumberType); ok && !t.IsFloat() {
f, c := floor(k), ceil(k)
switch k.(type) {
case float32, float64:
if f != c {
potentialRanges[i] = EmptyRangeColumnExpr(typ)
potentialRanges[i] = EmptyRangeColumnExpr(colTyp)
continue
}
case decimal.Decimal:
if !f.(decimal.Decimal).Equals(c.(decimal.Decimal)) {
potentialRanges[i] = EmptyRangeColumnExpr(typ)
potentialRanges[i] = EmptyRangeColumnExpr(colTyp)
continue
}
}
}

var err error
k, _, err = typ.Convert(ctx, k)
k, err = b.convertKey(ctx, colTyp, keyType, k)

if err != nil {
b.isInvalid = true
b.err = err
return b
}
potentialRanges[i] = ClosedRangeColumnExpr(k, k, typ)
potentialRanges[i] = ClosedRangeColumnExpr(k, k, colTyp)
}
b.updateCol(ctx, colExpr, potentialRanges...)
return b
}

// NotEquals represents colExpr <> key.
func (b *MySQLIndexBuilder) NotEquals(ctx *Context, colExpr string, key interface{}) *MySQLIndexBuilder {
func (b *MySQLIndexBuilder) NotEquals(ctx *Context, colExpr string, keyType Type, key interface{}) *MySQLIndexBuilder {
if b.isInvalid {
return b
}
Expand All @@ -172,7 +173,7 @@ func (b *MySQLIndexBuilder) NotEquals(ctx *Context, colExpr string, key interfac
}
}

key, _, err := typ.Convert(ctx, key)
key, err := b.convertKey(ctx, typ, keyType, key)
if err != nil {
b.isInvalid = true
b.err = err
Expand All @@ -197,7 +198,7 @@ func (b *MySQLIndexBuilder) NotEquals(ctx *Context, colExpr string, key interfac
}

// GreaterThan represents colExpr > key.
func (b *MySQLIndexBuilder) GreaterThan(ctx *Context, colExpr string, key interface{}) *MySQLIndexBuilder {
func (b *MySQLIndexBuilder) GreaterThan(ctx *Context, colExpr string, keyType Type, key interface{}) *MySQLIndexBuilder {
if b.isInvalid {
return b
}
Expand All @@ -212,7 +213,7 @@ func (b *MySQLIndexBuilder) GreaterThan(ctx *Context, colExpr string, key interf
key = floor(key)
}

key, _, err := typ.Convert(ctx, key)
key, err := b.convertKey(ctx, typ, keyType, key)
if err != nil {
b.isInvalid = true
b.err = err
Expand All @@ -223,8 +224,18 @@ func (b *MySQLIndexBuilder) GreaterThan(ctx *Context, colExpr string, key interf
return b
}

// convertKey converts the given key from keyType to colType, returning an error if the conversion fails.
func (b *MySQLIndexBuilder) convertKey(ctx *Context, colType Type, keyType Type, key interface{}) (interface{}, error) {
if et, ok := colType.(ExtendedType); ok {
return et.ConvertToType(ctx, keyType.(ExtendedType), key)
} else {
key, _, err := colType.Convert(ctx, key)
return key, err
}
}

// GreaterOrEqual represents colExpr >= key.
func (b *MySQLIndexBuilder) GreaterOrEqual(ctx *Context, colExpr string, key interface{}) *MySQLIndexBuilder {
func (b *MySQLIndexBuilder) GreaterOrEqual(ctx *Context, colExpr string, keyType Type, key interface{}) *MySQLIndexBuilder {
if b.isInvalid {
return b
}
Expand All @@ -247,7 +258,7 @@ func (b *MySQLIndexBuilder) GreaterOrEqual(ctx *Context, colExpr string, key int
key = newKey
}

key, _, err := typ.Convert(ctx, key)
key, err := b.convertKey(ctx, typ, keyType, key)
if err != nil {
b.isInvalid = true
b.err = err
Expand All @@ -266,7 +277,7 @@ func (b *MySQLIndexBuilder) GreaterOrEqual(ctx *Context, colExpr string, key int
}

// LessThan represents colExpr < key.
func (b *MySQLIndexBuilder) LessThan(ctx *Context, colExpr string, key interface{}) *MySQLIndexBuilder {
func (b *MySQLIndexBuilder) LessThan(ctx *Context, colExpr string, keyType Type, key interface{}) *MySQLIndexBuilder {
if b.isInvalid {
return b
}
Expand All @@ -280,7 +291,8 @@ func (b *MySQLIndexBuilder) LessThan(ctx *Context, colExpr string, key interface
if t, ok := typ.(NumberType); ok && !t.IsFloat() {
key = ceil(key)
}
key, _, err := typ.Convert(ctx, key)

key, err := b.convertKey(ctx, typ, keyType, key)
if err != nil {
b.isInvalid = true
b.err = err
Expand All @@ -292,7 +304,7 @@ func (b *MySQLIndexBuilder) LessThan(ctx *Context, colExpr string, key interface
}

// LessOrEqual represents colExpr <= key.
func (b *MySQLIndexBuilder) LessOrEqual(ctx *Context, colExpr string, key interface{}) *MySQLIndexBuilder {
func (b *MySQLIndexBuilder) LessOrEqual(ctx *Context, colExpr string, keyType Type, key interface{}) *MySQLIndexBuilder {
if b.isInvalid {
return b
}
Expand All @@ -315,7 +327,7 @@ func (b *MySQLIndexBuilder) LessOrEqual(ctx *Context, colExpr string, key interf
key = newKey
}

key, _, err := typ.Convert(ctx, key)
key, err := b.convertKey(ctx, typ, keyType, key)
if err != nil {
b.isInvalid = true
b.err = err
Expand Down
38 changes: 19 additions & 19 deletions sql/index_builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,88 +46,88 @@ func TestIndexBuilderRanges(t *testing.T) {
t.Run("IsNull,Equals2=EmptyRange", func(t *testing.T) {
builder := sql.NewMySQLIndexBuilder(testIndex{1})
builder = builder.IsNull(ctx, "column_0")
builder = builder.Equals(ctx, "column_0", 2)
builder = builder.Equals(ctx, "column_0", nil, 2)
ranges := builder.Ranges(ctx)
assert.NotNil(t, ranges)
assert.Equal(t, sql.MySQLRangeCollection{sql.MySQLRange{sql.EmptyRangeColumnExpr(types.Int8)}}, ranges)
})

t.Run("NotEquals2=(NULL,2),(2,Inf)", func(t *testing.T) {
builder := sql.NewMySQLIndexBuilder(testIndex{1})
builder = builder.NotEquals(ctx, "column_0", 2)
builder = builder.NotEquals(ctx, "column_0", nil, 2)
ranges := builder.Ranges(ctx)
assert.NotNil(t, ranges)
assert.Equal(t, sql.MySQLRangeCollection{sql.MySQLRange{sql.GreaterThanRangeColumnExpr(int8(2), types.Int8)}, sql.MySQLRange{sql.LessThanRangeColumnExpr(int8(2), types.Int8)}}, ranges)
})

t.Run("NotEquals2,Equals2=(Inf,Inf)", func(t *testing.T) {
builder := sql.NewMySQLIndexBuilder(testIndex{1})
builder = builder.NotEquals(ctx, "column_0", 2)
builder = builder.Equals(ctx, "column_0", 2)
builder = builder.NotEquals(ctx, "column_0", nil, 2)
builder = builder.Equals(ctx, "column_0", nil, 2)
ranges := builder.Ranges(ctx)
assert.NotNil(t, ranges)
assert.Equal(t, sql.MySQLRangeCollection{sql.MySQLRange{sql.EmptyRangeColumnExpr(types.Int8)}}, ranges)
})

t.Run("Equals2,NotEquals2=(Inf,Inf)", func(t *testing.T) {
builder := sql.NewMySQLIndexBuilder(testIndex{1})
builder = builder.Equals(ctx, "column_0", 2)
builder = builder.NotEquals(ctx, "column_0", 2)
builder = builder.Equals(ctx, "column_0", nil, 2)
builder = builder.NotEquals(ctx, "column_0", nil, 2)
ranges := builder.Ranges(ctx)
assert.NotNil(t, ranges)
assert.Equal(t, sql.MySQLRangeCollection{sql.MySQLRange{sql.EmptyRangeColumnExpr(types.Int8)}}, ranges)
})

t.Run("LT4=(NULL,4)", func(t *testing.T) {
builder := sql.NewMySQLIndexBuilder(testIndex{1})
builder = builder.LessThan(ctx, "column_0", 4)
builder = builder.LessThan(ctx, "column_0", nil, 4)
ranges := builder.Ranges(ctx)
assert.NotNil(t, ranges)
assert.Equal(t, sql.MySQLRangeCollection{sql.MySQLRange{sql.LessThanRangeColumnExpr(int8(4), types.Int8)}}, ranges)
})

t.Run("GT2,LT4=(2,4)", func(t *testing.T) {
builder := sql.NewMySQLIndexBuilder(testIndex{1})
builder = builder.GreaterThan(ctx, "column_0", 2)
builder = builder.LessThan(ctx, "column_0", 4)
builder = builder.GreaterThan(ctx, "column_0", nil, 2)
builder = builder.LessThan(ctx, "column_0", nil, 4)
ranges := builder.Ranges(ctx)
assert.NotNil(t, ranges)
assert.Equal(t, sql.MySQLRangeCollection{sql.MySQLRange{sql.OpenRangeColumnExpr(int8(2), int8(4), types.Int8)}}, ranges)
})

t.Run("GT2,GT6=(4,Inf)", func(t *testing.T) {
builder := sql.NewMySQLIndexBuilder(testIndex{1})
builder = builder.GreaterThan(ctx, "column_0", 2)
builder = builder.GreaterThan(ctx, "column_0", 6)
builder = builder.GreaterThan(ctx, "column_0", nil, 2)
builder = builder.GreaterThan(ctx, "column_0", nil, 6)
ranges := builder.Ranges(ctx)
assert.NotNil(t, ranges)
assert.Equal(t, sql.MySQLRangeCollection{sql.MySQLRange{sql.GreaterThanRangeColumnExpr(int8(6), types.Int8)}}, ranges)
})

t.Run("GT2,LT4,GT6=(Inf,Inf)", func(t *testing.T) {
builder := sql.NewMySQLIndexBuilder(testIndex{1})
builder = builder.GreaterThan(ctx, "column_0", 2)
builder = builder.LessThan(ctx, "column_0", 4)
builder = builder.GreaterThan(ctx, "column_0", 6)
builder = builder.GreaterThan(ctx, "column_0", nil, 2)
builder = builder.LessThan(ctx, "column_0", nil, 4)
builder = builder.GreaterThan(ctx, "column_0", nil, 6)
ranges := builder.Ranges(ctx)
assert.NotNil(t, ranges)
assert.Equal(t, sql.MySQLRangeCollection{sql.MySQLRange{sql.EmptyRangeColumnExpr(types.Int8)}}, ranges)
})

t.Run("NotEqual2,NotEquals4=(2,4),(4,Inf),(NULL,2)", func(t *testing.T) {
builder := sql.NewMySQLIndexBuilder(testIndex{1})
builder = builder.NotEquals(ctx, "column_0", 2)
builder = builder.NotEquals(ctx, "column_0", 4)
builder = builder.NotEquals(ctx, "column_0", nil, 2)
builder = builder.NotEquals(ctx, "column_0", nil, 4)
ranges := builder.Ranges(ctx)
assert.NotNil(t, ranges)
assert.Equal(t, sql.MySQLRangeCollection{sql.MySQLRange{sql.OpenRangeColumnExpr(int8(2), int8(4), types.Int8)}, sql.MySQLRange{sql.GreaterThanRangeColumnExpr(int8(4), types.Int8)}, sql.MySQLRange{sql.LessThanRangeColumnExpr(int8(2), types.Int8)}}, ranges)
})

t.Run("ThreeColumnCombine", func(t *testing.T) {
clauses := make([]sql.MySQLRangeCollection, 3)
clauses[0] = sql.NewMySQLIndexBuilder(testIndex{3}).GreaterOrEqual(ctx, "column_0", 99).LessThan(ctx, "column_1", 66).Ranges(ctx)
clauses[1] = sql.NewMySQLIndexBuilder(testIndex{3}).GreaterOrEqual(ctx, "column_0", 1).LessOrEqual(ctx, "column_0", 47).Ranges(ctx)
clauses[2] = sql.NewMySQLIndexBuilder(testIndex{3}).NotEquals(ctx, "column_0", 2).LessThan(ctx, "column_1", 30).Ranges(ctx)
clauses[0] = sql.NewMySQLIndexBuilder(testIndex{3}).GreaterOrEqual(ctx, "column_0", nil, 99).LessThan(ctx, "column_1", nil, 66).Ranges(ctx)
clauses[1] = sql.NewMySQLIndexBuilder(testIndex{3}).GreaterOrEqual(ctx, "column_0", nil, 1).LessOrEqual(ctx, "column_0", nil, 47).Ranges(ctx)
clauses[2] = sql.NewMySQLIndexBuilder(testIndex{3}).NotEquals(ctx, "column_0", nil, 2).LessThan(ctx, "column_1", nil, 30).Ranges(ctx)
assert.Len(t, clauses[0], 1)
assert.Len(t, clauses[1], 1)
assert.Len(t, clauses[2], 2)
Expand Down
2 changes: 1 addition & 1 deletion sql/analyzer/indexscanop_string.go → sql/indexscanop_string.go
100644 → 100755

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions sql/plan/foreign_key_editor.go
Original file line number Diff line number Diff line change
Expand Up @@ -763,14 +763,14 @@ func GetForeignKeyTypeConversions(
childType := childSch[childIndex].Type
parentType := parentSch[parentIndex].Type

childExtendedType, ok := childType.(types.ExtendedType)
childExtendedType, ok := childType.(sql.ExtendedType)
// if even one of the types is not an extended type, then we can't transform any values
if !ok {
return nil, nil
}

if !childType.Equals(parentType) {
parentExtendedType, ok := parentType.(types.ExtendedType)
parentExtendedType, ok := parentType.(sql.ExtendedType)
if !ok {
// this should be impossible (child and parent should both be extended types), but just in case
return nil, nil
Expand Down
4 changes: 1 addition & 3 deletions sql/rowexec/agg.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ import (
"errors"
"io"

"github.com/dolthub/go-mysql-server/sql/types"

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/expression/function/aggregation"
"github.com/dolthub/go-mysql-server/sql/hash"
Expand Down Expand Up @@ -251,7 +249,7 @@ func (i *groupByGroupingIter) groupingKey(ctx *sql.Context, exprs []sql.Expressi

// TODO: this should be moved into hash.HashOf
typ := expr.Type()
if extTyp, isExtTyp := typ.(types.ExtendedType); isExtTyp {
if extTyp, isExtTyp := typ.(sql.ExtendedType); isExtTyp {
val, vErr := extTyp.SerializeValue(ctx, v)
if vErr != nil {
return 0, vErr
Expand Down
Loading
Loading