diff --git a/server/functions/framework/interpreted_function.go b/server/functions/framework/interpreted_function.go index 1f116936a3..9defc7b86e 100644 --- a/server/functions/framework/interpreted_function.go +++ b/server/functions/framework/interpreted_function.go @@ -100,25 +100,6 @@ func (iFunc InterpretedFunction) VariadicIndex() int { return -1 } -// We need to call QueryWithBindings. GMS's QueryWithBindings -// currently wraps the result iterator in a tracking iterator which -// calls ProcessList.EndQuery() on the ctx which is used to build the -// query after the iterator is closed. -// -// Here we hack to get a subcontext that will not cause the context -// associated with the top-level query to get canceled when it gets -// passed to ProcessList.EndQuery. -// -// TODO: Fix GMS to not do this. -func HackNewSubqueryContext(ctx *sql.Context) *sql.Context { - res := *ctx - res.ApplyOpts(sql.WithPid(1<<64 - 1)) - if res.Pid() == ctx.Pid() { - panic("pids matched when they shouldn't") - } - return &res -} - // QuerySingleReturn handles queries that are supposed to return a single value. func (InterpretedFunction) QuerySingleReturn(ctx *sql.Context, stack plpgsql.InterpreterStack, stmt string, targetType *pgtypes.DoltgresType, bindings []string) (val any, err error) { if len(bindings) > 0 { @@ -138,44 +119,45 @@ func (InterpretedFunction) QuerySingleReturn(ctx *sql.Context, stack plpgsql.Int stmt = strings.Replace(stmt, "$"+strconv.Itoa(i+1), formattedVar, 1) } } - subCtx := HackNewSubqueryContext(ctx) - sch, rowIter, _, err := stack.Runner().QueryWithBindings(subCtx, stmt, nil, nil, nil) - if err != nil { - return nil, err - } - rows, err := sql.RowIterToRows(ctx, rowIter) - if err != nil { - return nil, err - } - if len(sch) != 1 { - return nil, errors.New("expression does not result in a single value") - } - if len(rows) != 1 { - return nil, errors.New("expression returned multiple result sets") - } - if len(rows[0]) != 1 { - return nil, errors.New("expression returned multiple results") - } - if targetType == nil { - return rows[0][0], nil - } - fromType, ok := sch[0].Type.(*pgtypes.DoltgresType) - if !ok { - fromType, err = pgtypes.FromGmsTypeToDoltgresType(sch[0].Type) + return sql.RunInterpreted(ctx, func(subCtx *sql.Context) (any, error) { + sch, rowIter, _, err := stack.Runner().QueryWithBindings(subCtx, stmt, nil, nil, nil) if err != nil { return nil, err } - } - castFunc := GetAssignmentCast(fromType, targetType) - if castFunc == nil { - // TODO: try I/O casting - return nil, errors.New("no valid cast for return value") - } - return castFunc(ctx, rows[0][0], targetType) + rows, err := sql.RowIterToRows(subCtx, rowIter) + if err != nil { + return nil, err + } + if len(sch) != 1 { + return nil, errors.New("expression does not result in a single value") + } + if len(rows) != 1 { + return nil, errors.New("expression returned multiple result sets") + } + if len(rows[0]) != 1 { + return nil, errors.New("expression returned multiple results") + } + if targetType == nil { + return rows[0][0], nil + } + fromType, ok := sch[0].Type.(*pgtypes.DoltgresType) + if !ok { + fromType, err = pgtypes.FromGmsTypeToDoltgresType(sch[0].Type) + if err != nil { + return nil, err + } + } + castFunc := GetAssignmentCast(fromType, targetType) + if castFunc == nil { + // TODO: try I/O casting + return nil, errors.New("no valid cast for return value") + } + return castFunc(subCtx, rows[0][0], targetType) + }) } // QueryMultiReturn handles queries that may return multiple values over multiple rows. -func (InterpretedFunction) QueryMultiReturn(ctx *sql.Context, stack plpgsql.InterpreterStack, stmt string, bindings []string) (rowIter sql.RowIter, err error) { +func (InterpretedFunction) QueryMultiReturn(ctx *sql.Context, stack plpgsql.InterpreterStack, stmt string, bindings []string) (rows []sql.Row, err error) { if len(bindings) > 0 { for i, bindingName := range bindings { variable := stack.GetVariable(bindingName) @@ -193,9 +175,16 @@ func (InterpretedFunction) QueryMultiReturn(ctx *sql.Context, stack plpgsql.Inte stmt = strings.Replace(stmt, "$"+strconv.Itoa(i+1), formattedVar, 1) } } - subCtx := HackNewSubqueryContext(ctx) - _, rowIter, _, err = stack.Runner().QueryWithBindings(subCtx, stmt, nil, nil, nil) - return rowIter, err + return sql.RunInterpreted(ctx, func(subCtx *sql.Context) ([]sql.Row, error) { + _, rowIter, _, err := stack.Runner().QueryWithBindings(subCtx, stmt, nil, nil, nil) + if err != nil { + return nil, err + } + // TODO: we should come up with a good way of carrying the RowIter out of the function without needing to wrap + // each call to QueryMultiReturn with RunInterpreted. For now, we don't check the returned rows, so this is + // fine. + return sql.RowIterToRows(subCtx, rowIter) + }) } // enforceInterfaceInheritance implements the interface FunctionInterface. diff --git a/server/plpgsql/interpreter_logic.go b/server/plpgsql/interpreter_logic.go index 4366f9905d..634e799777 100644 --- a/server/plpgsql/interpreter_logic.go +++ b/server/plpgsql/interpreter_logic.go @@ -37,7 +37,7 @@ type InterpretedFunction interface { GetParameterNames() []string GetReturn() *pgtypes.DoltgresType GetStatements() []InterpreterOperation - QueryMultiReturn(ctx *sql.Context, stack InterpreterStack, stmt string, bindings []string) (rowIter sql.RowIter, err error) + QueryMultiReturn(ctx *sql.Context, stack InterpreterStack, stmt string, bindings []string) (rows []sql.Row, err error) QuerySingleReturn(ctx *sql.Context, stack InterpreterStack, stmt string, targetType *pgtypes.DoltgresType, bindings []string) (val any, err error) } @@ -141,13 +141,10 @@ func Call(ctx *sql.Context, iFunc InterpretedFunction, runner analyzer.Statement return nil, err } } else { - rowIter, err := iFunc.QueryMultiReturn(ctx, stack, operation.PrimaryData, operation.SecondaryData) + _, err := iFunc.QueryMultiReturn(ctx, stack, operation.PrimaryData, operation.SecondaryData) if err != nil { return nil, err } - if _, err = sql.RowIterToRows(ctx, rowIter); err != nil { - return nil, err - } } case OpCode_Get: // TODO: implement @@ -185,13 +182,10 @@ func Call(ctx *sql.Context, iFunc InterpretedFunction, runner analyzer.Statement case OpCode_InsertInto: // TODO: implement case OpCode_Perform: - rowIter, err := iFunc.QueryMultiReturn(ctx, stack, operation.PrimaryData, operation.SecondaryData) + _, err := iFunc.QueryMultiReturn(ctx, stack, operation.PrimaryData, operation.SecondaryData) if err != nil { return nil, err } - if _, err = sql.RowIterToRows(ctx, rowIter); err != nil { - return nil, err - } case OpCode_Raise: // TODO: Use the client_min_messages config param to determine which // notice levels to send to the client. diff --git a/testing/go/create_function_test.go b/testing/go/create_function_test.go index 70f1ae1b48..819ef153be 100644 --- a/testing/go/create_function_test.go +++ b/testing/go/create_function_test.go @@ -808,5 +808,33 @@ $$ LANGUAGE plpgsql;`, }, }, }, + { + Name: "INSERT values from function", + SetUpScript: []string{ + "CREATE TABLE test (v1 TEXT);", + `CREATE FUNCTION insertion_text() RETURNS TEXT AS $$ + DECLARE + var1 TEXT; + BEGIN + var1 := 'example'; + RETURN var1; + END; + $$ LANGUAGE plpgsql; + `, + }, + Assertions: []ScriptTestAssertion{ + { + Query: "INSERT INTO test VALUES (insertion_text()), (insertion_text());", + Expected: []sql.Row{}, + }, + { + Query: "SELECT * FROM test;", + Expected: []sql.Row{ + {"example"}, + {"example"}, + }, + }, + }, + }, }) }