Skip to content

Commit 32e6835

Browse files
committed
fix: e where SQL parsing errors occurred in custom database node
1 parent 2ffd7a8 commit 32e6835

File tree

2 files changed

+39
-29
lines changed

2 files changed

+39
-29
lines changed

backend/domain/workflow/internal/nodes/database/customsql.go

Lines changed: 38 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"errors"
2222
"fmt"
2323
"reflect"
24+
"regexp"
2425
"strconv"
2526
"strings"
2627

@@ -34,6 +35,8 @@ import (
3435
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
3536
)
3637

38+
var singleQuotesStringRegexp = regexp.MustCompile("[`']\\{\\{([a-zA-Z_][a-zA-Z0-9_]*(?:\\.\\w+|\\[\\d+\\])*)+\\}\\}[`']")
39+
3740
type CustomSQLConfig struct {
3841
DatabaseInfoID int64
3942
SQLTemplate string
@@ -111,46 +114,54 @@ func (c *CustomSQL) Invoke(ctx context.Context, input map[string]any) (map[strin
111114
return nil, err
112115
}
113116

117+
templateParts := nodes.ParseTemplate(singleQuotesStringRegexp.ReplaceAllString(c.sqlTemplate, "?"))
114118
templateSQL := ""
115-
templateParts := nodes.ParseTemplate(c.sqlTemplate)
116-
sqlParams := make([]database.SQLParam, 0, len(templateParts))
117-
var nilError = errors.New("field is nil")
118-
for _, templatePart := range templateParts {
119-
if !templatePart.IsVariable {
120-
templateSQL += templatePart.Value
121-
continue
122-
}
119+
if len(templateParts) > 0 {
120+
for _, templatePart := range templateParts {
121+
if !templatePart.IsVariable {
122+
templateSQL += templatePart.Value
123+
continue
124+
}
123125

124-
templateSQL += "?"
125-
val, err := templatePart.Render(inputBytes, nodes.WithNilRender(func() (string, error) {
126-
return "", nilError
127-
}),
128-
nodes.WithCustomRender(reflect.TypeOf(false), func(val any) (string, error) {
126+
val, err := templatePart.Render(inputBytes, nodes.WithCustomRender(reflect.TypeOf(false), func(val any) (string, error) {
129127
b := val.(bool)
130128
if b {
131129
return "1", nil
132130
}
133131
return "0", nil
134132
}))
135-
136-
if err != nil {
137-
if !errors.Is(err, nilError) {
133+
if err != nil {
138134
return nil, err
139135
}
140-
sqlParams = append(sqlParams, database.SQLParam{
141-
IsNull: true,
142-
})
143-
} else {
144-
sqlParams = append(sqlParams, database.SQLParam{
145-
Value: val,
146-
IsNull: false,
147-
})
148-
}
136+
templateSQL += val
149137

138+
}
139+
} else {
140+
templateSQL += templateParts[0].Value
150141
}
151142

152-
// replace sql template '?' to ?
153-
templateSQL = strings.Replace(templateSQL, "'?'", "?", -1)
143+
sqlParamStrings := singleQuotesStringRegexp.FindAllString(c.sqlTemplate, -1)
144+
sqlParams := make([]database.SQLParam, 0, len(sqlParamStrings))
145+
for _, s := range sqlParamStrings {
146+
parts := nodes.ParseTemplate(s)
147+
for _, part := range parts {
148+
if part.IsVariable {
149+
val, err := part.Render(inputBytes, nodes.WithCustomRender(reflect.TypeOf(false), func(val any) (string, error) {
150+
b := val.(bool)
151+
if b {
152+
return "1", nil
153+
}
154+
return "0", nil
155+
}))
156+
if err != nil {
157+
return nil, err
158+
}
159+
sqlParams = append(sqlParams, database.SQLParam{
160+
Value: val,
161+
})
162+
}
163+
}
164+
}
154165
templateSQL = strings.Replace(templateSQL, "`?`", "?", -1)
155166
req.SQL = templateSQL
156167
req.Params = sqlParams

backend/domain/workflow/internal/nodes/database/customsql_test.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,11 @@ func TestCustomSQL_Execute(t *testing.T) {
6161
validate: func(req *database.CustomSQLRequest) {
6262
assert.Equal(t, int64(111), req.DatabaseInfoID)
6363
ps := []database.SQLParam{
64-
{Value: "v1_value"},
6564
{Value: "v2_value"},
6665
{Value: "v3_value"},
6766
}
6867
assert.Equal(t, ps, req.Params)
69-
assert.Equal(t, "select * from v1 where v1 = ? and v2 = ? and v3 = ?", req.SQL)
68+
assert.Equal(t, "select * from v1 where v1 = v1_value and v2 = ? and v3 = ?", req.SQL)
7069
},
7170
}
7271

0 commit comments

Comments
 (0)