Skip to content

Commit eca5bf7

Browse files
committed
allow using function and table function
1 parent 09a7e80 commit eca5bf7

File tree

8 files changed

+154
-35
lines changed

8 files changed

+154
-35
lines changed

enginetest/join_stats_tests.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -360,12 +360,12 @@ func (t TestProvider) Function(ctx *sql.Context, name string) (sql.Function, boo
360360
return nil, false
361361
}
362362

363-
func (t TestProvider) TableFunction(_ *sql.Context, name string) (sql.TableFunction, error) {
363+
func (t TestProvider) TableFunction(_ *sql.Context, name string) (sql.TableFunction, bool) {
364364
if tf, ok := t.tableFunctions[strings.ToLower(name)]; ok {
365-
return tf, nil
365+
return tf, true
366366
}
367367

368-
return nil, sql.ErrTableFunctionNotFound.New(name)
368+
return nil, false
369369
}
370370

371371
func (t TestProvider) WithTableFunctions(fns ...sql.TableFunction) (sql.TableFunctionProvider, error) {

memory/provider.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,10 +194,10 @@ func (pro *DbProvider) ExternalStoredProcedures(_ *sql.Context, name string) ([]
194194
}
195195

196196
// TableFunction implements sql.TableFunctionProvider
197-
func (pro *DbProvider) TableFunction(_ *sql.Context, name string) (sql.TableFunction, error) {
197+
func (pro *DbProvider) TableFunction(_ *sql.Context, name string) (sql.TableFunction, bool) {
198198
if tableFunction, ok := pro.tableFunctions[name]; ok {
199-
return tableFunction, nil
199+
return tableFunction, true
200200
}
201201

202-
return nil, sql.ErrTableFunctionNotFound.New(name)
202+
return nil, false
203203
}

sql/analyzer/catalog.go

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -384,17 +384,14 @@ func (c *Catalog) ExternalStoredProcedures(ctx *sql.Context, name string) ([]sql
384384
}
385385

386386
// TableFunction implements the TableFunctionProvider interface
387-
func (c *Catalog) TableFunction(ctx *sql.Context, name string) (sql.TableFunction, error) {
387+
func (c *Catalog) TableFunction(ctx *sql.Context, name string) (sql.TableFunction, bool) {
388388
if fp, ok := c.DbProvider.(sql.TableFunctionProvider); ok {
389-
tf, err := fp.TableFunction(ctx, name)
390-
if err != nil {
391-
return nil, err
392-
} else if tf != nil {
393-
return tf, nil
389+
tf, found := fp.TableFunction(ctx, name)
390+
if found && tf != nil {
391+
return tf, true
394392
}
395393
}
396-
397-
return nil, sql.ErrTableFunctionNotFound.New(name)
394+
return nil, false
398395
}
399396

400397
func (c *Catalog) RefreshTableStats(ctx *sql.Context, table sql.Table, db string) error {

sql/catalog_map.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,11 @@ func (t MapCatalog) Function(ctx *Context, name string) (Function, bool) {
2525
return nil, false
2626
}
2727

28-
func (t MapCatalog) TableFunction(ctx *Context, name string) (TableFunction, error) {
28+
func (t MapCatalog) TableFunction(ctx *Context, name string) (TableFunction, bool) {
2929
if f, ok := t.tabFuncs[name]; ok {
30-
return f, nil
30+
return f, true
3131
}
32-
return nil, fmt.Errorf("table func not found")
32+
return nil, false
3333
}
3434

3535
func (t MapCatalog) ExternalStoredProcedure(ctx *Context, name string, numOfParams int) (*ExternalStoredProcedureDetails, error) {

sql/databases.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ type CollatedDatabaseProvider interface {
5858
// always) implemented by a DatabaseProvider.
5959
type TableFunctionProvider interface {
6060
// TableFunction returns the table function with the name provided, case-insensitive
61-
TableFunction(ctx *Context, name string) (TableFunction, error)
61+
TableFunction(ctx *Context, name string) (TableFunction, bool)
6262
// WithTableFunctions returns a new provider with (only) the list of table functions arguments
6363
WithTableFunctions(fns ...TableFunction) (TableFunctionProvider, error)
6464
}
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
// Copyright 2024 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package dtablefunctions
16+
17+
import (
18+
"fmt"
19+
"strings"
20+
21+
"github.com/dolthub/go-mysql-server/sql"
22+
)
23+
24+
var _ sql.TableFunction = &TableFunction{}
25+
var _ sql.ExecSourceRel = &TableFunction{}
26+
27+
type TableFunction struct {
28+
underlyingFunc sql.Function
29+
30+
args []sql.Expression
31+
database sql.Database
32+
funcExpr sql.Expression
33+
}
34+
35+
func NewTableFunction(f sql.Function) sql.TableFunction {
36+
return &TableFunction{
37+
underlyingFunc: f,
38+
}
39+
}
40+
41+
func (t *TableFunction) NewInstance(ctx *sql.Context, db sql.Database, args []sql.Expression) (sql.Node, error) {
42+
nt := *t
43+
nt.database = db
44+
nt.args = args
45+
f, err := nt.underlyingFunc.NewInstance(args)
46+
if err != nil {
47+
return nil, err
48+
}
49+
nt.funcExpr = f
50+
return &nt, nil
51+
}
52+
53+
func (t *TableFunction) Children() []sql.Node {
54+
return nil
55+
}
56+
57+
func (t *TableFunction) Database() sql.Database {
58+
return t.database
59+
}
60+
61+
func (t *TableFunction) Expressions() []sql.Expression {
62+
return t.funcExpr.Children()
63+
}
64+
65+
func (t *TableFunction) IsReadOnly() bool {
66+
return true
67+
}
68+
69+
func (t *TableFunction) Name() string {
70+
return t.underlyingFunc.FunctionName()
71+
}
72+
73+
func (t *TableFunction) RowIter(ctx *sql.Context, r sql.Row) (sql.RowIter, error) {
74+
v, err := t.funcExpr.Eval(ctx, r)
75+
if err != nil {
76+
return nil, err
77+
}
78+
return sql.RowsToRowIter(sql.Row{v}), nil
79+
}
80+
81+
func (t *TableFunction) Resolved() bool {
82+
for _, expr := range t.args {
83+
return expr.Resolved()
84+
}
85+
return true
86+
}
87+
88+
func (t *TableFunction) Schema() sql.Schema {
89+
return sql.Schema{&sql.Column{Name: t.underlyingFunc.FunctionName(), Type: t.funcExpr.Type()}}
90+
}
91+
92+
func (t *TableFunction) String() string {
93+
var args []string
94+
for _, expr := range t.args {
95+
args = append(args, expr.String())
96+
}
97+
return fmt.Sprintf("%s(%s)", t.underlyingFunc.FunctionName(), strings.Join(args, ", "))
98+
}
99+
100+
func (t *TableFunction) WithChildren(children ...sql.Node) (sql.Node, error) {
101+
if len(children) != 0 {
102+
return nil, fmt.Errorf("unexpected children")
103+
}
104+
return t, nil
105+
}
106+
107+
func (t *TableFunction) WithDatabase(database sql.Database) (sql.Node, error) {
108+
nt := *t
109+
nt.database = database
110+
return &nt, nil
111+
}
112+
113+
func (t *TableFunction) WithExpressions(exprs ...sql.Expression) (sql.Node, error) {
114+
l := len(t.funcExpr.Children())
115+
if len(exprs) != l {
116+
return nil, sql.ErrInvalidChildrenNumber.New(t, len(exprs), l)
117+
}
118+
nt := *t
119+
nf, err := nt.funcExpr.WithChildren(exprs...)
120+
if err != nil {
121+
return nil, err
122+
}
123+
nt.funcExpr = nf
124+
return &nt, nil
125+
}

sql/planbuilder/from.go

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ package planbuilder
1616

1717
import (
1818
"fmt"
19+
dtablefunctions "github.com/dolthub/go-mysql-server/sql/expression/tablefunction"
1920
"strings"
2021

2122
ast "github.com/dolthub/vitess/go/vt/sqlparser"
@@ -447,30 +448,26 @@ func (b *Builder) resolveTable(tab, db string, asOf interface{}) *plan.ResolvedT
447448
func (b *Builder) buildTableFunc(inScope *scope, t *ast.TableFuncExpr) (outScope *scope) {
448449
//TODO what are valid mysql table arguments
449450
args := make([]sql.Expression, 0, len(t.Exprs))
450-
for _, e := range t.Exprs {
451-
switch e := e.(type) {
451+
for _, expr := range t.Exprs {
452+
switch e := expr.(type) {
452453
case *ast.AliasedExpr:
453-
expr := b.buildScalar(inScope, e.Expr)
454-
455-
if !e.As.IsEmpty() {
456-
b.handleErr(sql.ErrUnsupportedSyntax.New(ast.String(e)))
457-
}
458-
459-
if selectExprNeedsAlias(e, expr) {
460-
b.handleErr(sql.ErrUnsupportedSyntax.New(ast.String(e)))
461-
}
462-
463-
args = append(args, expr)
454+
scalarExpr := b.buildScalar(inScope, e.Expr)
455+
args = append(args, scalarExpr)
464456
default:
465457
b.handleErr(sql.ErrUnsupportedSyntax.New(ast.String(e)))
466458
}
467459
}
468460

469461
utf := expression.NewUnresolvedTableFunction(t.Name, args)
470462

471-
tableFunction, err := b.cat.TableFunction(b.ctx, utf.Name())
472-
if err != nil {
473-
b.handleErr(err)
463+
tableFunction, found := b.cat.TableFunction(b.ctx, utf.Name())
464+
if !found {
465+
// try getting regular function
466+
f, funcFound := b.cat.Function(b.ctx, utf.Name())
467+
if !funcFound {
468+
b.handleErr(sql.ErrTableFunctionNotFound.New(utf.Name()))
469+
}
470+
tableFunction = dtablefunctions.NewTableFunction(f)
474471
}
475472

476473
database := b.currentDb()

test/test_catalog.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ func (c *Catalog) UnlockTables(ctx *sql.Context, id uint32) error {
159159
return nil
160160
}
161161

162-
func (c *Catalog) TableFunction(ctx *sql.Context, name string) (sql.TableFunction, error) {
162+
func (c *Catalog) TableFunction(ctx *sql.Context, name string) (sql.TableFunction, bool) {
163163
//TODO implement me
164164
panic("implement me")
165165
}

0 commit comments

Comments
 (0)