Skip to content

Commit c263c7e

Browse files
committed
Fixed interpretation calls within DML
1 parent a794bac commit c263c7e

File tree

3 files changed

+74
-63
lines changed

3 files changed

+74
-63
lines changed

server/functions/framework/interpreted_function.go

Lines changed: 43 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -100,25 +100,6 @@ func (iFunc InterpretedFunction) VariadicIndex() int {
100100
return -1
101101
}
102102

103-
// We need to call QueryWithBindings. GMS's QueryWithBindings
104-
// currently wraps the result iterator in a tracking iterator which
105-
// calls ProcessList.EndQuery() on the ctx which is used to build the
106-
// query after the iterator is closed.
107-
//
108-
// Here we hack to get a subcontext that will not cause the context
109-
// associated with the top-level query to get canceled when it gets
110-
// passed to ProcessList.EndQuery.
111-
//
112-
// TODO: Fix GMS to not do this.
113-
func HackNewSubqueryContext(ctx *sql.Context) *sql.Context {
114-
res := *ctx
115-
res.ApplyOpts(sql.WithPid(1<<64 - 1))
116-
if res.Pid() == ctx.Pid() {
117-
panic("pids matched when they shouldn't")
118-
}
119-
return &res
120-
}
121-
122103
// QuerySingleReturn handles queries that are supposed to return a single value.
123104
func (InterpretedFunction) QuerySingleReturn(ctx *sql.Context, stack plpgsql.InterpreterStack, stmt string, targetType *pgtypes.DoltgresType, bindings []string) (val any, err error) {
124105
if len(bindings) > 0 {
@@ -138,44 +119,45 @@ func (InterpretedFunction) QuerySingleReturn(ctx *sql.Context, stack plpgsql.Int
138119
stmt = strings.Replace(stmt, "$"+strconv.Itoa(i+1), formattedVar, 1)
139120
}
140121
}
141-
subCtx := HackNewSubqueryContext(ctx)
142-
sch, rowIter, _, err := stack.Runner().QueryWithBindings(subCtx, stmt, nil, nil, nil)
143-
if err != nil {
144-
return nil, err
145-
}
146-
rows, err := sql.RowIterToRows(ctx, rowIter)
147-
if err != nil {
148-
return nil, err
149-
}
150-
if len(sch) != 1 {
151-
return nil, errors.New("expression does not result in a single value")
152-
}
153-
if len(rows) != 1 {
154-
return nil, errors.New("expression returned multiple result sets")
155-
}
156-
if len(rows[0]) != 1 {
157-
return nil, errors.New("expression returned multiple results")
158-
}
159-
if targetType == nil {
160-
return rows[0][0], nil
161-
}
162-
fromType, ok := sch[0].Type.(*pgtypes.DoltgresType)
163-
if !ok {
164-
fromType, err = pgtypes.FromGmsTypeToDoltgresType(sch[0].Type)
122+
return sql.RunInterpreted(ctx, func(subCtx *sql.Context) (any, error) {
123+
sch, rowIter, _, err := stack.Runner().QueryWithBindings(subCtx, stmt, nil, nil, nil)
165124
if err != nil {
166125
return nil, err
167126
}
168-
}
169-
castFunc := GetAssignmentCast(fromType, targetType)
170-
if castFunc == nil {
171-
// TODO: try I/O casting
172-
return nil, errors.New("no valid cast for return value")
173-
}
174-
return castFunc(ctx, rows[0][0], targetType)
127+
rows, err := sql.RowIterToRows(subCtx, rowIter)
128+
if err != nil {
129+
return nil, err
130+
}
131+
if len(sch) != 1 {
132+
return nil, errors.New("expression does not result in a single value")
133+
}
134+
if len(rows) != 1 {
135+
return nil, errors.New("expression returned multiple result sets")
136+
}
137+
if len(rows[0]) != 1 {
138+
return nil, errors.New("expression returned multiple results")
139+
}
140+
if targetType == nil {
141+
return rows[0][0], nil
142+
}
143+
fromType, ok := sch[0].Type.(*pgtypes.DoltgresType)
144+
if !ok {
145+
fromType, err = pgtypes.FromGmsTypeToDoltgresType(sch[0].Type)
146+
if err != nil {
147+
return nil, err
148+
}
149+
}
150+
castFunc := GetAssignmentCast(fromType, targetType)
151+
if castFunc == nil {
152+
// TODO: try I/O casting
153+
return nil, errors.New("no valid cast for return value")
154+
}
155+
return castFunc(subCtx, rows[0][0], targetType)
156+
})
175157
}
176158

177159
// QueryMultiReturn handles queries that may return multiple values over multiple rows.
178-
func (InterpretedFunction) QueryMultiReturn(ctx *sql.Context, stack plpgsql.InterpreterStack, stmt string, bindings []string) (rowIter sql.RowIter, err error) {
160+
func (InterpretedFunction) QueryMultiReturn(ctx *sql.Context, stack plpgsql.InterpreterStack, stmt string, bindings []string) (rows []sql.Row, err error) {
179161
if len(bindings) > 0 {
180162
for i, bindingName := range bindings {
181163
variable := stack.GetVariable(bindingName)
@@ -193,9 +175,16 @@ func (InterpretedFunction) QueryMultiReturn(ctx *sql.Context, stack plpgsql.Inte
193175
stmt = strings.Replace(stmt, "$"+strconv.Itoa(i+1), formattedVar, 1)
194176
}
195177
}
196-
subCtx := HackNewSubqueryContext(ctx)
197-
_, rowIter, _, err = stack.Runner().QueryWithBindings(subCtx, stmt, nil, nil, nil)
198-
return rowIter, err
178+
return sql.RunInterpreted(ctx, func(subCtx *sql.Context) ([]sql.Row, error) {
179+
_, rowIter, _, err := stack.Runner().QueryWithBindings(subCtx, stmt, nil, nil, nil)
180+
if err != nil {
181+
return nil, err
182+
}
183+
// TODO: we should come up with a good way of carrying the RowIter out of the function without needing to wrap
184+
// each call to QueryMultiReturn with RunInterpreted. For now, we don't check the returned rows, so this is
185+
// fine.
186+
return sql.RowIterToRows(subCtx, rowIter)
187+
})
199188
}
200189

201190
// enforceInterfaceInheritance implements the interface FunctionInterface.

server/plpgsql/interpreter_logic.go

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ type InterpretedFunction interface {
3737
GetParameterNames() []string
3838
GetReturn() *pgtypes.DoltgresType
3939
GetStatements() []InterpreterOperation
40-
QueryMultiReturn(ctx *sql.Context, stack InterpreterStack, stmt string, bindings []string) (rowIter sql.RowIter, err error)
40+
QueryMultiReturn(ctx *sql.Context, stack InterpreterStack, stmt string, bindings []string) (rows []sql.Row, err error)
4141
QuerySingleReturn(ctx *sql.Context, stack InterpreterStack, stmt string, targetType *pgtypes.DoltgresType, bindings []string) (val any, err error)
4242
}
4343

@@ -141,13 +141,10 @@ func Call(ctx *sql.Context, iFunc InterpretedFunction, runner analyzer.Statement
141141
return nil, err
142142
}
143143
} else {
144-
rowIter, err := iFunc.QueryMultiReturn(ctx, stack, operation.PrimaryData, operation.SecondaryData)
144+
_, err := iFunc.QueryMultiReturn(ctx, stack, operation.PrimaryData, operation.SecondaryData)
145145
if err != nil {
146146
return nil, err
147147
}
148-
if _, err = sql.RowIterToRows(ctx, rowIter); err != nil {
149-
return nil, err
150-
}
151148
}
152149
case OpCode_Get:
153150
// TODO: implement
@@ -185,13 +182,10 @@ func Call(ctx *sql.Context, iFunc InterpretedFunction, runner analyzer.Statement
185182
case OpCode_InsertInto:
186183
// TODO: implement
187184
case OpCode_Perform:
188-
rowIter, err := iFunc.QueryMultiReturn(ctx, stack, operation.PrimaryData, operation.SecondaryData)
185+
_, err := iFunc.QueryMultiReturn(ctx, stack, operation.PrimaryData, operation.SecondaryData)
189186
if err != nil {
190187
return nil, err
191188
}
192-
if _, err = sql.RowIterToRows(ctx, rowIter); err != nil {
193-
return nil, err
194-
}
195189
case OpCode_Raise:
196190
// TODO: Use the client_min_messages config param to determine which
197191
// notice levels to send to the client.

testing/go/create_function_test.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -808,5 +808,33 @@ $$ LANGUAGE plpgsql;`,
808808
},
809809
},
810810
},
811+
{
812+
Name: "INSERT values from function",
813+
SetUpScript: []string{
814+
"CREATE TABLE test (v1 TEXT);",
815+
`CREATE FUNCTION insertion_text() RETURNS TEXT AS $$
816+
DECLARE
817+
var1 TEXT;
818+
BEGIN
819+
var1 := 'example';
820+
RETURN var1;
821+
END;
822+
$$ LANGUAGE plpgsql;
823+
`,
824+
},
825+
Assertions: []ScriptTestAssertion{
826+
{
827+
Query: "INSERT INTO test VALUES (insertion_text()), (insertion_text());",
828+
Expected: []sql.Row{},
829+
},
830+
{
831+
Query: "SELECT * FROM test;",
832+
Expected: []sql.Row{
833+
{"example"},
834+
{"example"},
835+
},
836+
},
837+
},
838+
},
811839
})
812840
}

0 commit comments

Comments
 (0)