Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 48 additions & 33 deletions backend/domain/workflow/internal/nodes/database/customsql.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ import (
"errors"
"fmt"
"reflect"
"regexp"
"strconv"
"strings"

"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
crossdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/contract/database"
Expand All @@ -34,6 +34,8 @@ import (
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
)

var singleQuotesStringRegexp = regexp.MustCompile("[`']\\{\\{([a-zA-Z_][a-zA-Z0-9_]*(?:\\.\\w+|\\[\\d+\\])*)+\\}\\}[`']")

type CustomSQLConfig struct {
DatabaseInfoID int64
SQLTemplate string
Expand Down Expand Up @@ -111,47 +113,60 @@ func (c *CustomSQL) Invoke(ctx context.Context, input map[string]any) (map[strin
return nil, err
}

templateParts := nodes.ParseTemplate(singleQuotesStringRegexp.ReplaceAllString(c.sqlTemplate, "?"))
templateSQL := ""
templateParts := nodes.ParseTemplate(c.sqlTemplate)
sqlParams := make([]database.SQLParam, 0, len(templateParts))
var nilError = errors.New("field is nil")
for _, templatePart := range templateParts {
if !templatePart.IsVariable {
templateSQL += templatePart.Value
continue
}
if len(templateParts) > 0 {
if len(templateParts) == 0 {
templateSQL = templateParts[0].Value
} else {
for _, templatePart := range templateParts {
if !templatePart.IsVariable {
templateSQL += templatePart.Value
continue
}

templateSQL += "?"
val, err := templatePart.Render(inputBytes, nodes.WithNilRender(func() (string, error) {
return "", nilError
}),
nodes.WithCustomRender(reflect.TypeOf(false), func(val any) (string, error) {
b := val.(bool)
if b {
return "1", nil
val, err := templatePart.Render(inputBytes, nodes.WithCustomRender(reflect.TypeOf(false), func(val any) (string, error) {
b := val.(bool)
if b {
return "1", nil
}
return "0", nil
}))
if err != nil {
return nil, err
}
return "0", nil
}))
templateSQL += val

if err != nil {
if !errors.Is(err, nilError) {
return nil, err
}
sqlParams = append(sqlParams, database.SQLParam{
IsNull: true,
})
} else {
sqlParams = append(sqlParams, database.SQLParam{
Value: val,
IsNull: false,
})
}

} else {
return nil, fmt.Errorf("parse template invalid")
}

sqlParamStrings := singleQuotesStringRegexp.FindAllString(c.sqlTemplate, -1)
sqlParams := make([]database.SQLParam, 0, len(sqlParamStrings))
for _, s := range sqlParamStrings {
parts := nodes.ParseTemplate(s)
for _, part := range parts {
if part.IsVariable {
val, err := part.Render(inputBytes, nodes.WithCustomRender(reflect.TypeOf(false), func(val any) (string, error) {
b := val.(bool)
if b {
return "1", nil
}
return "0", nil
}))
if err != nil {
return nil, err
}
sqlParams = append(sqlParams, database.SQLParam{
Value: val,
})
}
}
}

// replace sql template '?' to ?
templateSQL = strings.Replace(templateSQL, "'?'", "?", -1)
templateSQL = strings.Replace(templateSQL, "`?`", "?", -1)
req.SQL = templateSQL
req.Params = sqlParams
response, err := crossdatabase.DefaultSVC().Execute(ctx, req)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,12 @@ func TestCustomSQL_Execute(t *testing.T) {
validate: func(req *database.CustomSQLRequest) {
assert.Equal(t, int64(111), req.DatabaseInfoID)
ps := []database.SQLParam{
{Value: "v1_value"},
{Value: "v2_value"},
{Value: "v3_value"},
{Value: "1"},
}
assert.Equal(t, ps, req.Params)
assert.Equal(t, "select * from v1 where v1 = ? and v2 = ? and v3 = ?", req.SQL)
assert.Equal(t, "select * from v1 where v1 = v1_value and v2 = ? and v3 = ? and v4 = ?", req.SQL)
},
}

Expand All @@ -86,7 +86,7 @@ func TestCustomSQL_Execute(t *testing.T) {

cfg := &CustomSQLConfig{
DatabaseInfoID: 111,
SQLTemplate: "select * from v1 where v1 = {{v1}} and v2 = '{{v2}}' and v3 = `{{v3}}`",
SQLTemplate: "select * from v1 where v1 = {{v1}} and v2 = '{{v2}}' and v3 = `{{v3}}` and v4 = '{{v4}}'",
}

c1, err := cfg.Build(context.Background(), &schema.NodeSchema{
Expand All @@ -104,6 +104,7 @@ func TestCustomSQL_Execute(t *testing.T) {
"v1": "v1_value",
"v2": "v2_value",
"v3": "v3_value",
"v4": true,
})

assert.Nil(t, err)
Expand Down
Loading