Skip to content

Commit 7db5799

Browse files
authored
Merge pull request #3190 from dolthub/zachmu/index-scans-pgcatalog
Refactorings to support index scans for pg catalog tables
2 parents db805db + 311f976 commit 7db5799

File tree

12 files changed

+257
-199
lines changed

12 files changed

+257
-199
lines changed

sql/analyzer/costed_index_scan.go

Lines changed: 124 additions & 118 deletions
Large diffs are not rendered by default.

sql/analyzer/resolve_column_defaults.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,7 @@ func normalizeDefault(ctx *sql.Context, colDefault *sql.ColumnDefaultValue) (sql
479479
// serialization before being passed to the integrator for table creation
480480
func skipDefaultNormalizationForType(typ sql.Type) bool {
481481
// Extended types handle their own serialization concerns
482-
if _, ok := typ.(types.ExtendedType); ok {
482+
if _, ok := typ.(sql.ExtendedType); ok {
483483
return true
484484
}
485485
return types.IsTime(typ) || types.IsTimespan(typ) || types.IsEnum(typ) || types.IsSet(typ) || types.IsJSON(typ)

sql/hash/hash.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ func HashOf(ctx *sql.Context, sch sql.Schema, row sql.Row) (uint64, error) {
7171
}
7272

7373
switch typ := sch[i].Type.(type) {
74-
case types.ExtendedType:
74+
case sql.ExtendedType:
7575
// TODO: Doltgres follows Postgres conventions which don't align with the expectations of MySQL,
7676
// so we're using the old (probably incorrect) behavior for now
7777
_, err = fmt.Fprintf(hash, "%v", v)

sql/index.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,49 @@ type IndexLookup struct {
166166

167167
var emptyLookup = IndexLookup{}
168168

169+
type IndexComparisonExpression interface {
170+
// TODO: IndexScanOp probably needs to be moved into this package as well
171+
IndexScanOperation() (IndexScanOp, Expression, Expression, bool)
172+
}
173+
174+
type IndexScanOp uint8
175+
176+
//go:generate stringer -type=IndexScanOp -linecomment
177+
178+
const (
179+
IndexScanOpEq IndexScanOp = iota // =
180+
IndexScanOpNullSafeEq // <=>
181+
IndexScanOpInSet // =
182+
IndexScanOpNotInSet // !=
183+
IndexScanOpNotEq // !=
184+
IndexScanOpGt // >
185+
IndexScanOpGte // >=
186+
IndexScanOpLt // <
187+
IndexScanOpLte // <=
188+
IndexScanOpAnd // &&
189+
IndexScanOpOr // ||
190+
IndexScanOpIsNull // IS NULL
191+
IndexScanOpIsNotNull // IS NOT NULL
192+
IndexScanOpSpatialEq // SpatialEq
193+
IndexScanOpFulltextEq // FulltextEq
194+
)
195+
196+
// Swap returns the identity op for swapping a comparison's LHS and RHS
197+
func (o IndexScanOp) Swap() IndexScanOp {
198+
switch o {
199+
case IndexScanOpGt:
200+
return IndexScanOpLt
201+
case IndexScanOpGte:
202+
return IndexScanOpLte
203+
case IndexScanOpLt:
204+
return IndexScanOpGt
205+
case IndexScanOpLte:
206+
return IndexScanOpGte
207+
default:
208+
return o
209+
}
210+
}
211+
169212
func NewIndexLookup(idx Index, ranges MySQLRangeCollection, isPointLookup, isEmptyRange, isSpatialLookup, isReverse bool) IndexLookup {
170213
if isReverse {
171214
for i, j := 0, len(ranges)-1; i < j; i, j = i+1, j-1 {

sql/index_builder.go

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,11 @@ func floor(val interface{}) interface{} {
104104
}
105105

106106
// Equals represents colExpr = key. For IN expressions, pass all of them in the same Equals call.
107-
func (b *MySQLIndexBuilder) Equals(ctx *Context, colExpr string, keys ...interface{}) *MySQLIndexBuilder {
107+
func (b *MySQLIndexBuilder) Equals(ctx *Context, colExpr string, keyType Type, keys ...interface{}) *MySQLIndexBuilder {
108108
if b.isInvalid {
109109
return b
110110
}
111-
typ, ok := b.colExprTypes[colExpr]
111+
colTyp, ok := b.colExprTypes[colExpr]
112112
if !ok {
113113
b.isInvalid = true
114114
b.err = ErrInvalidColExpr.New(colExpr, b.idx.ID())
@@ -117,37 +117,38 @@ func (b *MySQLIndexBuilder) Equals(ctx *Context, colExpr string, keys ...interfa
117117
potentialRanges := make([]MySQLRangeColumnExpr, len(keys))
118118
for i, k := range keys {
119119
// if converting from float to int results in rounding, then it's empty range
120-
if t, ok := typ.(NumberType); ok && !t.IsFloat() {
120+
if t, ok := colTyp.(NumberType); ok && !t.IsFloat() {
121121
f, c := floor(k), ceil(k)
122122
switch k.(type) {
123123
case float32, float64:
124124
if f != c {
125-
potentialRanges[i] = EmptyRangeColumnExpr(typ)
125+
potentialRanges[i] = EmptyRangeColumnExpr(colTyp)
126126
continue
127127
}
128128
case decimal.Decimal:
129129
if !f.(decimal.Decimal).Equals(c.(decimal.Decimal)) {
130-
potentialRanges[i] = EmptyRangeColumnExpr(typ)
130+
potentialRanges[i] = EmptyRangeColumnExpr(colTyp)
131131
continue
132132
}
133133
}
134134
}
135135

136136
var err error
137-
k, _, err = typ.Convert(ctx, k)
137+
k, err = b.convertKey(ctx, colTyp, keyType, k)
138+
138139
if err != nil {
139140
b.isInvalid = true
140141
b.err = err
141142
return b
142143
}
143-
potentialRanges[i] = ClosedRangeColumnExpr(k, k, typ)
144+
potentialRanges[i] = ClosedRangeColumnExpr(k, k, colTyp)
144145
}
145146
b.updateCol(ctx, colExpr, potentialRanges...)
146147
return b
147148
}
148149

149150
// NotEquals represents colExpr <> key.
150-
func (b *MySQLIndexBuilder) NotEquals(ctx *Context, colExpr string, key interface{}) *MySQLIndexBuilder {
151+
func (b *MySQLIndexBuilder) NotEquals(ctx *Context, colExpr string, keyType Type, key interface{}) *MySQLIndexBuilder {
151152
if b.isInvalid {
152153
return b
153154
}
@@ -172,7 +173,7 @@ func (b *MySQLIndexBuilder) NotEquals(ctx *Context, colExpr string, key interfac
172173
}
173174
}
174175

175-
key, _, err := typ.Convert(ctx, key)
176+
key, err := b.convertKey(ctx, typ, keyType, key)
176177
if err != nil {
177178
b.isInvalid = true
178179
b.err = err
@@ -197,7 +198,7 @@ func (b *MySQLIndexBuilder) NotEquals(ctx *Context, colExpr string, key interfac
197198
}
198199

199200
// GreaterThan represents colExpr > key.
200-
func (b *MySQLIndexBuilder) GreaterThan(ctx *Context, colExpr string, key interface{}) *MySQLIndexBuilder {
201+
func (b *MySQLIndexBuilder) GreaterThan(ctx *Context, colExpr string, keyType Type, key interface{}) *MySQLIndexBuilder {
201202
if b.isInvalid {
202203
return b
203204
}
@@ -212,7 +213,7 @@ func (b *MySQLIndexBuilder) GreaterThan(ctx *Context, colExpr string, key interf
212213
key = floor(key)
213214
}
214215

215-
key, _, err := typ.Convert(ctx, key)
216+
key, err := b.convertKey(ctx, typ, keyType, key)
216217
if err != nil {
217218
b.isInvalid = true
218219
b.err = err
@@ -223,8 +224,18 @@ func (b *MySQLIndexBuilder) GreaterThan(ctx *Context, colExpr string, key interf
223224
return b
224225
}
225226

227+
// convertKey converts the given key from keyType to colType, returning an error if the conversion fails.
228+
func (b *MySQLIndexBuilder) convertKey(ctx *Context, colType Type, keyType Type, key interface{}) (interface{}, error) {
229+
if et, ok := colType.(ExtendedType); ok {
230+
return et.ConvertToType(ctx, keyType.(ExtendedType), key)
231+
} else {
232+
key, _, err := colType.Convert(ctx, key)
233+
return key, err
234+
}
235+
}
236+
226237
// GreaterOrEqual represents colExpr >= key.
227-
func (b *MySQLIndexBuilder) GreaterOrEqual(ctx *Context, colExpr string, key interface{}) *MySQLIndexBuilder {
238+
func (b *MySQLIndexBuilder) GreaterOrEqual(ctx *Context, colExpr string, keyType Type, key interface{}) *MySQLIndexBuilder {
228239
if b.isInvalid {
229240
return b
230241
}
@@ -247,7 +258,7 @@ func (b *MySQLIndexBuilder) GreaterOrEqual(ctx *Context, colExpr string, key int
247258
key = newKey
248259
}
249260

250-
key, _, err := typ.Convert(ctx, key)
261+
key, err := b.convertKey(ctx, typ, keyType, key)
251262
if err != nil {
252263
b.isInvalid = true
253264
b.err = err
@@ -266,7 +277,7 @@ func (b *MySQLIndexBuilder) GreaterOrEqual(ctx *Context, colExpr string, key int
266277
}
267278

268279
// LessThan represents colExpr < key.
269-
func (b *MySQLIndexBuilder) LessThan(ctx *Context, colExpr string, key interface{}) *MySQLIndexBuilder {
280+
func (b *MySQLIndexBuilder) LessThan(ctx *Context, colExpr string, keyType Type, key interface{}) *MySQLIndexBuilder {
270281
if b.isInvalid {
271282
return b
272283
}
@@ -280,7 +291,8 @@ func (b *MySQLIndexBuilder) LessThan(ctx *Context, colExpr string, key interface
280291
if t, ok := typ.(NumberType); ok && !t.IsFloat() {
281292
key = ceil(key)
282293
}
283-
key, _, err := typ.Convert(ctx, key)
294+
295+
key, err := b.convertKey(ctx, typ, keyType, key)
284296
if err != nil {
285297
b.isInvalid = true
286298
b.err = err
@@ -292,7 +304,7 @@ func (b *MySQLIndexBuilder) LessThan(ctx *Context, colExpr string, key interface
292304
}
293305

294306
// LessOrEqual represents colExpr <= key.
295-
func (b *MySQLIndexBuilder) LessOrEqual(ctx *Context, colExpr string, key interface{}) *MySQLIndexBuilder {
307+
func (b *MySQLIndexBuilder) LessOrEqual(ctx *Context, colExpr string, keyType Type, key interface{}) *MySQLIndexBuilder {
296308
if b.isInvalid {
297309
return b
298310
}
@@ -315,7 +327,7 @@ func (b *MySQLIndexBuilder) LessOrEqual(ctx *Context, colExpr string, key interf
315327
key = newKey
316328
}
317329

318-
key, _, err := typ.Convert(ctx, key)
330+
key, err := b.convertKey(ctx, typ, keyType, key)
319331
if err != nil {
320332
b.isInvalid = true
321333
b.err = err

sql/index_builder_test.go

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -46,88 +46,88 @@ func TestIndexBuilderRanges(t *testing.T) {
4646
t.Run("IsNull,Equals2=EmptyRange", func(t *testing.T) {
4747
builder := sql.NewMySQLIndexBuilder(testIndex{1})
4848
builder = builder.IsNull(ctx, "column_0")
49-
builder = builder.Equals(ctx, "column_0", 2)
49+
builder = builder.Equals(ctx, "column_0", nil, 2)
5050
ranges := builder.Ranges(ctx)
5151
assert.NotNil(t, ranges)
5252
assert.Equal(t, sql.MySQLRangeCollection{sql.MySQLRange{sql.EmptyRangeColumnExpr(types.Int8)}}, ranges)
5353
})
5454

5555
t.Run("NotEquals2=(NULL,2),(2,Inf)", func(t *testing.T) {
5656
builder := sql.NewMySQLIndexBuilder(testIndex{1})
57-
builder = builder.NotEquals(ctx, "column_0", 2)
57+
builder = builder.NotEquals(ctx, "column_0", nil, 2)
5858
ranges := builder.Ranges(ctx)
5959
assert.NotNil(t, ranges)
6060
assert.Equal(t, sql.MySQLRangeCollection{sql.MySQLRange{sql.GreaterThanRangeColumnExpr(int8(2), types.Int8)}, sql.MySQLRange{sql.LessThanRangeColumnExpr(int8(2), types.Int8)}}, ranges)
6161
})
6262

6363
t.Run("NotEquals2,Equals2=(Inf,Inf)", func(t *testing.T) {
6464
builder := sql.NewMySQLIndexBuilder(testIndex{1})
65-
builder = builder.NotEquals(ctx, "column_0", 2)
66-
builder = builder.Equals(ctx, "column_0", 2)
65+
builder = builder.NotEquals(ctx, "column_0", nil, 2)
66+
builder = builder.Equals(ctx, "column_0", nil, 2)
6767
ranges := builder.Ranges(ctx)
6868
assert.NotNil(t, ranges)
6969
assert.Equal(t, sql.MySQLRangeCollection{sql.MySQLRange{sql.EmptyRangeColumnExpr(types.Int8)}}, ranges)
7070
})
7171

7272
t.Run("Equals2,NotEquals2=(Inf,Inf)", func(t *testing.T) {
7373
builder := sql.NewMySQLIndexBuilder(testIndex{1})
74-
builder = builder.Equals(ctx, "column_0", 2)
75-
builder = builder.NotEquals(ctx, "column_0", 2)
74+
builder = builder.Equals(ctx, "column_0", nil, 2)
75+
builder = builder.NotEquals(ctx, "column_0", nil, 2)
7676
ranges := builder.Ranges(ctx)
7777
assert.NotNil(t, ranges)
7878
assert.Equal(t, sql.MySQLRangeCollection{sql.MySQLRange{sql.EmptyRangeColumnExpr(types.Int8)}}, ranges)
7979
})
8080

8181
t.Run("LT4=(NULL,4)", func(t *testing.T) {
8282
builder := sql.NewMySQLIndexBuilder(testIndex{1})
83-
builder = builder.LessThan(ctx, "column_0", 4)
83+
builder = builder.LessThan(ctx, "column_0", nil, 4)
8484
ranges := builder.Ranges(ctx)
8585
assert.NotNil(t, ranges)
8686
assert.Equal(t, sql.MySQLRangeCollection{sql.MySQLRange{sql.LessThanRangeColumnExpr(int8(4), types.Int8)}}, ranges)
8787
})
8888

8989
t.Run("GT2,LT4=(2,4)", func(t *testing.T) {
9090
builder := sql.NewMySQLIndexBuilder(testIndex{1})
91-
builder = builder.GreaterThan(ctx, "column_0", 2)
92-
builder = builder.LessThan(ctx, "column_0", 4)
91+
builder = builder.GreaterThan(ctx, "column_0", nil, 2)
92+
builder = builder.LessThan(ctx, "column_0", nil, 4)
9393
ranges := builder.Ranges(ctx)
9494
assert.NotNil(t, ranges)
9595
assert.Equal(t, sql.MySQLRangeCollection{sql.MySQLRange{sql.OpenRangeColumnExpr(int8(2), int8(4), types.Int8)}}, ranges)
9696
})
9797

9898
t.Run("GT2,GT6=(4,Inf)", func(t *testing.T) {
9999
builder := sql.NewMySQLIndexBuilder(testIndex{1})
100-
builder = builder.GreaterThan(ctx, "column_0", 2)
101-
builder = builder.GreaterThan(ctx, "column_0", 6)
100+
builder = builder.GreaterThan(ctx, "column_0", nil, 2)
101+
builder = builder.GreaterThan(ctx, "column_0", nil, 6)
102102
ranges := builder.Ranges(ctx)
103103
assert.NotNil(t, ranges)
104104
assert.Equal(t, sql.MySQLRangeCollection{sql.MySQLRange{sql.GreaterThanRangeColumnExpr(int8(6), types.Int8)}}, ranges)
105105
})
106106

107107
t.Run("GT2,LT4,GT6=(Inf,Inf)", func(t *testing.T) {
108108
builder := sql.NewMySQLIndexBuilder(testIndex{1})
109-
builder = builder.GreaterThan(ctx, "column_0", 2)
110-
builder = builder.LessThan(ctx, "column_0", 4)
111-
builder = builder.GreaterThan(ctx, "column_0", 6)
109+
builder = builder.GreaterThan(ctx, "column_0", nil, 2)
110+
builder = builder.LessThan(ctx, "column_0", nil, 4)
111+
builder = builder.GreaterThan(ctx, "column_0", nil, 6)
112112
ranges := builder.Ranges(ctx)
113113
assert.NotNil(t, ranges)
114114
assert.Equal(t, sql.MySQLRangeCollection{sql.MySQLRange{sql.EmptyRangeColumnExpr(types.Int8)}}, ranges)
115115
})
116116

117117
t.Run("NotEqual2,NotEquals4=(2,4),(4,Inf),(NULL,2)", func(t *testing.T) {
118118
builder := sql.NewMySQLIndexBuilder(testIndex{1})
119-
builder = builder.NotEquals(ctx, "column_0", 2)
120-
builder = builder.NotEquals(ctx, "column_0", 4)
119+
builder = builder.NotEquals(ctx, "column_0", nil, 2)
120+
builder = builder.NotEquals(ctx, "column_0", nil, 4)
121121
ranges := builder.Ranges(ctx)
122122
assert.NotNil(t, ranges)
123123
assert.Equal(t, sql.MySQLRangeCollection{sql.MySQLRange{sql.OpenRangeColumnExpr(int8(2), int8(4), types.Int8)}, sql.MySQLRange{sql.GreaterThanRangeColumnExpr(int8(4), types.Int8)}, sql.MySQLRange{sql.LessThanRangeColumnExpr(int8(2), types.Int8)}}, ranges)
124124
})
125125

126126
t.Run("ThreeColumnCombine", func(t *testing.T) {
127127
clauses := make([]sql.MySQLRangeCollection, 3)
128-
clauses[0] = sql.NewMySQLIndexBuilder(testIndex{3}).GreaterOrEqual(ctx, "column_0", 99).LessThan(ctx, "column_1", 66).Ranges(ctx)
129-
clauses[1] = sql.NewMySQLIndexBuilder(testIndex{3}).GreaterOrEqual(ctx, "column_0", 1).LessOrEqual(ctx, "column_0", 47).Ranges(ctx)
130-
clauses[2] = sql.NewMySQLIndexBuilder(testIndex{3}).NotEquals(ctx, "column_0", 2).LessThan(ctx, "column_1", 30).Ranges(ctx)
128+
clauses[0] = sql.NewMySQLIndexBuilder(testIndex{3}).GreaterOrEqual(ctx, "column_0", nil, 99).LessThan(ctx, "column_1", nil, 66).Ranges(ctx)
129+
clauses[1] = sql.NewMySQLIndexBuilder(testIndex{3}).GreaterOrEqual(ctx, "column_0", nil, 1).LessOrEqual(ctx, "column_0", nil, 47).Ranges(ctx)
130+
clauses[2] = sql.NewMySQLIndexBuilder(testIndex{3}).NotEquals(ctx, "column_0", nil, 2).LessThan(ctx, "column_1", nil, 30).Ranges(ctx)
131131
assert.Len(t, clauses[0], 1)
132132
assert.Len(t, clauses[1], 1)
133133
assert.Len(t, clauses[2], 2)

sql/analyzer/indexscanop_string.go renamed to sql/indexscanop_string.go

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

sql/plan/foreign_key_editor.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -763,14 +763,14 @@ func GetForeignKeyTypeConversions(
763763
childType := childSch[childIndex].Type
764764
parentType := parentSch[parentIndex].Type
765765

766-
childExtendedType, ok := childType.(types.ExtendedType)
766+
childExtendedType, ok := childType.(sql.ExtendedType)
767767
// if even one of the types is not an extended type, then we can't transform any values
768768
if !ok {
769769
return nil, nil
770770
}
771771

772772
if !childType.Equals(parentType) {
773-
parentExtendedType, ok := parentType.(types.ExtendedType)
773+
parentExtendedType, ok := parentType.(sql.ExtendedType)
774774
if !ok {
775775
// this should be impossible (child and parent should both be extended types), but just in case
776776
return nil, nil

sql/rowexec/agg.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@ import (
1818
"errors"
1919
"io"
2020

21-
"github.com/dolthub/go-mysql-server/sql/types"
22-
2321
"github.com/dolthub/go-mysql-server/sql"
2422
"github.com/dolthub/go-mysql-server/sql/expression/function/aggregation"
2523
"github.com/dolthub/go-mysql-server/sql/hash"
@@ -251,7 +249,7 @@ func (i *groupByGroupingIter) groupingKey(ctx *sql.Context, exprs []sql.Expressi
251249

252250
// TODO: this should be moved into hash.HashOf
253251
typ := expr.Type()
254-
if extTyp, isExtTyp := typ.(types.ExtendedType); isExtTyp {
252+
if extTyp, isExtTyp := typ.(sql.ExtendedType); isExtTyp {
255253
val, vErr := extTyp.SerializeValue(ctx, v)
256254
if vErr != nil {
257255
return 0, vErr

0 commit comments

Comments
 (0)