Skip to content

Commit 3ee3aa2

Browse files
committed
Add an interpreter base for function creation
1 parent b57eb10 commit 3ee3aa2

File tree

16 files changed

+949
-8
lines changed

16 files changed

+949
-8
lines changed

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ require (
1010
github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20241119094239-f4e529af734d
1111
github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2
1212
github.com/dolthub/go-icu-regex v0.0.0-20241215010122-db690dd53c90
13-
github.com/dolthub/go-mysql-server v0.19.1-0.20250123191908-97b350e0113f
13+
github.com/dolthub/go-mysql-server v0.19.1-0.20250130122244-1a4cc3db2ffb
1414
github.com/dolthub/sqllogictest/go v0.0.0-20240618184124-ca47f9354216
1515
github.com/dolthub/vitess v0.0.0-20250123002143-3b45b8cacbfa
1616
github.com/fatih/color v1.13.0

go.sum

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,8 +224,8 @@ github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U=
224224
github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0=
225225
github.com/dolthub/go-icu-regex v0.0.0-20241215010122-db690dd53c90 h1:Sni8jrP0sy/w9ZYXoff4g/ixe+7bFCZlfCqXKJSU+zM=
226226
github.com/dolthub/go-icu-regex v0.0.0-20241215010122-db690dd53c90/go.mod h1:ylU4XjUpsMcvl/BKeRRMXSH7e7WBrPXdSLvnRJYrxEA=
227-
github.com/dolthub/go-mysql-server v0.19.1-0.20250123191908-97b350e0113f h1:14ITP1m0kBLZPabq8rGmE9VXbORauhPzXhIAg7E31/Q=
228-
github.com/dolthub/go-mysql-server v0.19.1-0.20250123191908-97b350e0113f/go.mod h1:jYEJ8tNkA7K3k39X8iMqaX3MSMmViRgh222JSLHDgVc=
227+
github.com/dolthub/go-mysql-server v0.19.1-0.20250130122244-1a4cc3db2ffb h1:/GXre5yEwbz8dV189A70LrTLcrzCBcZm4jc61B3vRmA=
228+
github.com/dolthub/go-mysql-server v0.19.1-0.20250130122244-1a4cc3db2ffb/go.mod h1:jYEJ8tNkA7K3k39X8iMqaX3MSMmViRgh222JSLHDgVc=
229229
github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63 h1:OAsXLAPL4du6tfbBgK0xXHZkOlos63RdKYS3Sgw/dfI=
230230
github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63/go.mod h1:lV7lUeuDhH5thVGDCKXbatwKy2KW80L4rMT46n+Y2/Q=
231231
github.com/dolthub/ishell v0.0.0-20240701202509-2b217167d718 h1:lT7hE5k+0nkBdj/1UOSFwjWpNxf+LCApbRHgnCA17XE=

scripts/format_repo.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ paths=`find . -maxdepth 1 -mindepth 1 ! -name ".idea" ! -name ".git" ! -name ".g
2323

2424
goimports -w -local github.com/dolthub/doltgresql $paths
2525

26-
bad_files=$(find $paths -name '*.go' | while read f; do
26+
bad_files=$(find $paths -name '*.go' ! -path ".idea/*" | while read f; do
2727
if [[ $(awk '/import \(/{flag=1;next}/\)/{flag=0}flag' < $f | egrep -c '$^') -gt 2 ]]; then
2828
echo $f
2929
fi

server/expression/init.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// Copyright 2025 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package expression
16+
17+
import (
18+
"github.com/dolthub/doltgresql/server/functions/framework"
19+
pgtypes "github.com/dolthub/doltgresql/server/types"
20+
)
21+
22+
// Init handles all setup needed for this package.
23+
func Init() {
24+
framework.NewUnsafeLiteral = func(val any, t *pgtypes.DoltgresType) framework.LiteralInterface {
25+
return NewUnsafeLiteral(val, t)
26+
}
27+
}

server/functions/framework/catalog.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ func RegisterFunction(f FunctionInterface) {
5252
case Function4:
5353
name := strings.ToLower(f.Name)
5454
Catalog[name] = append(Catalog[name], f)
55+
case InterpretedFunction:
56+
name := strings.ToLower(f.ID.FunctionName())
57+
Catalog[name] = append(Catalog[name], f)
5558
default:
5659
panic("unhandled function type")
5760
}

server/functions/framework/compiled_function.go

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020

2121
cerrors "github.com/cockroachdb/errors"
2222
"github.com/dolthub/go-mysql-server/sql"
23+
"github.com/dolthub/go-mysql-server/sql/analyzer"
2324
"github.com/dolthub/go-mysql-server/sql/expression"
2425
"gopkg.in/src-d/go-errors.v1"
2526

@@ -46,15 +47,17 @@ type CompiledFunction struct {
4647
overload overloadMatch
4748
originalTypes []*pgtypes.DoltgresType
4849
callResolved []*pgtypes.DoltgresType
50+
runner analyzer.StatementRunner
4951
stashedErr error
5052
}
5153

5254
var _ sql.FunctionExpression = (*CompiledFunction)(nil)
5355
var _ sql.NonDeterministicExpression = (*CompiledFunction)(nil)
56+
var _ analyzer.Interpreter = (*CompiledFunction)(nil)
5457

5558
// NewCompiledFunction returns a newly compiled function.
5659
func NewCompiledFunction(name string, args []sql.Expression, functions *Overloads, isOperator bool) *CompiledFunction {
57-
return newCompiledFunctionInternal(name, args, functions, functions.overloadsForParams(len(args)), isOperator)
60+
return newCompiledFunctionInternal(name, args, functions, functions.overloadsForParams(len(args)), isOperator, nil)
5861
}
5962

6063
// newCompiledFunctionInternal is called internally, which skips steps that may have already been processed.
@@ -64,8 +67,16 @@ func newCompiledFunctionInternal(
6467
overloads *Overloads,
6568
fnOverloads []Overload,
6669
isOperator bool,
70+
runner analyzer.StatementRunner,
6771
) *CompiledFunction {
68-
c := &CompiledFunction{Name: name, Arguments: args, IsOperator: isOperator, overloads: overloads, fnOverloads: fnOverloads}
72+
c := &CompiledFunction{
73+
Name: name,
74+
Arguments: args,
75+
IsOperator: isOperator,
76+
overloads: overloads,
77+
fnOverloads: fnOverloads,
78+
runner: runner,
79+
}
6980
// First we'll analyze all the parameters.
7081
originalTypes, err := c.analyzeParameters()
7182
if err != nil {
@@ -292,6 +303,8 @@ func (c *CompiledFunction) Eval(ctx *sql.Context, row sql.Row) (interface{}, err
292303
return f.Callable(ctx, ([4]*pgtypes.DoltgresType)(c.callResolved), args[0], args[1], args[2])
293304
case Function4:
294305
return f.Callable(ctx, ([5]*pgtypes.DoltgresType)(c.callResolved), args[0], args[1], args[2], args[3])
306+
case InterpretedFunction:
307+
return f.Call(ctx, c.runner, c.callResolved, args)
295308
default:
296309
return nil, cerrors.Errorf("unknown function type in CompiledFunction::Eval")
297310
}
@@ -309,7 +322,13 @@ func (c *CompiledFunction) WithChildren(children ...sql.Expression) (sql.Express
309322
}
310323

311324
// We have to re-resolve here, since the change in children may require it (e.g. we have more type info than we did)
312-
return newCompiledFunctionInternal(c.Name, children, c.overloads, c.fnOverloads, c.IsOperator), nil
325+
return newCompiledFunctionInternal(c.Name, children, c.overloads, c.fnOverloads, c.IsOperator, c.runner), nil
326+
}
327+
328+
// SetStatementRunner implements the interface analyzer.Interpreter.
329+
func (c *CompiledFunction) SetStatementRunner(ctx *sql.Context, runner analyzer.StatementRunner) sql.Expression {
330+
c.runner = runner
331+
return c
313332
}
314333

315334
// GetQuickFunction returns the QuickFunction form of this function, if it exists. If one does not exist, then this

server/functions/framework/intermediate_function.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,5 +31,5 @@ func (f IntermediateFunction) Compile(name string, parameters ...sql.Expression)
3131
if f.Functions == nil {
3232
return nil
3333
}
34-
return newCompiledFunctionInternal(name, parameters, f.Functions, f.AllOverloads, f.IsOperator)
34+
return newCompiledFunctionInternal(name, parameters, f.Functions, f.AllOverloads, f.IsOperator, nil)
3535
}
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
// Copyright 2025 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package framework
16+
17+
import (
18+
"errors"
19+
"fmt"
20+
"strconv"
21+
"strings"
22+
23+
"github.com/dolthub/go-mysql-server/sql"
24+
"github.com/lib/pq"
25+
26+
"github.com/dolthub/doltgresql/core/id"
27+
pgtypes "github.com/dolthub/doltgresql/server/types"
28+
)
29+
30+
// InterpretedFunction is the implementation of functions created using PL/pgSQL. The created functions are converted to
31+
// a collection of operations, and an interpreter iterates over those operations to handle the logic.
32+
type InterpretedFunction struct {
33+
ID id.Function
34+
ReturnType *pgtypes.DoltgresType
35+
ParameterNames []string
36+
ParameterTypes []*pgtypes.DoltgresType
37+
Variadic bool
38+
IsNonDeterministic bool
39+
Strict bool
40+
Labels map[string]int
41+
Statements []InterpreterOperation
42+
}
43+
44+
var _ FunctionInterface = InterpretedFunction{}
45+
46+
// GetExpectedParameterCount implements the interface FunctionInterface.
47+
func (iFunc InterpretedFunction) GetExpectedParameterCount() int {
48+
return len(iFunc.ParameterTypes)
49+
}
50+
51+
// GetName implements the interface FunctionInterface.
52+
func (iFunc InterpretedFunction) GetName() string {
53+
return iFunc.ID.FunctionName()
54+
}
55+
56+
// GetParameters implements the interface FunctionInterface.
57+
func (iFunc InterpretedFunction) GetParameters() []*pgtypes.DoltgresType {
58+
return iFunc.ParameterTypes
59+
}
60+
61+
// GetReturn implements the interface FunctionInterface.
62+
func (iFunc InterpretedFunction) GetReturn() *pgtypes.DoltgresType {
63+
return iFunc.ReturnType
64+
}
65+
66+
// InternalID implements the interface FunctionInterface.
67+
func (iFunc InterpretedFunction) InternalID() id.Id {
68+
return iFunc.ID.AsId()
69+
}
70+
71+
// IsStrict implements the interface FunctionInterface.
72+
func (iFunc InterpretedFunction) IsStrict() bool {
73+
return iFunc.Strict
74+
}
75+
76+
// Return implements the interface plan.Interpreter.
77+
func (iFunc InterpretedFunction) Return(ctx *sql.Context) sql.Type {
78+
return iFunc.ReturnType
79+
}
80+
81+
// NonDeterministic implements the interface FunctionInterface.
82+
func (iFunc InterpretedFunction) NonDeterministic() bool {
83+
return iFunc.IsNonDeterministic
84+
}
85+
86+
// VariadicIndex implements the interface FunctionInterface.
87+
func (iFunc InterpretedFunction) VariadicIndex() int {
88+
// TODO: implement variadic
89+
return -1
90+
}
91+
92+
// querySingleReturn handles queries that are supposed to return a single value.
93+
func (iFunc InterpretedFunction) querySingleReturn(ctx *sql.Context, stack InterpreterStack, stmt string, targetType *pgtypes.DoltgresType, bindings []string) (val any, err error) {
94+
if len(bindings) > 0 {
95+
for i, bindingName := range bindings {
96+
variable := stack.GetVariable(bindingName)
97+
if variable == nil {
98+
return nil, fmt.Errorf("variable `%s` could not be found", bindingName)
99+
}
100+
formattedVar, err := variable.Type.FormatValue(variable.Value)
101+
if err != nil {
102+
return nil, err
103+
}
104+
switch variable.Type.TypCategory {
105+
case pgtypes.TypeCategory_ArrayTypes, pgtypes.TypeCategory_StringTypes:
106+
formattedVar = pq.QuoteLiteral(formattedVar)
107+
}
108+
stmt = strings.Replace(stmt, "$"+strconv.Itoa(i+1), formattedVar, 1)
109+
}
110+
}
111+
sch, rowIter, _, err := stack.runner.QueryWithBindings(ctx, stmt, nil, nil, nil)
112+
if err != nil {
113+
return nil, err
114+
}
115+
rows, err := sql.RowIterToRows(ctx, rowIter)
116+
if err != nil {
117+
return nil, err
118+
}
119+
if len(sch) != 1 {
120+
return nil, errors.New("expression does not result in a single value")
121+
}
122+
if len(rows) != 1 {
123+
return nil, errors.New("expression returned multiple result sets")
124+
}
125+
if len(rows[0]) != 1 {
126+
return nil, errors.New("expression returned multiple results")
127+
}
128+
if targetType == nil {
129+
return rows[0][0], nil
130+
}
131+
fromType, ok := sch[0].Type.(*pgtypes.DoltgresType)
132+
if !ok {
133+
fromType, err = pgtypes.FromGmsTypeToDoltgresType(sch[0].Type)
134+
if err != nil {
135+
return nil, err
136+
}
137+
}
138+
castFunc := GetAssignmentCast(fromType, targetType)
139+
if castFunc == nil {
140+
// TODO: try I/O casting
141+
return nil, errors.New("no valid cast for return value")
142+
}
143+
return castFunc(ctx, rows[0][0], targetType)
144+
}
145+
146+
// queryMultiReturn handles queries that may return multiple values over multiple rows.
147+
func (iFunc InterpretedFunction) queryMultiReturn(ctx *sql.Context, stack InterpreterStack, stmt string, bindings []string) (rowIter sql.RowIter, err error) {
148+
if len(bindings) > 0 {
149+
for i, bindingName := range bindings {
150+
variable := stack.GetVariable(bindingName)
151+
if variable == nil {
152+
return nil, fmt.Errorf("variable `%s` could not be found", bindingName)
153+
}
154+
formattedVar, err := variable.Type.FormatValue(variable.Value)
155+
if err != nil {
156+
return nil, err
157+
}
158+
switch variable.Type.TypCategory {
159+
case pgtypes.TypeCategory_ArrayTypes, pgtypes.TypeCategory_StringTypes:
160+
formattedVar = pq.QuoteLiteral(formattedVar)
161+
}
162+
stmt = strings.Replace(stmt, "$"+strconv.Itoa(i+1), formattedVar, 1)
163+
}
164+
}
165+
_, rowIter, _, err = stack.runner.QueryWithBindings(ctx, stmt, nil, nil, nil)
166+
return rowIter, err
167+
}
168+
169+
// enforceInterfaceInheritance implements the interface FunctionInterface.
170+
func (iFunc InterpretedFunction) enforceInterfaceInheritance(error) {}

0 commit comments

Comments
 (0)