@@ -102,23 +102,10 @@ func (iFunc InterpretedFunction) VariadicIndex() int {
102102}
103103
104104// QuerySingleReturn handles queries that are supposed to return a single value.
105- func (InterpretedFunction ) QuerySingleReturn (ctx * sql.Context , stack plpgsql.InterpreterStack , stmt string , targetType * pgtypes.DoltgresType , bindings []string ) (val any , err error ) {
106- if len (bindings ) > 0 {
107- for i , bindingName := range bindings {
108- variable := stack .GetVariable (bindingName )
109- if variable .Type == nil {
110- return nil , fmt .Errorf ("variable `%s` could not be found" , bindingName )
111- }
112- formattedVar , err := variable .Type .FormatValue (* variable .Value )
113- if err != nil {
114- return nil , err
115- }
116- switch variable .Type .TypCategory {
117- case pgtypes .TypeCategory_ArrayTypes , pgtypes .TypeCategory_StringTypes :
118- formattedVar = pq .QuoteLiteral (formattedVar )
119- }
120- stmt = strings .Replace (stmt , "$" + strconv .Itoa (i + 1 ), formattedVar , 1 )
121- }
105+ func (iFunc InterpretedFunction ) QuerySingleReturn (ctx * sql.Context , stack plpgsql.InterpreterStack , stmt string , targetType * pgtypes.DoltgresType , bindings []string ) (val any , err error ) {
106+ stmt , _ , err = iFunc .ApplyBindings (ctx , stack , stmt , bindings , true )
107+ if err != nil {
108+ return nil , err
122109 }
123110 return sql .RunInterpreted (ctx , func (subCtx * sql.Context ) (any , error ) {
124111 sch , rowIter , _ , err := stack .Runner ().QueryWithBindings (subCtx , stmt , nil , nil , nil )
@@ -177,23 +164,10 @@ func (InterpretedFunction) QuerySingleReturn(ctx *sql.Context, stack plpgsql.Int
177164}
178165
179166// QueryMultiReturn handles queries that may return multiple values over multiple rows.
180- func (InterpretedFunction ) QueryMultiReturn (ctx * sql.Context , stack plpgsql.InterpreterStack , stmt string , bindings []string ) (rows []sql.Row , err error ) {
181- if len (bindings ) > 0 {
182- for i , bindingName := range bindings {
183- variable := stack .GetVariable (bindingName )
184- if variable .Type == nil {
185- return nil , fmt .Errorf ("variable `%s` could not be found" , bindingName )
186- }
187- formattedVar , err := variable .Type .FormatValue (* variable .Value )
188- if err != nil {
189- return nil , err
190- }
191- switch variable .Type .TypCategory {
192- case pgtypes .TypeCategory_ArrayTypes , pgtypes .TypeCategory_StringTypes :
193- formattedVar = pq .QuoteLiteral (formattedVar )
194- }
195- stmt = strings .Replace (stmt , "$" + strconv .Itoa (i + 1 ), formattedVar , 1 )
196- }
167+ func (iFunc InterpretedFunction ) QueryMultiReturn (ctx * sql.Context , stack plpgsql.InterpreterStack , stmt string , bindings []string ) (rows []sql.Row , err error ) {
168+ stmt , _ , err = iFunc .ApplyBindings (ctx , stack , stmt , bindings , true )
169+ if err != nil {
170+ return nil , err
197171 }
198172 return sql .RunInterpreted (ctx , func (subCtx * sql.Context ) ([]sql.Row , error ) {
199173 _ , rowIter , _ , err := stack .Runner ().QueryWithBindings (subCtx , stmt , nil , nil , nil )
@@ -207,5 +181,43 @@ func (InterpretedFunction) QueryMultiReturn(ctx *sql.Context, stack plpgsql.Inte
207181 })
208182}
209183
184+ // ApplyBindings applies the given bindings to the statement. If `varFound` is false, then the error will be state that
185+ // the variable was not found (which means the error may be ignored if you're only concerned with finding a variable).
186+ // If `varFound` is true, then the error is related to formatting the variable. `enforceType` adds casting and quotes to
187+ // ensure that the value is correctly represented in the string.
188+ func (InterpretedFunction ) ApplyBindings (ctx * sql.Context , stack plpgsql.InterpreterStack , stmt string , bindings []string , enforceType bool ) (newStmt string , varFound bool , err error ) {
189+ if len (bindings ) == 0 {
190+ return stmt , false , nil
191+ }
192+ newStmt = stmt
193+ for i , bindingName := range bindings {
194+ variable := stack .GetVariable (bindingName )
195+ if variable .Type == nil {
196+ return newStmt , false , fmt .Errorf ("variable `%s` could not be found" , bindingName )
197+ }
198+ var formattedVar string
199+ if * variable .Value != nil {
200+ formattedVar , err = variable .Type .FormatValue (* variable .Value )
201+ if err != nil {
202+ return newStmt , true , err
203+ }
204+ if enforceType {
205+ switch variable .Type .TypCategory {
206+ case pgtypes .TypeCategory_ArrayTypes , pgtypes .TypeCategory_DateTimeTypes , pgtypes .TypeCategory_StringTypes :
207+ formattedVar = pq .QuoteLiteral (formattedVar )
208+ }
209+ }
210+ } else {
211+ formattedVar = "NULL"
212+ }
213+ if enforceType {
214+ newStmt = strings .Replace (newStmt , "$" + strconv .Itoa (i + 1 ), fmt .Sprintf (`(%s)::%s` , formattedVar , variable .Type .String ()), 1 )
215+ } else {
216+ newStmt = strings .Replace (newStmt , "$" + strconv .Itoa (i + 1 ), formattedVar , 1 )
217+ }
218+ }
219+ return newStmt , true , nil
220+ }
221+
210222// enforceInterfaceInheritance implements the interface FunctionInterface.
211223func (iFunc InterpretedFunction ) enforceInterfaceInheritance (error ) {}
0 commit comments