Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 43 additions & 54 deletions server/functions/framework/interpreted_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,25 +100,6 @@ func (iFunc InterpretedFunction) VariadicIndex() int {
return -1
}

// We need to call QueryWithBindings. GMS's QueryWithBindings
// currently wraps the result iterator in a tracking iterator which
// calls ProcessList.EndQuery() on the ctx which is used to build the
// query after the iterator is closed.
//
// Here we hack to get a subcontext that will not cause the context
// associated with the top-level query to get canceled when it gets
// passed to ProcessList.EndQuery.
//
// TODO: Fix GMS to not do this.
func HackNewSubqueryContext(ctx *sql.Context) *sql.Context {
res := *ctx
res.ApplyOpts(sql.WithPid(1<<64 - 1))
if res.Pid() == ctx.Pid() {
panic("pids matched when they shouldn't")
}
return &res
}

// QuerySingleReturn handles queries that are supposed to return a single value.
func (InterpretedFunction) QuerySingleReturn(ctx *sql.Context, stack plpgsql.InterpreterStack, stmt string, targetType *pgtypes.DoltgresType, bindings []string) (val any, err error) {
if len(bindings) > 0 {
Expand All @@ -138,44 +119,45 @@ func (InterpretedFunction) QuerySingleReturn(ctx *sql.Context, stack plpgsql.Int
stmt = strings.Replace(stmt, "$"+strconv.Itoa(i+1), formattedVar, 1)
}
}
subCtx := HackNewSubqueryContext(ctx)
sch, rowIter, _, err := stack.Runner().QueryWithBindings(subCtx, stmt, nil, nil, nil)
if err != nil {
return nil, err
}
rows, err := sql.RowIterToRows(ctx, rowIter)
if err != nil {
return nil, err
}
if len(sch) != 1 {
return nil, errors.New("expression does not result in a single value")
}
if len(rows) != 1 {
return nil, errors.New("expression returned multiple result sets")
}
if len(rows[0]) != 1 {
return nil, errors.New("expression returned multiple results")
}
if targetType == nil {
return rows[0][0], nil
}
fromType, ok := sch[0].Type.(*pgtypes.DoltgresType)
if !ok {
fromType, err = pgtypes.FromGmsTypeToDoltgresType(sch[0].Type)
return sql.RunInterpreted(ctx, func(subCtx *sql.Context) (any, error) {
sch, rowIter, _, err := stack.Runner().QueryWithBindings(subCtx, stmt, nil, nil, nil)
if err != nil {
return nil, err
}
}
castFunc := GetAssignmentCast(fromType, targetType)
if castFunc == nil {
// TODO: try I/O casting
return nil, errors.New("no valid cast for return value")
}
return castFunc(ctx, rows[0][0], targetType)
rows, err := sql.RowIterToRows(subCtx, rowIter)
if err != nil {
return nil, err
}
if len(sch) != 1 {
return nil, errors.New("expression does not result in a single value")
}
if len(rows) != 1 {
return nil, errors.New("expression returned multiple result sets")
}
if len(rows[0]) != 1 {
return nil, errors.New("expression returned multiple results")
}
if targetType == nil {
return rows[0][0], nil
}
fromType, ok := sch[0].Type.(*pgtypes.DoltgresType)
if !ok {
fromType, err = pgtypes.FromGmsTypeToDoltgresType(sch[0].Type)
if err != nil {
return nil, err
}
}
castFunc := GetAssignmentCast(fromType, targetType)
if castFunc == nil {
// TODO: try I/O casting
return nil, errors.New("no valid cast for return value")
}
return castFunc(subCtx, rows[0][0], targetType)
})
}

// QueryMultiReturn handles queries that may return multiple values over multiple rows.
func (InterpretedFunction) QueryMultiReturn(ctx *sql.Context, stack plpgsql.InterpreterStack, stmt string, bindings []string) (rowIter sql.RowIter, err error) {
func (InterpretedFunction) QueryMultiReturn(ctx *sql.Context, stack plpgsql.InterpreterStack, stmt string, bindings []string) (rows []sql.Row, err error) {
if len(bindings) > 0 {
for i, bindingName := range bindings {
variable := stack.GetVariable(bindingName)
Expand All @@ -193,9 +175,16 @@ func (InterpretedFunction) QueryMultiReturn(ctx *sql.Context, stack plpgsql.Inte
stmt = strings.Replace(stmt, "$"+strconv.Itoa(i+1), formattedVar, 1)
}
}
subCtx := HackNewSubqueryContext(ctx)
_, rowIter, _, err = stack.Runner().QueryWithBindings(subCtx, stmt, nil, nil, nil)
return rowIter, err
return sql.RunInterpreted(ctx, func(subCtx *sql.Context) ([]sql.Row, error) {
_, rowIter, _, err := stack.Runner().QueryWithBindings(subCtx, stmt, nil, nil, nil)
if err != nil {
return nil, err
}
// TODO: we should come up with a good way of carrying the RowIter out of the function without needing to wrap
// each call to QueryMultiReturn with RunInterpreted. For now, we don't check the returned rows, so this is
// fine.
return sql.RowIterToRows(subCtx, rowIter)
})
}

// enforceInterfaceInheritance implements the interface FunctionInterface.
Expand Down
12 changes: 3 additions & 9 deletions server/plpgsql/interpreter_logic.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ type InterpretedFunction interface {
GetParameterNames() []string
GetReturn() *pgtypes.DoltgresType
GetStatements() []InterpreterOperation
QueryMultiReturn(ctx *sql.Context, stack InterpreterStack, stmt string, bindings []string) (rowIter sql.RowIter, err error)
QueryMultiReturn(ctx *sql.Context, stack InterpreterStack, stmt string, bindings []string) (rows []sql.Row, err error)
QuerySingleReturn(ctx *sql.Context, stack InterpreterStack, stmt string, targetType *pgtypes.DoltgresType, bindings []string) (val any, err error)
}

Expand Down Expand Up @@ -141,13 +141,10 @@ func Call(ctx *sql.Context, iFunc InterpretedFunction, runner analyzer.Statement
return nil, err
}
} else {
rowIter, err := iFunc.QueryMultiReturn(ctx, stack, operation.PrimaryData, operation.SecondaryData)
_, err := iFunc.QueryMultiReturn(ctx, stack, operation.PrimaryData, operation.SecondaryData)
if err != nil {
return nil, err
}
if _, err = sql.RowIterToRows(ctx, rowIter); err != nil {
return nil, err
}
}
case OpCode_Get:
// TODO: implement
Expand Down Expand Up @@ -185,13 +182,10 @@ func Call(ctx *sql.Context, iFunc InterpretedFunction, runner analyzer.Statement
case OpCode_InsertInto:
// TODO: implement
case OpCode_Perform:
rowIter, err := iFunc.QueryMultiReturn(ctx, stack, operation.PrimaryData, operation.SecondaryData)
_, err := iFunc.QueryMultiReturn(ctx, stack, operation.PrimaryData, operation.SecondaryData)
if err != nil {
return nil, err
}
if _, err = sql.RowIterToRows(ctx, rowIter); err != nil {
return nil, err
}
case OpCode_Raise:
// TODO: Use the client_min_messages config param to determine which
// notice levels to send to the client.
Expand Down
28 changes: 28 additions & 0 deletions testing/go/create_function_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -808,5 +808,33 @@ $$ LANGUAGE plpgsql;`,
},
},
},
{
Name: "INSERT values from function",
SetUpScript: []string{
"CREATE TABLE test (v1 TEXT);",
`CREATE FUNCTION insertion_text() RETURNS TEXT AS $$
DECLARE
var1 TEXT;
BEGIN
var1 := 'example';
RETURN var1;
END;
$$ LANGUAGE plpgsql;
`,
},
Assertions: []ScriptTestAssertion{
{
Query: "INSERT INTO test VALUES (insertion_text()), (insertion_text());",
Expected: []sql.Row{},
},
{
Query: "SELECT * FROM test;",
Expected: []sql.Row{
{"example"},
{"example"},
},
},
},
},
})
}