Skip to content

Commit 21f4be5

Browse files
d1vbyz3r0n-r-w
andauthored
Fix parameter numbering in CTEs (#28)
* Fix parameter numbering in CTEs * Refactor SQL generation methods to improve placeholder handling and add tests for nested select scenarios --------- Co-authored-by: Roman Nikulenkov <nrw@yandex.ru>
1 parent 2908bf2 commit 21f4be5

File tree

9 files changed

+339
-49
lines changed

9 files changed

+339
-49
lines changed

cte_test.go

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,3 +151,59 @@ func TestCTEPlaceholderFormat(t *testing.T) {
151151
expectedSql = "WITH table1 AS (SELECT col1, col2 FROM table1 WHERE col1 = $1) UPDATE table2 SET col3 = $2"
152152
assert.Equal(t, expectedSql, sql)
153153
}
154+
155+
func TestCTEWithNestedSelects_DollarPlaceholderFormat(t *testing.T) {
156+
b := StatementBuilder.PlaceholderFormat(Dollar)
157+
158+
sub := b.Select("col1", "col2").
159+
From("table1").
160+
Where("col1 = ?", 1)
161+
162+
sub = sub.Where("col2 = ?", "123")
163+
164+
q := b.With("table1").
165+
As(sub).
166+
Cte("table2").
167+
As(
168+
b.Select("col3", "col4").
169+
From("table2").
170+
Where("col3 = ?", "345").
171+
Where("col4 = ?", 2),
172+
).
173+
Select(
174+
b.Select("col1", "col2", "col3", "col4").
175+
From("table1").
176+
Where("col1 = ?", 3).
177+
Join("table2 ON col3 = col4"),
178+
)
179+
180+
sql, args, err := q.ToSql()
181+
assert.NoError(t, err)
182+
183+
expectedSQL := "" +
184+
"WITH table1 AS (SELECT col1, col2 FROM table1 WHERE col1 = $1 AND col2 = $2), " +
185+
"table2 AS (SELECT col3, col4 FROM table2 WHERE col3 = $3 AND col4 = $4) " +
186+
"SELECT col1, col2, col3, col4 FROM table1 JOIN table2 ON col3 = col4 WHERE col1 = $5"
187+
188+
assert.Equal(t, expectedSQL, sql)
189+
assert.Equal(t, []any{1, "123", "345", 2, 3}, args)
190+
}
191+
192+
func TestCTEFinalUpdate_DollarPlaceholderNumberingConflict(t *testing.T) {
193+
b := StatementBuilder.PlaceholderFormat(Dollar)
194+
195+
q := b.With("w1").
196+
As(
197+
b.Select("c").From("t1").Where("a = ?", 1),
198+
).
199+
Update(
200+
b.Update("t2").Set("x", 2).Where("y = ?", 3),
201+
)
202+
203+
sql, args, err := q.ToSql()
204+
assert.NoError(t, err)
205+
206+
expectedSQL := "WITH w1 AS (SELECT c FROM t1 WHERE a = $1) UPDATE t2 SET x = $2 WHERE y = $3"
207+
assert.Equal(t, expectedSQL, sql)
208+
assert.Equal(t, []any{1, 2, 3}, args)
209+
}

delete.go

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ type deleteData struct {
1919
Suffixes []Sqlizer
2020
}
2121

22-
func (d *deleteData) ToSql() (sqlStr string, args []any, err error) {
22+
func (d *deleteData) toSqlRaw() (sqlStr string, args []any, err error) {
2323
if len(d.From) == 0 {
2424
err = fmt.Errorf("delete statements must specify a From table")
2525
return "", nil, err
@@ -33,14 +33,14 @@ func (d *deleteData) ToSql() (sqlStr string, args []any, err error) {
3333
return "", nil, err
3434
}
3535

36-
sql.WriteString(" ")
36+
_, _ = sql.WriteString(" ")
3737
}
3838

39-
sql.WriteString("DELETE FROM ")
40-
sql.WriteString(d.From)
39+
_, _ = sql.WriteString("DELETE FROM ")
40+
_, _ = sql.WriteString(d.From)
4141

4242
if len(d.WhereParts) > 0 {
43-
sql.WriteString(" WHERE ")
43+
_, _ = sql.WriteString(" WHERE ")
4444
args, err = appendToSql(d.WhereParts, sql, " AND ", args)
4545
if err != nil {
4646
return "", nil, err
@@ -70,8 +70,16 @@ func (d *deleteData) ToSql() (sqlStr string, args []any, err error) {
7070
}
7171
}
7272

73-
sqlStr, err = d.PlaceholderFormat.ReplacePlaceholders(sql.String())
74-
return sqlStr, args, err
73+
return sql.String(), args, nil
74+
}
75+
76+
func (d *deleteData) ToSql() (sqlStr string, args []any, err error) {
77+
s, a, e := d.toSqlRaw()
78+
if e != nil {
79+
return "", nil, e
80+
}
81+
sqlStr, err = d.PlaceholderFormat.ReplacePlaceholders(s)
82+
return sqlStr, a, err
7583
}
7684

7785
// Builder
@@ -146,6 +154,12 @@ func (b DeleteBuilder) Offset(offset uint64) DeleteBuilder {
146154
return builder.Set(b, "Offset", fmt.Sprintf("%d", offset)).(DeleteBuilder)
147155
}
148156

157+
// toSqlRaw builds SQL with raw placeholders ("?") without applying PlaceholderFormat.
158+
func (b DeleteBuilder) toSqlRaw() (string, []any, error) {
159+
data := builder.GetStruct(b).(deleteData)
160+
return data.toSqlRaw()
161+
}
162+
149163
// Suffix adds an expression to the end of the query
150164
func (b DeleteBuilder) Suffix(sql string, args ...any) DeleteBuilder {
151165
return b.SuffixExpr(Expr(sql, args...))

expr.go

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ func (e expr) ToSql() (sql string, args []any, err error) {
6565

6666
if as, ok := ap[0].(Sqlizer); ok {
6767
// sqlizer argument; expand it and append the result
68-
isql, iargs, err = as.ToSql()
68+
isql, iargs, err = nestedToSql(as)
6969
buf.WriteString(sp[:i])
7070
buf.WriteString(isql)
7171
args = append(args, iargs...)
@@ -93,7 +93,7 @@ func (ce concatExpr) ToSql() (sql string, args []any, err error) {
9393
case string:
9494
sql += p
9595
case Sqlizer:
96-
pSql, pArgs, err := p.ToSql()
96+
pSql, pArgs, err := nestedToSql(p)
9797
if err != nil {
9898
return "", nil, err
9999
}
@@ -132,7 +132,7 @@ func Alias(e Sqlizer, a string) aliasExpr {
132132
}
133133

134134
func (e aliasExpr) ToSql() (sql string, args []any, err error) {
135-
sql, args, err = e.expr.ToSql()
135+
sql, args, err = nestedToSql(e.expr)
136136
if err == nil {
137137
sql = fmt.Sprintf("(%s) AS %s", sql, e.alias)
138138
}
@@ -456,7 +456,7 @@ func Sum(e Sqlizer) sumExpr {
456456
}
457457

458458
func (e sumExpr) ToSql() (sql string, args []any, err error) {
459-
sql, args, err = e.expr.ToSql()
459+
sql, args, err = nestedToSql(e.expr)
460460
if err == nil {
461461
sql = fmt.Sprintf("SUM(%s)", sql)
462462
}
@@ -475,7 +475,7 @@ func Count(e Sqlizer) countExpr {
475475
}
476476

477477
func (e countExpr) ToSql() (sql string, args []any, err error) {
478-
sql, args, err = e.expr.ToSql()
478+
sql, args, err = nestedToSql(e.expr)
479479
if err == nil {
480480
sql = fmt.Sprintf("COUNT(%s)", sql)
481481
}
@@ -494,7 +494,7 @@ func Min(e Sqlizer) minExpr {
494494
}
495495

496496
func (e minExpr) ToSql() (sql string, args []any, err error) {
497-
sql, args, err = e.expr.ToSql()
497+
sql, args, err = nestedToSql(e.expr)
498498
if err == nil {
499499
sql = fmt.Sprintf("MIN(%s)", sql)
500500
}
@@ -513,7 +513,7 @@ func Max(e Sqlizer) maxExpr {
513513
}
514514

515515
func (e maxExpr) ToSql() (sql string, args []any, err error) {
516-
sql, args, err = e.expr.ToSql()
516+
sql, args, err = nestedToSql(e.expr)
517517
if err == nil {
518518
sql = fmt.Sprintf("MAX(%s)", sql)
519519
}
@@ -532,7 +532,7 @@ func Avg(e Sqlizer) avgExpr {
532532
}
533533

534534
func (e avgExpr) ToSql() (sql string, args []any, err error) {
535-
sql, args, err = e.expr.ToSql()
535+
sql, args, err = nestedToSql(e.expr)
536536
if err == nil {
537537
sql = fmt.Sprintf("AVG(%s)", sql)
538538
}
@@ -551,7 +551,7 @@ func Exists(e Sqlizer) existsExpr {
551551
}
552552

553553
func (e existsExpr) ToSql() (sql string, args []any, err error) {
554-
sql, args, err = e.expr.ToSql()
554+
sql, args, err = nestedToSql(e.expr)
555555
if err == nil {
556556
sql = fmt.Sprintf("EXISTS (%s)", sql)
557557
}
@@ -570,7 +570,7 @@ func NotExists(e Sqlizer) notExistsExpr {
570570
}
571571

572572
func (e notExistsExpr) ToSql() (sql string, args []any, err error) {
573-
sql, args, err = e.expr.ToSql()
573+
sql, args, err = nestedToSql(e.expr)
574574
if err == nil {
575575
sql = fmt.Sprintf("NOT EXISTS (%s)", sql)
576576
}
@@ -590,7 +590,7 @@ func Equal(e Sqlizer, v any) equalExpr {
590590
}
591591

592592
func (e equalExpr) ToSql() (sql string, args []any, err error) {
593-
sql, args, err = e.expr.ToSql()
593+
sql, args, err = nestedToSql(e.expr)
594594
if err == nil {
595595
sql = fmt.Sprintf("(%s) = ?", sql)
596596
args = append(args, e.value)
@@ -608,7 +608,7 @@ func NotEqual(e Sqlizer, v any) notEqualExpr {
608608
}
609609

610610
func (e notEqualExpr) ToSql() (sql string, args []any, err error) {
611-
sql, args, err = e.expr.ToSql()
611+
sql, args, err = nestedToSql(e.expr)
612612
if err == nil {
613613
sql = fmt.Sprintf("(%s) <> ?", sql)
614614
args = append(args, e.value)
@@ -626,7 +626,7 @@ func Greater(e Sqlizer, v any) greaterExpr {
626626
}
627627

628628
func (e greaterExpr) ToSql() (sql string, args []any, err error) {
629-
sql, args, err = e.expr.ToSql()
629+
sql, args, err = nestedToSql(e.expr)
630630
if err == nil {
631631
sql = fmt.Sprintf("(%s) > ?", sql)
632632
args = append(args, e.value)
@@ -644,7 +644,7 @@ func GreaterOrEqual(e Sqlizer, v any) greaterOrEqualExpr {
644644
}
645645

646646
func (e greaterOrEqualExpr) ToSql() (sql string, args []any, err error) {
647-
sql, args, err = e.expr.ToSql()
647+
sql, args, err = nestedToSql(e.expr)
648648
if err == nil {
649649
sql = fmt.Sprintf("(%s) >= ?", sql)
650650
args = append(args, e.value)
@@ -662,7 +662,7 @@ func Less(e Sqlizer, v any) lessExpr {
662662
}
663663

664664
func (e lessExpr) ToSql() (sql string, args []any, err error) {
665-
sql, args, err = e.expr.ToSql()
665+
sql, args, err = nestedToSql(e.expr)
666666
if err == nil {
667667
sql = fmt.Sprintf("(%s) < ?", sql)
668668
args = append(args, e.value)
@@ -680,7 +680,7 @@ func LessOrEqual(e Sqlizer, v any) lessOrEqualExpr {
680680
}
681681

682682
func (e lessOrEqualExpr) ToSql() (sql string, args []any, err error) {
683-
sql, args, err = e.expr.ToSql()
683+
sql, args, err = nestedToSql(e.expr)
684684
if err == nil {
685685
sql = fmt.Sprintf("(%s) <= ?", sql)
686686
args = append(args, e.value)
@@ -703,7 +703,7 @@ func In(column string, e any) inExpr {
703703
func (e inExpr) ToSql() (sql string, args []any, err error) {
704704
switch v := e.expr.(type) {
705705
case Sqlizer:
706-
sql, args, err = v.ToSql()
706+
sql, args, err = nestedToSql(v)
707707
if err == nil && sql != "" {
708708
sql = fmt.Sprintf("%s IN (%s)", e.column, sql)
709709
}
@@ -741,7 +741,7 @@ func NotIn(column string, e any) notInExpr {
741741
func (e notInExpr) ToSql() (sql string, args []any, err error) {
742742
switch v := e.expr.(type) {
743743
case Sqlizer:
744-
sql, args, err = v.ToSql()
744+
sql, args, err = nestedToSql(v)
745745
if err == nil && sql != "" {
746746
sql = fmt.Sprintf("%s NOT IN (%s)", e.column, sql)
747747
}
@@ -801,7 +801,7 @@ func (e rangeExpr) ToSql() (sql string, args []any, err error) {
801801
s = LtOrEq{e.column: e.end}
802802
}
803803

804-
return s.ToSql()
804+
return nestedToSql(s)
805805
}
806806

807807
// EqNotEmpty ignores empty and zero values in Eq map.
@@ -818,7 +818,7 @@ func (eq EqNotEmpty) ToSql() (sql string, args []any, err error) {
818818
}
819819
}
820820

821-
return vals.ToSql()
821+
return nestedToSql(vals)
822822
}
823823

824824
// clearEmptyValue recursively clears empty and zero values in any type.
@@ -865,7 +865,7 @@ func Cte(e Sqlizer, cte string) cteExpr {
865865

866866
// ToSql builds the query into a SQL string and bound args.
867867
func (e cteExpr) ToSql() (sql string, args []any, err error) {
868-
sql, args, err = e.expr.ToSql()
868+
sql, args, err = nestedToSql(e.expr)
869869
if err == nil {
870870
sql = fmt.Sprintf("%s AS (%s)", e.cte, sql)
871871
}
@@ -878,7 +878,7 @@ type notExpr struct {
878878

879879
// ToSql builds the query into a SQL string and bound args.
880880
func (e notExpr) ToSql() (sql string, args []any, err error) {
881-
sql, args, err = e.expr.ToSql()
881+
sql, args, err = nestedToSql(e.expr)
882882
if err == nil {
883883
sql = fmt.Sprintf("NOT (%s)", sql)
884884
}
@@ -908,9 +908,11 @@ func Coalesce(nullValue any, exprs ...Sqlizer) Sqlizer {
908908
// ToSql builds the query into a SQL string and bound args.
909909
func (e coalesceExpr) ToSql() (sql string, args []any, err error) {
910910
exprs := make([]string, 0, len(e.exprs))
911+
allArgs := make([]any, 0)
911912
for _, expr := range e.exprs {
912913
var exprSQL string
913-
exprSQL, args, err = expr.ToSql()
914+
var a []any
915+
exprSQL, a, err = nestedToSql(expr)
914916
if err != nil {
915917
return
916918
}
@@ -920,13 +922,16 @@ func (e coalesceExpr) ToSql() (sql string, args []any, err error) {
920922
}
921923

922924
exprs = append(exprs, fmt.Sprintf("(%s)", exprSQL))
925+
if len(a) > 0 {
926+
allArgs = append(allArgs, a...)
927+
}
923928
}
924929

925930
if len(exprs) == 0 {
926931
return "", nil, nil
927932
}
928933

929934
sql = fmt.Sprintf("COALESCE(%s, ?)", strings.Join(exprs, ", "))
930-
args = append(args, e.null)
935+
args = append(allArgs, e.null)
931936
return
932937
}

0 commit comments

Comments
 (0)