@@ -21,8 +21,8 @@ import (
2121 "errors"
2222 "fmt"
2323 "reflect"
24+ "regexp"
2425 "strconv"
25- "strings"
2626
2727 "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
2828 crossdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/contract/database"
@@ -34,6 +34,8 @@ import (
3434 "github.com/coze-dev/coze-studio/backend/pkg/sonic"
3535)
3636
37+ var singleQuotesStringRegexp = regexp .MustCompile ("[`']\\ {\\ {([a-zA-Z_][a-zA-Z0-9_]*(?:\\ .\\ w+|\\ [\\ d+\\ ])*)+\\ }\\ }[`']" )
38+
3739type CustomSQLConfig struct {
3840 DatabaseInfoID int64
3941 SQLTemplate string
@@ -111,47 +113,60 @@ func (c *CustomSQL) Invoke(ctx context.Context, input map[string]any) (map[strin
111113 return nil , err
112114 }
113115
116+ templateParts := nodes .ParseTemplate (singleQuotesStringRegexp .ReplaceAllString (c .sqlTemplate , "?" ))
114117 templateSQL := ""
115- templateParts := nodes .ParseTemplate (c .sqlTemplate )
116- sqlParams := make ([]database.SQLParam , 0 , len (templateParts ))
117- var nilError = errors .New ("field is nil" )
118- for _ , templatePart := range templateParts {
119- if ! templatePart .IsVariable {
120- templateSQL += templatePart .Value
121- continue
122- }
118+ if len (templateParts ) > 0 {
119+ if len (templateParts ) == 0 {
120+ templateSQL = templateParts [0 ].Value
121+ } else {
122+ for _ , templatePart := range templateParts {
123+ if ! templatePart .IsVariable {
124+ templateSQL += templatePart .Value
125+ continue
126+ }
123127
124- templateSQL += "?"
125- val , err := templatePart .Render (inputBytes , nodes .WithNilRender (func () (string , error ) {
126- return "" , nilError
127- }),
128- nodes .WithCustomRender (reflect .TypeOf (false ), func (val any ) (string , error ) {
129- b := val .(bool )
130- if b {
131- return "1" , nil
128+ val , err := templatePart .Render (inputBytes , nodes .WithCustomRender (reflect .TypeOf (false ), func (val any ) (string , error ) {
129+ b := val .(bool )
130+ if b {
131+ return "1" , nil
132+ }
133+ return "0" , nil
134+ }))
135+ if err != nil {
136+ return nil , err
132137 }
133- return "0" , nil
134- }))
138+ templateSQL += val
135139
136- if err != nil {
137- if ! errors .Is (err , nilError ) {
138- return nil , err
139140 }
140- sqlParams = append (sqlParams , database.SQLParam {
141- IsNull : true ,
142- })
143- } else {
144- sqlParams = append (sqlParams , database.SQLParam {
145- Value : val ,
146- IsNull : false ,
147- })
148141 }
149142
143+ } else {
144+ return nil , fmt .Errorf ("parse template invalid" )
145+ }
146+
147+ sqlParamStrings := singleQuotesStringRegexp .FindAllString (c .sqlTemplate , - 1 )
148+ sqlParams := make ([]database.SQLParam , 0 , len (sqlParamStrings ))
149+ for _ , s := range sqlParamStrings {
150+ parts := nodes .ParseTemplate (s )
151+ for _ , part := range parts {
152+ if part .IsVariable {
153+ val , err := part .Render (inputBytes , nodes .WithCustomRender (reflect .TypeOf (false ), func (val any ) (string , error ) {
154+ b := val .(bool )
155+ if b {
156+ return "1" , nil
157+ }
158+ return "0" , nil
159+ }))
160+ if err != nil {
161+ return nil , err
162+ }
163+ sqlParams = append (sqlParams , database.SQLParam {
164+ Value : val ,
165+ })
166+ }
167+ }
150168 }
151169
152- // replace sql template '?' to ?
153- templateSQL = strings .Replace (templateSQL , "'?'" , "?" , - 1 )
154- templateSQL = strings .Replace (templateSQL , "`?`" , "?" , - 1 )
155170 req .SQL = templateSQL
156171 req .Params = sqlParams
157172 response , err := crossdatabase .DefaultSVC ().Execute (ctx , req )
0 commit comments