diff --git a/memory/index.go b/memory/index.go index 44eae97c4c..0d07d5ab77 100644 --- a/memory/index.go +++ b/memory/index.go @@ -73,6 +73,15 @@ func (idx *Index) Expressions() []string { return exprs } +func (idx *Index) UnqualifiedExpressions() []string { + exprs := make([]string, len(idx.Exprs)) + for i, e := range idx.Exprs { + str := e.String() + exprs[i] = str[strings.IndexByte(str, '.')+1:] + } + return exprs +} + func (idx *Index) ExtendedExpressions() []string { var exprs []string foundCols := make(map[string]struct{}) @@ -298,8 +307,9 @@ func (idx *Index) Reversible() bool { return true } -func (idx Index) copy() *Index { - return &idx +func (idx *Index) copy() *Index { + newIdx := *idx + return &newIdx } // columnIndexes returns the indexes in the given schema for the fields in this index diff --git a/sql/analyzer/common_test.go b/sql/analyzer/common_test.go index e4abd9e797..caa2ca925c 100644 --- a/sql/analyzer/common_test.go +++ b/sql/analyzer/common_test.go @@ -157,6 +157,8 @@ func runTestCases(t *testing.T, ctx *sql.Context, testCases []analyzerFnTestCase if expected == nil { expected = tt.node } + // Schema of certain nodes aren't filled until needed + expected.Schema() assertNodesEqualWithDiff(t, expected, result) }) diff --git a/sql/analyzer/index_analyzer_test.go b/sql/analyzer/index_analyzer_test.go index 944c5029c8..aba72c77bf 100644 --- a/sql/analyzer/index_analyzer_test.go +++ b/sql/analyzer/index_analyzer_test.go @@ -15,6 +15,7 @@ package analyzer import ( + "strings" "testing" "github.com/stretchr/testify/require" @@ -142,6 +143,15 @@ func (i dummyIdx) Expressions() []string { } return exprs } + +func (i dummyIdx) UnqualifiedExpressions() []string { + exprs := make([]string, len(i.expr)) + for i, e := range i.expr { + str := e.String() + exprs[i] = str[strings.IndexByte(str, '.')+1:] + } + return exprs +} func (i *dummyIdx) ID() string { return i.id } func (i *dummyIdx) Database() string { return i.database } func (i *dummyIdx) Table() string { return i.table } diff --git a/sql/index.go b/sql/index.go index 62aac79cc6..43b28314d7 100644 --- a/sql/index.go +++ b/sql/index.go @@ -102,6 +102,8 @@ type Index interface { // one expression, it means the index has multiple columns indexed. If it's // just one, it means it may be an expression or a column. Expressions() []string + // UnqualifiedExpressions returns the indexed expressions without the source. + UnqualifiedExpressions() []string // IsUnique returns whether this index is unique IsUnique() bool // IsSpatial returns whether this index is a spatial index diff --git a/sql/index_builder_test.go b/sql/index_builder_test.go index a7bde2a244..9f116883aa 100644 --- a/sql/index_builder_test.go +++ b/sql/index_builder_test.go @@ -188,6 +188,14 @@ func (i testIndex) Expressions() []string { return res } +func (i testIndex) UnqualifiedExpressions() []string { + res := make([]string, i.numcols) + for i := range res { + res[i] = fmt.Sprintf("column_%d", i) + } + return res +} + func (testIndex) IsUnique() bool { return false } diff --git a/sql/index_registry_test.go b/sql/index_registry_test.go index dbdbc2e415..2dd7530b5a 100644 --- a/sql/index_registry_test.go +++ b/sql/index_registry_test.go @@ -16,6 +16,7 @@ package sql import ( "fmt" + "strings" "testing" "github.com/stretchr/testify/require" @@ -443,6 +444,16 @@ func (i dummyIdx) Expressions() []string { } return exprs } + +func (i dummyIdx) UnqualifiedExpressions() []string { + exprs := make([]string, len(i.expr)) + for i, e := range i.expr { + str := e.String() + exprs[i] = str[strings.IndexByte(str, '.')+1:] + } + return exprs +} + func (i dummyIdx) ID() string { return i.id } func (i dummyIdx) Database() string { return i.database } func (i dummyIdx) Table() string { return i.table } diff --git a/sql/memo/rel_props.go b/sql/memo/rel_props.go index 91cc182ddc..7ee114c20c 100644 --- a/sql/memo/rel_props.go +++ b/sql/memo/rel_props.go @@ -131,13 +131,11 @@ func newRelProps(rel RelExpr) *relProps { } // idxExprsColumns returns the column names used in an index's expressions. -// TODO: this is unstable as long as periods in Index.Expressions() -// identifiers are ambiguous. +// Identifiers are ambiguous. func idxExprsColumns(idx sql.Index) []string { - columns := make([]string, len(idx.Expressions())) - for i, e := range idx.Expressions() { - parts := strings.Split(e, ".") - columns[i] = strings.ToLower(parts[1]) + columns := idx.UnqualifiedExpressions() + for i := 0; i < len(columns); i++ { + columns[i] = strings.ToLower(columns[i]) } return columns } @@ -791,17 +789,10 @@ func sortedColsForRel(rel RelExpr) sql.Schema { } case *MergeJoin: var ret sql.Schema - for _, e := range r.InnerScan.Table.Index().Expressions() { + for _, e := range r.InnerScan.Table.Index().UnqualifiedExpressions() { // TODO columns can have "." characters, this will miss cases - parts := strings.Split(e, ".") - var name string - if len(parts) == 2 { - name = parts[1] - } else { - return nil - } ret = append(ret, &sql.Column{ - Name: strings.ToLower(name), + Name: strings.ToLower(e), Source: strings.ToLower(r.InnerScan.Table.Name()), Nullable: true}, ) diff --git a/sql/memo/rel_props_test.go b/sql/memo/rel_props_test.go index a5034768c8..9f8bc0dbc6 100644 --- a/sql/memo/rel_props_test.go +++ b/sql/memo/rel_props_test.go @@ -2,6 +2,7 @@ package memo import ( "fmt" + "strings" "testing" "github.com/stretchr/testify/require" @@ -217,6 +218,14 @@ func (i dummyIndex) Expressions() []string { return i.cols } +func (i dummyIndex) UnqualifiedExpressions() []string { + res := make([]string, len(i.cols)) + for idx, col := range i.cols { + res[idx] = col[strings.IndexByte(col, '.')+1:] + } + return res +} + func (dummyIndex) IsUnique() bool { return true } diff --git a/sql/plan/project.go b/sql/plan/project.go index 9e377794c2..f1ea767a0a 100644 --- a/sql/plan/project.go +++ b/sql/plan/project.go @@ -35,6 +35,8 @@ type Project struct { // a RowIter. IncludesNestedIters bool deps sql.ColSet + + sch sql.Schema } var _ sql.Expressioner = (*Project)(nil) @@ -122,14 +124,16 @@ func ExprDeps(exprs ...sql.Expression) sql.ColSet { // Schema implements the Node interface. func (p *Project) Schema() sql.Schema { - var s = make(sql.Schema, len(p.Projections)) - for i, expr := range p.Projections { - s[i] = transform.ExpressionToColumn(expr, AliasSubqueryString(expr)) - if gf := unwrapGetField(expr); gf != nil { - s[i].Default = findDefault(p.Child, gf) + if p.sch == nil { + p.sch = make(sql.Schema, len(p.Projections)) + for i, expr := range p.Projections { + p.sch[i] = transform.ExpressionToColumn(expr, AliasSubqueryString(expr)) + if gf := unwrapGetField(expr); gf != nil { + p.sch[i].Default = findDefault(p.Child, gf) + } } } - return s + return p.sch } // Resolved implements the Resolvable interface. diff --git a/sql/plan/tablealias.go b/sql/plan/tablealias.go index de92e1fd3f..856a21e188 100644 --- a/sql/plan/tablealias.go +++ b/sql/plan/tablealias.go @@ -21,10 +21,12 @@ import ( // TableAlias is a node that acts as a table with a given name. type TableAlias struct { *UnaryNode - name string - comment string - id sql.TableId - cols sql.ColSet + name string + comment string + id sql.TableId + cols sql.ColSet + sch sql.Schema + cachedSch bool } var _ sql.RenameableNode = (*TableAlias)(nil) @@ -33,7 +35,10 @@ var _ sql.CollationCoercible = (*TableAlias)(nil) // NewTableAlias returns a new Table alias node. func NewTableAlias(name string, node sql.Node) *TableAlias { - ret := &TableAlias{UnaryNode: &UnaryNode{Child: node}, name: name} + ret := &TableAlias{ + UnaryNode: &UnaryNode{Child: node}, + name: name, + } if tin, ok := node.(TableIdNode); ok { ret.id = tin.Id() ret.cols = tin.Columns() @@ -87,14 +92,16 @@ func (t *TableAlias) Comment() string { // Schema implements the Node interface. TableAlias alters the schema of its child element to rename the source of // columns to the alias. func (t *TableAlias) Schema() sql.Schema { - childSchema := t.Child.Schema() - copy := make(sql.Schema, len(childSchema)) - for i, col := range childSchema { - colCopy := *col - colCopy.Source = t.name - copy[i] = &colCopy + if t.sch == nil { + childSchema := t.Child.Schema() + t.sch = make(sql.Schema, len(childSchema)) + for i, col := range childSchema { + newCol := *col + newCol.Source = t.name + t.sch[i] = &newCol + } } - return copy + return t.sch } // WithChildren implements the Node interface. @@ -118,21 +125,22 @@ func (t *TableAlias) CollationCoercibility(ctx *sql.Context) (collation sql.Coll return sql.Collation_binary, 7 } -func (t TableAlias) String() string { +func (t *TableAlias) String() string { pr := sql.NewTreePrinter() _ = pr.WriteNode("TableAlias(%s)", t.name) _ = pr.WriteChildren(t.Child.String()) return pr.String() } -func (t TableAlias) DebugString() string { +func (t *TableAlias) DebugString() string { pr := sql.NewTreePrinter() _ = pr.WriteNode("TableAlias(%s)", t.name) _ = pr.WriteChildren(sql.DebugString(t.Child)) return pr.String() } -func (t TableAlias) WithName(name string) sql.Node { - t.name = name - return &t +func (t *TableAlias) WithName(name string) sql.Node { + nt := *t + nt.name = name + return &nt } diff --git a/sql/planbuilder/ddl.go b/sql/planbuilder/ddl.go index c5f79201e9..5447518d18 100644 --- a/sql/planbuilder/ddl.go +++ b/sql/planbuilder/ddl.go @@ -395,11 +395,11 @@ func (b *Builder) getIndexDefs(table sql.Table) sql.IndexDefs { if !isIdxTbl { return nil } - var idxDefs sql.IndexDefs idxs, err := idxTbl.GetIndexes(b.ctx) if err != nil { b.handleErr(err) } + idxDefs := make(sql.IndexDefs, 0, len(idxs)) for _, idx := range idxs { if idx.IsGenerated() { continue @@ -412,10 +412,9 @@ func (b *Builder) getIndexDefs(table sql.Table) sql.IndexDefs { constraint = sql.IndexConstraint_Unique } } - columns := make([]sql.IndexColumn, len(idx.Expressions())) - for i, col := range idx.Expressions() { - // TODO: find a better way to get only the column name if the table is present - col = strings.TrimPrefix(col, idxTbl.Name()+".") + exprs := idx.UnqualifiedExpressions() + columns := make([]sql.IndexColumn, len(exprs)) + for i, col := range exprs { columns[i] = sql.IndexColumn{Name: col} } idxDefs = append(idxDefs, &sql.IndexDef{