@@ -100,25 +100,6 @@ func (iFunc InterpretedFunction) VariadicIndex() int {
100100 return - 1
101101}
102102
103- // We need to call QueryWithBindings. GMS's QueryWithBindings
104- // currently wraps the result iterator in a tracking iterator which
105- // calls ProcessList.EndQuery() on the ctx which is used to build the
106- // query after the iterator is closed.
107- //
108- // Here we hack to get a subcontext that will not cause the context
109- // associated with the top-level query to get canceled when it gets
110- // passed to ProcessList.EndQuery.
111- //
112- // TODO: Fix GMS to not do this.
113- func HackNewSubqueryContext (ctx * sql.Context ) * sql.Context {
114- res := * ctx
115- res .ApplyOpts (sql .WithPid (1 << 64 - 1 ))
116- if res .Pid () == ctx .Pid () {
117- panic ("pids matched when they shouldn't" )
118- }
119- return & res
120- }
121-
122103// QuerySingleReturn handles queries that are supposed to return a single value.
123104func (InterpretedFunction ) QuerySingleReturn (ctx * sql.Context , stack plpgsql.InterpreterStack , stmt string , targetType * pgtypes.DoltgresType , bindings []string ) (val any , err error ) {
124105 if len (bindings ) > 0 {
@@ -138,44 +119,45 @@ func (InterpretedFunction) QuerySingleReturn(ctx *sql.Context, stack plpgsql.Int
138119 stmt = strings .Replace (stmt , "$" + strconv .Itoa (i + 1 ), formattedVar , 1 )
139120 }
140121 }
141- subCtx := HackNewSubqueryContext (ctx )
142- sch , rowIter , _ , err := stack .Runner ().QueryWithBindings (subCtx , stmt , nil , nil , nil )
143- if err != nil {
144- return nil , err
145- }
146- rows , err := sql .RowIterToRows (ctx , rowIter )
147- if err != nil {
148- return nil , err
149- }
150- if len (sch ) != 1 {
151- return nil , errors .New ("expression does not result in a single value" )
152- }
153- if len (rows ) != 1 {
154- return nil , errors .New ("expression returned multiple result sets" )
155- }
156- if len (rows [0 ]) != 1 {
157- return nil , errors .New ("expression returned multiple results" )
158- }
159- if targetType == nil {
160- return rows [0 ][0 ], nil
161- }
162- fromType , ok := sch [0 ].Type .(* pgtypes.DoltgresType )
163- if ! ok {
164- fromType , err = pgtypes .FromGmsTypeToDoltgresType (sch [0 ].Type )
122+ return sql .RunInterpreted (ctx , func (subCtx * sql.Context ) (any , error ) {
123+ sch , rowIter , _ , err := stack .Runner ().QueryWithBindings (subCtx , stmt , nil , nil , nil )
165124 if err != nil {
166125 return nil , err
167126 }
168- }
169- castFunc := GetAssignmentCast (fromType , targetType )
170- if castFunc == nil {
171- // TODO: try I/O casting
172- return nil , errors .New ("no valid cast for return value" )
173- }
174- return castFunc (ctx , rows [0 ][0 ], targetType )
127+ rows , err := sql .RowIterToRows (subCtx , rowIter )
128+ if err != nil {
129+ return nil , err
130+ }
131+ if len (sch ) != 1 {
132+ return nil , errors .New ("expression does not result in a single value" )
133+ }
134+ if len (rows ) != 1 {
135+ return nil , errors .New ("expression returned multiple result sets" )
136+ }
137+ if len (rows [0 ]) != 1 {
138+ return nil , errors .New ("expression returned multiple results" )
139+ }
140+ if targetType == nil {
141+ return rows [0 ][0 ], nil
142+ }
143+ fromType , ok := sch [0 ].Type .(* pgtypes.DoltgresType )
144+ if ! ok {
145+ fromType , err = pgtypes .FromGmsTypeToDoltgresType (sch [0 ].Type )
146+ if err != nil {
147+ return nil , err
148+ }
149+ }
150+ castFunc := GetAssignmentCast (fromType , targetType )
151+ if castFunc == nil {
152+ // TODO: try I/O casting
153+ return nil , errors .New ("no valid cast for return value" )
154+ }
155+ return castFunc (subCtx , rows [0 ][0 ], targetType )
156+ })
175157}
176158
177159// QueryMultiReturn handles queries that may return multiple values over multiple rows.
178- func (InterpretedFunction ) QueryMultiReturn (ctx * sql.Context , stack plpgsql.InterpreterStack , stmt string , bindings []string ) (rowIter sql.RowIter , err error ) {
160+ func (InterpretedFunction ) QueryMultiReturn (ctx * sql.Context , stack plpgsql.InterpreterStack , stmt string , bindings []string ) (rows [] sql.Row , err error ) {
179161 if len (bindings ) > 0 {
180162 for i , bindingName := range bindings {
181163 variable := stack .GetVariable (bindingName )
@@ -193,9 +175,16 @@ func (InterpretedFunction) QueryMultiReturn(ctx *sql.Context, stack plpgsql.Inte
193175 stmt = strings .Replace (stmt , "$" + strconv .Itoa (i + 1 ), formattedVar , 1 )
194176 }
195177 }
196- subCtx := HackNewSubqueryContext (ctx )
197- _ , rowIter , _ , err = stack .Runner ().QueryWithBindings (subCtx , stmt , nil , nil , nil )
198- return rowIter , err
178+ return sql .RunInterpreted (ctx , func (subCtx * sql.Context ) ([]sql.Row , error ) {
179+ _ , rowIter , _ , err := stack .Runner ().QueryWithBindings (subCtx , stmt , nil , nil , nil )
180+ if err != nil {
181+ return nil , err
182+ }
183+ // TODO: we should come up with a good way of carrying the RowIter out of the function without needing to wrap
184+ // each call to QueryMultiReturn with RunInterpreted. For now, we don't check the returned rows, so this is
185+ // fine.
186+ return sql .RowIterToRows (subCtx , rowIter )
187+ })
199188}
200189
201190// enforceInterfaceInheritance implements the interface FunctionInterface.
0 commit comments