Skip to content

Commit a76d6cf

Browse files
committed
fix: e where SQL parsing errors occurred in custom database node
1 parent 2ffd7a8 commit a76d6cf

File tree

2 files changed

+52
-36
lines changed

2 files changed

+52
-36
lines changed

backend/domain/workflow/internal/nodes/database/customsql.go

Lines changed: 48 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ import (
2121
"errors"
2222
"fmt"
2323
"reflect"
24+
"regexp"
2425
"strconv"
25-
"strings"
2626

2727
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
2828
crossdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/contract/database"
@@ -34,6 +34,8 @@ import (
3434
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
3535
)
3636

37+
var singleQuotesStringRegexp = regexp.MustCompile("[`']\\{\\{([a-zA-Z_][a-zA-Z0-9_]*(?:\\.\\w+|\\[\\d+\\])*)+\\}\\}[`']")
38+
3739
type CustomSQLConfig struct {
3840
DatabaseInfoID int64
3941
SQLTemplate string
@@ -111,47 +113,60 @@ func (c *CustomSQL) Invoke(ctx context.Context, input map[string]any) (map[strin
111113
return nil, err
112114
}
113115

116+
templateParts := nodes.ParseTemplate(singleQuotesStringRegexp.ReplaceAllString(c.sqlTemplate, "?"))
114117
templateSQL := ""
115-
templateParts := nodes.ParseTemplate(c.sqlTemplate)
116-
sqlParams := make([]database.SQLParam, 0, len(templateParts))
117-
var nilError = errors.New("field is nil")
118-
for _, templatePart := range templateParts {
119-
if !templatePart.IsVariable {
120-
templateSQL += templatePart.Value
121-
continue
122-
}
118+
if len(templateParts) > 0 {
119+
if len(templateParts) == 0 {
120+
templateSQL = templateParts[0].Value
121+
} else {
122+
for _, templatePart := range templateParts {
123+
if !templatePart.IsVariable {
124+
templateSQL += templatePart.Value
125+
continue
126+
}
123127

124-
templateSQL += "?"
125-
val, err := templatePart.Render(inputBytes, nodes.WithNilRender(func() (string, error) {
126-
return "", nilError
127-
}),
128-
nodes.WithCustomRender(reflect.TypeOf(false), func(val any) (string, error) {
129-
b := val.(bool)
130-
if b {
131-
return "1", nil
128+
val, err := templatePart.Render(inputBytes, nodes.WithCustomRender(reflect.TypeOf(false), func(val any) (string, error) {
129+
b := val.(bool)
130+
if b {
131+
return "1", nil
132+
}
133+
return "0", nil
134+
}))
135+
if err != nil {
136+
return nil, err
132137
}
133-
return "0", nil
134-
}))
138+
templateSQL += val
135139

136-
if err != nil {
137-
if !errors.Is(err, nilError) {
138-
return nil, err
139140
}
140-
sqlParams = append(sqlParams, database.SQLParam{
141-
IsNull: true,
142-
})
143-
} else {
144-
sqlParams = append(sqlParams, database.SQLParam{
145-
Value: val,
146-
IsNull: false,
147-
})
148141
}
149142

143+
} else {
144+
return nil, fmt.Errorf("parse template invalid")
145+
}
146+
147+
sqlParamStrings := singleQuotesStringRegexp.FindAllString(c.sqlTemplate, -1)
148+
sqlParams := make([]database.SQLParam, 0, len(sqlParamStrings))
149+
for _, s := range sqlParamStrings {
150+
parts := nodes.ParseTemplate(s)
151+
for _, part := range parts {
152+
if part.IsVariable {
153+
val, err := part.Render(inputBytes, nodes.WithCustomRender(reflect.TypeOf(false), func(val any) (string, error) {
154+
b := val.(bool)
155+
if b {
156+
return "1", nil
157+
}
158+
return "0", nil
159+
}))
160+
if err != nil {
161+
return nil, err
162+
}
163+
sqlParams = append(sqlParams, database.SQLParam{
164+
Value: val,
165+
})
166+
}
167+
}
150168
}
151169

152-
// replace sql template '?' to ?
153-
templateSQL = strings.Replace(templateSQL, "'?'", "?", -1)
154-
templateSQL = strings.Replace(templateSQL, "`?`", "?", -1)
155170
req.SQL = templateSQL
156171
req.Params = sqlParams
157172
response, err := crossdatabase.DefaultSVC().Execute(ctx, req)

backend/domain/workflow/internal/nodes/database/customsql_test.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,12 @@ func TestCustomSQL_Execute(t *testing.T) {
6161
validate: func(req *database.CustomSQLRequest) {
6262
assert.Equal(t, int64(111), req.DatabaseInfoID)
6363
ps := []database.SQLParam{
64-
{Value: "v1_value"},
6564
{Value: "v2_value"},
6665
{Value: "v3_value"},
66+
{Value: "1"},
6767
}
6868
assert.Equal(t, ps, req.Params)
69-
assert.Equal(t, "select * from v1 where v1 = ? and v2 = ? and v3 = ?", req.SQL)
69+
assert.Equal(t, "select * from v1 where v1 = v1_value and v2 = ? and v3 = ? and v4 = ?", req.SQL)
7070
},
7171
}
7272

@@ -86,7 +86,7 @@ func TestCustomSQL_Execute(t *testing.T) {
8686

8787
cfg := &CustomSQLConfig{
8888
DatabaseInfoID: 111,
89-
SQLTemplate: "select * from v1 where v1 = {{v1}} and v2 = '{{v2}}' and v3 = `{{v3}}`",
89+
SQLTemplate: "select * from v1 where v1 = {{v1}} and v2 = '{{v2}}' and v3 = `{{v3}}` and v4 = '{{v4}}'",
9090
}
9191

9292
c1, err := cfg.Build(context.Background(), &schema.NodeSchema{
@@ -104,6 +104,7 @@ func TestCustomSQL_Execute(t *testing.T) {
104104
"v1": "v1_value",
105105
"v2": "v2_value",
106106
"v3": "v3_value",
107+
"v4": true,
107108
})
108109

109110
assert.Nil(t, err)

0 commit comments

Comments
 (0)