diff --git a/memory/stats.go b/memory/stats.go index 94a0a3116c..cfd302d049 100644 --- a/memory/stats.go +++ b/memory/stats.go @@ -60,10 +60,11 @@ func (s *StatsProv) AnalyzeTable(ctx *sql.Context, table sql.Table, db string) e } newStats := make(map[statsKey][]int) - tablePrefix := fmt.Sprintf("%s.", strings.ToLower(table.Name())) + tablePrefix := strings.ToLower(table.Name()) + "." for _, idx := range indexes { - cols := make([]string, len(idx.Expressions())) - for i, c := range idx.Expressions() { + exprs := idx.Expressions() + cols := make([]string, len(exprs)) + for i, c := range exprs { cols[i] = strings.TrimPrefix(strings.ToLower(c), tablePrefix) } for i := 1; i < len(cols)+1; i++ { @@ -244,7 +245,7 @@ func (s *StatsProv) reservoirSample(ctx *sql.Context, table sql.Table) ([]sql.Ro } func (s *StatsProv) GetTableStats(ctx *sql.Context, db string, table sql.Table) ([]sql.Statistic, error) { - pref := fmt.Sprintf("%s.%s", strings.ToLower(db), strings.ToLower(table.Name())) + pref := strings.ToLower(db) + "." + strings.ToLower(table.Name()) var ret []sql.Statistic for key, stats := range s.colStats { if strings.HasPrefix(string(key), pref) { @@ -279,7 +280,7 @@ func (s *StatsProv) DropStats(ctx *sql.Context, qual sql.StatQualifier, cols []s } func (s *StatsProv) RowCount(ctx *sql.Context, db string, table sql.Table) (uint64, error) { - pref := fmt.Sprintf("%s.%s", strings.ToLower(db), strings.ToLower(table.Name())) + pref := strings.ToLower(db) + "." + strings.ToLower(table.Name()) var cnt uint64 for key, stats := range s.colStats { if strings.HasPrefix(string(key), pref) { @@ -292,7 +293,7 @@ func (s *StatsProv) RowCount(ctx *sql.Context, db string, table sql.Table) (uint } func (s *StatsProv) DataLength(ctx *sql.Context, db string, table sql.Table) (uint64, error) { - pref := fmt.Sprintf("%s.%s", db, table) + pref := strings.ToLower(db) + "." + strings.ToLower(table.Name()) var size uint64 for key, stats := range s.colStats { if strings.HasPrefix(string(key), pref) { diff --git a/sql/analyzer/costed_index_scan.go b/sql/analyzer/costed_index_scan.go index b5fa511169..ddfd73b6c8 100644 --- a/sql/analyzer/costed_index_scan.go +++ b/sql/analyzer/costed_index_scan.go @@ -155,7 +155,7 @@ func costedIndexLookup(ctx *sql.Context, n sql.Node, a *Analyzer, iat sql.IndexA } func getCostedIndexScan(ctx *sql.Context, statsProv sql.StatsProvider, rt sql.TableNode, indexes []sql.Index, filters []sql.Expression, qFlags *sql.QueryFlags) (*plan.IndexedTableAccess, sql.Statistic, []sql.Expression, error) { - statistics, err := statsProv.GetTableStats(ctx, strings.ToLower(rt.Database().Name()), rt.UnderlyingTable()) + statistics, err := statsProv.GetTableStats(ctx, rt.Database().Name(), rt.UnderlyingTable()) if err != nil { return nil, nil, nil, err } @@ -182,19 +182,19 @@ func getCostedIndexScan(ctx *sql.Context, statsProv sql.StatsProvider, rt sql.Ta // run each index through coster, save the cheapest var dbName string if dbTab, ok := rt.UnderlyingTable().(sql.Databaseable); ok { - dbName = strings.ToLower(dbTab.Database()) + dbName = dbTab.Database() } table := rt.UnderlyingTable() var schemaName string if schTab, ok := table.(sql.DatabaseSchemaTable); ok { - schemaName = strings.ToLower(schTab.DatabaseSchema().SchemaName()) + schemaName = schTab.DatabaseSchema().SchemaName() } - tableName := strings.ToLower(table.Name()) + tableName := table.Name() if len(qualToStat) > 0 { // don't mix and match real and default stats for _, idx := range indexes { - qual := sql.NewStatQualifier(dbName, schemaName, tableName, strings.ToLower(idx.ID())) + qual := sql.NewStatQualifier(dbName, schemaName, tableName, idx.ID()) _, ok := qualToStat[qual] if !ok { qualToStat = nil @@ -204,15 +204,15 @@ func getCostedIndexScan(ctx *sql.Context, statsProv sql.StatsProvider, rt sql.Ta } for _, idx := range indexes { - qual := sql.NewStatQualifier(dbName, schemaName, tableName, strings.ToLower(idx.ID())) + qual := sql.NewStatQualifier(dbName, schemaName, tableName, idx.ID()) stat, ok := qualToStat[qual] if !ok { stat, err = uniformDistStatisticsForIndex(ctx, statsProv, iat, idx) + if err != nil { + return nil, nil, nil, err + } } - if err != nil { - return nil, nil, nil, err - } - err := c.cost(root, stat, idx) + err = c.cost(root, stat, idx) if err != nil { return nil, nil, nil, err } @@ -492,6 +492,7 @@ func (c *indexCoster) cost(f indexFilter, stat sql.Statistic, idx sql.Index) err if ok { filters.Add(int(f.id)) } + case *iScanLeaf: newHist, newFds, ok, prefix, err = c.costIndexScanLeaf(f, stat, stat.Histogram(), ordinals, idx) if err != nil { @@ -500,6 +501,7 @@ func (c *indexCoster) cost(f indexFilter, stat sql.Statistic, idx sql.Index) err if ok { filters.Add(int(f.id)) } + default: panic("unreachable") } @@ -517,7 +519,7 @@ func (c *indexCoster) updateBest(s sql.Statistic, hist []sql.HistogramBucket, fd if s == nil || filters.Len() == 0 { return } - rowCnt, _, _ := stats.GetNewCounts(hist) + rowCnt := stats.GetNewRowCounts(hist) var update bool defer func() { @@ -534,16 +536,24 @@ func (c *indexCoster) updateBest(s sql.Statistic, hist []sql.HistogramBucket, fd if c.bestStat == nil { update = true return - } else if c.bestStat.FuncDeps().HasMax1Row() { + } + + if c.bestStat.FuncDeps().HasMax1Row() { return - } else if rowCnt < c.bestCnt { + } + + if rowCnt < c.bestCnt { update = true return - } else if c.bestPrefix == 0 || prefix == 0 && c.bestPrefix != prefix { + } + + if c.bestPrefix == 0 || prefix == 0 && c.bestPrefix != prefix { // any prefix is better than no prefix update = prefix > c.bestPrefix return - } else if rowCnt == c.bestCnt { + } + + if rowCnt == c.bestCnt { // hand rules when stats don't exist or match exactly cmp := fds best := c.bestStat.FuncDeps() @@ -554,21 +564,20 @@ func (c *indexCoster) updateBest(s sql.Statistic, hist []sql.HistogramBucket, fd // If one index uses a strict superset of the filters of the other, we should always pick the superset. // This is true even if the index with more filters isn't unique. - if prefix > c.bestPrefix && slices.Equal(c.bestStat.Columns()[:c.bestPrefix], s.Columns()[:c.bestPrefix]) { + bestCols := c.bestStat.Columns() + newCols := s.Columns() + if prefix > c.bestPrefix && slices.Equal(bestCols[:c.bestPrefix], newCols[:c.bestPrefix]) { update = true return } - - if prefix == c.bestPrefix && slices.Equal(c.bestStat.Columns()[:c.bestPrefix], s.Columns()[:c.bestPrefix]) && hasRange && !c.hasRange { + if prefix == c.bestPrefix && slices.Equal(bestCols[:c.bestPrefix], newCols[:c.bestPrefix]) && hasRange && !c.hasRange { update = true return } - - if c.bestPrefix > prefix && slices.Equal(c.bestStat.Columns()[:prefix], s.Columns()[:prefix]) { + if c.bestPrefix > prefix && slices.Equal(bestCols[:prefix], newCols[:prefix]) { return } - - if c.bestPrefix == prefix && slices.Equal(c.bestStat.Columns()[:prefix], s.Columns()[:prefix]) && !hasRange && c.hasRange { + if c.bestPrefix == prefix && slices.Equal(bestCols[:prefix], newCols[:prefix]) && !hasRange && c.hasRange { return } @@ -600,7 +609,8 @@ func (c *indexCoster) updateBest(s sql.Statistic, hist []sql.HistogramBucket, fd } update = true return - } else if cmp.Constants().Len() < best.Constants().Len() { + } + if cmp.Constants().Len() < best.Constants().Len() { if cmpHasLax && !bestHasLax { // keep unique key update = true @@ -612,7 +622,6 @@ func (c *indexCoster) updateBest(s sql.Statistic, hist []sql.HistogramBucket, fd update = true return } - if filters.Len() < c.bestFilters.Len() { return } @@ -624,32 +633,29 @@ func (c *indexCoster) updateBest(s sql.Statistic, hist []sql.HistogramBucket, fd return } - { - // if no unique keys, prefer equality over ranges - bestConst, bestIsNull := c.getConstAndNullFilters(c.bestFilters) - cmpConst, cmpIsNull := c.getConstAndNullFilters(filters) - if cmpConst.Len() > bestConst.Len() { - update = true - return - } - if cmpIsNull.Len() > bestIsNull.Len() { - update = true - return - } + // if no unique keys, prefer equality over ranges + bestConst, bestIsNull := c.getConstAndNullFilters(c.bestFilters) + cmpConst, cmpIsNull := c.getConstAndNullFilters(filters) + if cmpConst.Len() > bestConst.Len() { + update = true + return + } + if cmpIsNull.Len() > bestIsNull.Len() { + update = true + return } - { - if strings.EqualFold(s.Qualifier().Index(), "primary") { - update = true - return - } else if strings.EqualFold(c.bestStat.Qualifier().Index(), "primary") { - return - } - if strings.Compare(s.Qualifier().Index(), c.bestStat.Qualifier().Index()) < 0 { - // if they are still equal, use index name to make deterministic - update = true - return - } + if strings.EqualFold(s.Qualifier().Index(), "primary") { + update = true + return + } + if strings.EqualFold(c.bestStat.Qualifier().Index(), "primary") { + return + } + if strings.Compare(s.Qualifier().Index(), c.bestStat.Qualifier().Index()) < 0 { + // if they are still equal, use index name to make deterministic + update = true + return } } } diff --git a/sql/convert_value.go b/sql/convert_value.go index 880b9f2f58..a759b8c529 100644 --- a/sql/convert_value.go +++ b/sql/convert_value.go @@ -1,92 +1,90 @@ package sql import ( - "fmt" - "github.com/dolthub/go-mysql-server/sql/values" "github.com/dolthub/vitess/go/vt/proto/query" ) // ConvertToValue converts the interface to a sql value. -func ConvertToValue(v interface{}) (Value, error) { +func ConvertToValue(v interface{}) Value { switch v := v.(type) { case nil: return Value{ Typ: query.Type_NULL_TYPE, Val: nil, - }, nil + } case int: return Value{ Typ: query.Type_INT64, Val: values.WriteInt64(make([]byte, values.Int64Size), int64(v)), - }, nil + } case int8: return Value{ Typ: query.Type_INT8, Val: values.WriteInt8(make([]byte, values.Int8Size), v), - }, nil + } case int16: return Value{ Typ: query.Type_INT16, Val: values.WriteInt16(make([]byte, values.Int16Size), v), - }, nil + } case int32: return Value{ Typ: query.Type_INT32, Val: values.WriteInt32(make([]byte, values.Int32Size), v), - }, nil + } case int64: return Value{ Typ: query.Type_INT64, Val: values.WriteInt64(make([]byte, values.Int64Size), v), - }, nil + } case uint: return Value{ Typ: query.Type_UINT64, Val: values.WriteUint64(make([]byte, values.Uint64Size), uint64(v)), - }, nil + } case uint8: return Value{ Typ: query.Type_UINT8, Val: values.WriteUint8(make([]byte, values.Uint8Size), v), - }, nil + } case uint16: return Value{ Typ: query.Type_UINT16, Val: values.WriteUint16(make([]byte, values.Uint16Size), v), - }, nil + } case uint32: return Value{ Typ: query.Type_UINT32, Val: values.WriteUint32(make([]byte, values.Uint32Size), v), - }, nil + } case uint64: return Value{ Typ: query.Type_UINT64, Val: values.WriteUint64(make([]byte, values.Uint64Size), v), - }, nil + } case float32: return Value{ Typ: query.Type_FLOAT32, Val: values.WriteFloat32(make([]byte, values.Float32Size), v), - }, nil + } case float64: return Value{ Typ: query.Type_FLOAT64, Val: values.WriteFloat64(make([]byte, values.Float64Size), v), - }, nil + } case string: return Value{ Typ: query.Type_VARCHAR, Val: values.WriteString(make([]byte, len(v)), v, values.ByteOrderCollation), - }, nil + } case []byte: return Value{ Typ: query.Type_BLOB, Val: values.WriteBytes(make([]byte, len(v)), v, values.ByteOrderCollation), - }, nil + } default: - return Value{}, fmt.Errorf("type %T not implemented", v) + return Value{} } } diff --git a/sql/expression/alias.go b/sql/expression/alias.go index ea587555c9..3c9d7283a2 100644 --- a/sql/expression/alias.go +++ b/sql/expression/alias.go @@ -140,7 +140,7 @@ func (e *Alias) Describe(options sql.DescribeOptions) string { return fmt.Sprintf("%s->%s:%d", sql.Describe(e.Child, options), e.name, e.id) } } - return fmt.Sprintf("%s as %s", sql.Describe(e.Child, options), e.name) + return sql.Describe(e.Child, options) + " as " + e.name } func (e *Alias) String() string { diff --git a/sql/expression/arithmetic.go b/sql/expression/arithmetic.go index dc42d6a51d..1cefd09961 100644 --- a/sql/expression/arithmetic.go +++ b/sql/expression/arithmetic.go @@ -798,7 +798,7 @@ func (*UnaryMinus) CollationCoercibility(ctx *sql.Context) (collation sql.Collat } func (e *UnaryMinus) String() string { - return fmt.Sprintf("-%s", e.Child) + return "-" + e.Child.String() } // WithChildren implements the Expression interface. diff --git a/sql/expression/function/aggregation/unary_aggs.og.go b/sql/expression/function/aggregation/unary_aggs.og.go index a5094cc975..eb0509115b 100644 --- a/sql/expression/function/aggregation/unary_aggs.og.go +++ b/sql/expression/function/aggregation/unary_aggs.og.go @@ -440,7 +440,7 @@ func (a *Count) String() string { pr.WriteChildren(children...) return pr.String() } - return fmt.Sprintf("COUNT(%s)", a.Child) + return "COUNT(" + a.Child.String() + ")" } func (a *Count) DebugString() string { diff --git a/sql/expression/literal.go b/sql/expression/literal.go index 104c04fd97..8757bb8c31 100644 --- a/sql/expression/literal.go +++ b/sql/expression/literal.go @@ -16,6 +16,7 @@ package expression import ( "fmt" + "strconv" "strings" "github.com/dolthub/vitess/go/vt/proto/query" @@ -40,7 +41,7 @@ var _ sqlparser.Injectable = &Literal{} // NewLiteral creates a new Literal expression. func NewLiteral(value interface{}, fieldType sql.Type) *Literal { - val2, _ := sql.ConvertToValue(value) + val2 := sql.ConvertToValue(value) return &Literal{ Val: value, val2: val2, @@ -79,8 +80,26 @@ func (lit *Literal) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { func (lit *Literal) String() string { switch litVal := lit.Val.(type) { - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: - return fmt.Sprintf("%d", litVal) + case int: + return strconv.FormatInt(int64(litVal), 10) + case int8: + return strconv.FormatInt(int64(litVal), 10) + case int16: + return strconv.FormatInt(int64(litVal), 10) + case int32: + return strconv.FormatInt(int64(litVal), 10) + case int64: + return strconv.FormatInt(litVal, 10) + case uint: + return strconv.FormatUint(uint64(litVal), 10) + case uint8: + return strconv.FormatUint(uint64(litVal), 10) + case uint16: + return strconv.FormatUint(uint64(litVal), 10) + case uint32: + return strconv.FormatUint(uint64(litVal), 10) + case uint64: + return strconv.FormatUint(litVal, 10) case string: switch lit.Typ.Type() { // utf8 charset cannot encode binary string @@ -91,7 +110,7 @@ func (lit *Literal) String() string { // Backslash chars also need to be replaced. escaped := strings.ReplaceAll(litVal, "'", "''") escaped = strings.ReplaceAll(escaped, "\\", "\\\\") - return fmt.Sprintf("'%s'", escaped) + return "'" + escaped + "'" case decimal.Decimal: return litVal.StringFixed(litVal.Exponent() * -1) case []byte: diff --git a/sql/schemas.go b/sql/schemas.go index ce34ce52d7..9455574931 100644 --- a/sql/schemas.go +++ b/sql/schemas.go @@ -101,9 +101,8 @@ func (s Schema) IndexOf(column, source string) int { // IndexOfColName returns the index of the given column in the schema or -1 if it's not present. Only safe for schemas // corresponding to a single table, where the source of the column is irrelevant. func (s Schema) IndexOfColName(column string) int { - column = strings.ToLower(column) for i, col := range s { - if strings.ToLower(col.Name) == column { + if strings.EqualFold(col.Name, column) { return i } } diff --git a/sql/statistics.go b/sql/statistics.go index ae1f9fda95..ef66dac32c 100644 --- a/sql/statistics.go +++ b/sql/statistics.go @@ -117,7 +117,8 @@ func NewStatQualifier(db, schema, table, index string) StatQualifier { Database: strings.ToLower(db), Sch: strings.ToLower(schema), Tab: strings.ToLower(table), - Idx: strings.ToLower(index)} + Idx: strings.ToLower(index), + } } // StatQualifier is the namespace hierarchy for a given statistic. diff --git a/sql/stats/filter.go b/sql/stats/filter.go index f2f2de5a21..ecff9f8553 100644 --- a/sql/stats/filter.go +++ b/sql/stats/filter.go @@ -158,7 +158,14 @@ func nilSafeCmp(ctx *sql.Context, typ sql.Type, left, right interface{}) (int, e } } -func GetNewCounts(buckets []sql.HistogramBucket) (rowCount uint64, distinctCount uint64, nullCount uint64) { +func GetNewRowCounts(buckets []sql.HistogramBucket) (rowCount uint64) { + for _, b := range buckets { + rowCount += b.RowCount() + } + return rowCount +} + +func GetAllNewCounts(buckets []sql.HistogramBucket) (rowCount uint64, distinctCount uint64, nullCount uint64) { if len(buckets) == 0 { return 0, 0, 0 } diff --git a/sql/stats/statistic.go b/sql/stats/statistic.go index c72dc89565..7e17287dc1 100644 --- a/sql/stats/statistic.go +++ b/sql/stats/statistic.go @@ -207,7 +207,7 @@ func (s *Statistic) WithLowerBound(r sql.Row) sql.Statistic { func (s *Statistic) WithHistogram(h sql.Histogram) (sql.Statistic, error) { ret := *s - ret.Hist = nil + ret.Hist = make(sql.Histogram, 0, len(h)) for _, b := range h { sqlB, ok := b.(*Bucket) if !ok { diff --git a/sql/value_row.go b/sql/value_row.go index f9140c41c5..e5fcd838c5 100644 --- a/sql/value_row.go +++ b/sql/value_row.go @@ -29,8 +29,8 @@ type ValueBytes []byte // Value is a logical index into a ValueRow. For efficiency reasons, use sparingly. type Value struct { - Val ValueBytes WrappedVal BytesWrapper + Val ValueBytes Typ query.Type }