Skip to content

Commit eba49ef

Browse files
committed
Build SET clause from columns
1 parent 405a2dc commit eba49ef

File tree

2 files changed

+51
-23
lines changed

2 files changed

+51
-23
lines changed

database/query_builder.go

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -183,41 +183,51 @@ func (qb *queryBuilder) SelectStatement(stmt SelectStatement) string {
183183
}
184184

185185
func (qb *queryBuilder) UpdateStatement(stmt UpdateStatement) (string, error) {
186+
columns := qb.BuildColumns(stmt.Entity(), stmt.Columns(), stmt.ExcludedColumns())
187+
186188
table := stmt.Table()
187189
if table == "" {
188190
table = TableName(stmt.Entity())
189191
}
190-
set := stmt.Set()
191-
if set == "" {
192-
return "", errors.New("set cannot be empty")
193-
}
192+
194193
where := stmt.Where()
195194
if where == "" {
196-
return "", errors.New("cannot use UpdateStatement() without where statement - use UpdateAllStatement() instead")
195+
return "", fmt.Errorf("%w: %s", ErrMissingStatementPart, "where statement - use UpdateAllStatement() instead")
196+
}
197+
198+
var set []string
199+
200+
for _, col := range columns {
201+
set = append(set, fmt.Sprintf(`"%[1]s" = :%[1]s`, col))
197202
}
198203

199204
return fmt.Sprintf(
200-
`UPDATE "%s" SET %s%s`,
205+
`UPDATE "%s" SET %s WHERE %s`,
201206
table,
202-
set,
207+
strings.Join(set, ", "),
203208
where,
204209
), nil
205210
}
206211

207212
func (qb *queryBuilder) UpdateAllStatement(stmt UpdateStatement) (string, error) {
213+
columns := qb.BuildColumns(stmt.Entity(), stmt.Columns(), stmt.ExcludedColumns())
214+
208215
table := stmt.Table()
209216
if table == "" {
210217
table = TableName(stmt.Entity())
211218
}
212-
set := stmt.Set()
213-
if set == "" {
214-
return "", errors.New("set cannot be empty")
215-
}
219+
216220
where := stmt.Where()
217221
if where != "" {
218222
return "", errors.New("cannot use UpdateAllStatement() with where statement - use UpdateStatement() instead")
219223
}
220224

225+
var set []string
226+
227+
for _, col := range columns {
228+
set = append(set, fmt.Sprintf(`"%[1]s" = :%[1]s`, col))
229+
}
230+
221231
return fmt.Sprintf(
222232
`UPDATE "%s" SET %s`,
223233
table,

database/update.go

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,12 @@ type UpdateStatement interface {
88
// Overrides the table name provided by the entity.
99
SetTable(table string) UpdateStatement
1010

11-
// SetSet sets the set clause for the UPDATE statement.
12-
SetSet(set string) UpdateStatement
11+
// SetColumns sets the columns to be updated.
12+
SetColumns(columns ...string) UpdateStatement
13+
14+
// SetExcludedColumns sets the columns to be excluded from the UPDATE statement.
15+
// Excludes also columns set by SetColumns.
16+
SetExcludedColumns(columns ...string) UpdateStatement
1317

1418
// SetWhere sets the where clause for the UPDATE statement.
1519
SetWhere(where string) UpdateStatement
@@ -20,8 +24,11 @@ type UpdateStatement interface {
2024
// Table returns the table name for the UPDATE statement.
2125
Table() string
2226

23-
// Set returns the set clause for the UPDATE statement.
24-
Set() string
27+
// Columns returns the columns to be updated.
28+
Columns() []string
29+
30+
// ExcludedColumns returns the columns to be excluded from the UPDATE statement.
31+
ExcludedColumns() []string
2532

2633
// Where returns the where clause for the UPDATE statement.
2734
Where() string
@@ -39,10 +46,11 @@ func NewUpdateStatement(entity Entity) UpdateStatement {
3946

4047
// updateStatement is the default implementation of the UpdateStatement interface.
4148
type updateStatement struct {
42-
entity Entity
43-
table string
44-
set string
45-
where string
49+
entity Entity
50+
table string
51+
columns []string
52+
excludedColumns []string
53+
where string
4654
}
4755

4856
func (u *updateStatement) SetTable(table string) UpdateStatement {
@@ -51,8 +59,14 @@ func (u *updateStatement) SetTable(table string) UpdateStatement {
5159
return u
5260
}
5361

54-
func (u *updateStatement) SetSet(set string) UpdateStatement {
55-
u.set = set
62+
func (u *updateStatement) SetColumns(columns ...string) UpdateStatement {
63+
u.columns = columns
64+
65+
return u
66+
}
67+
68+
func (u *updateStatement) SetExcludedColumns(columns ...string) UpdateStatement {
69+
u.excludedColumns = columns
5670

5771
return u
5872
}
@@ -71,8 +85,12 @@ func (u *updateStatement) Table() string {
7185
return u.table
7286
}
7387

74-
func (u *updateStatement) Set() string {
75-
return u.set
88+
func (u *updateStatement) Columns() []string {
89+
return u.columns
90+
}
91+
92+
func (u *updateStatement) ExcludedColumns() []string {
93+
return u.excludedColumns
7694
}
7795

7896
func (u *updateStatement) Where() string {

0 commit comments

Comments
 (0)