Skip to content

Commit d020f43

Browse files
committed
Introduce QueryBuilder type
1 parent 2d47d95 commit d020f43

File tree

2 files changed

+434
-0
lines changed

2 files changed

+434
-0
lines changed

database/query_builder.go

Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
package database
2+
3+
import (
4+
"fmt"
5+
"go.uber.org/zap"
6+
"golang.org/x/exp/slices"
7+
"reflect"
8+
"sort"
9+
"strings"
10+
)
11+
12+
// QueryBuilder is an addon for the DB type that takes care of all the database statement building shenanigans.
13+
// Note: This type is designed primarily for one-off use (monouso) and subsequent disposal and should only be
14+
// used to generate a single database query type.
15+
type QueryBuilder struct {
16+
subject any
17+
columns []string
18+
excludedColumns []string
19+
20+
// Indicates whether the generated columns should be sorted in ascending order before generating the
21+
// actual statements. This is intended for unit tests only and shouldn't be necessary for production code.
22+
sort bool
23+
}
24+
25+
// NewQB returns a fully initialized *QueryBuilder instance for the given subject/struct.
26+
func NewQB(subject any) *QueryBuilder {
27+
return &QueryBuilder{subject: subject}
28+
}
29+
30+
// SetColumns sets the DB columns to be used when building the statements.
31+
// When you do not want the columns to be extracted dynamically, you can use this method to specify them manually.
32+
// Returns the current *[QueryBuilder] receiver and allows you to chain some method calls.
33+
func (qb *QueryBuilder) SetColumns(columns ...string) *QueryBuilder {
34+
qb.columns = columns
35+
return qb
36+
}
37+
38+
// SetExcludedColumns excludes the given columns from all the database statements.
39+
// Returns the current *[QueryBuilder] receiver and allows you to chain some method calls.
40+
func (qb *QueryBuilder) SetExcludedColumns(columns ...string) *QueryBuilder {
41+
qb.excludedColumns = columns
42+
return qb
43+
}
44+
45+
// Delete returns a DELETE statement for the query builders subject filtered by ID.
46+
func (qb *QueryBuilder) Delete() string {
47+
return qb.DeleteBy("id")
48+
}
49+
50+
// DeleteBy returns a DELETE statement for the query builders subject filtered by the given column.
51+
func (qb *QueryBuilder) DeleteBy(column string) string {
52+
return fmt.Sprintf(`DELETE FROM "%s" WHERE "%s" IN (?)`, TableName(qb.subject), column)
53+
}
54+
55+
// Insert returns an INSERT INTO statement for the query builders subject.
56+
func (qb *QueryBuilder) Insert(db *DB) (string, int) {
57+
columns := qb.BuildColumns(db)
58+
59+
return fmt.Sprintf(
60+
`INSERT INTO "%s" ("%s") VALUES (%s)`,
61+
TableName(qb.subject),
62+
strings.Join(columns, `", "`),
63+
fmt.Sprintf(":%s", strings.Join(columns, ", :")),
64+
), len(columns)
65+
}
66+
67+
// InsertIgnore returns an INSERT statement for the query builders subject for
68+
// which the database ignores rows that have already been inserted.
69+
func (qb *QueryBuilder) InsertIgnore(db *DB) (string, int) {
70+
columns := qb.BuildColumns(db)
71+
table := TableName(qb.subject)
72+
73+
var clause string
74+
switch db.DriverName() {
75+
case MySQL:
76+
// MySQL treats UPDATE id = id as a no-op.
77+
clause = fmt.Sprintf(`ON DUPLICATE KEY UPDATE "%[1]s" = "%[1]s"`, columns[0])
78+
case PostgreSQL:
79+
var constraint string
80+
if constrainter, ok := qb.subject.(PgsqlOnConflictConstrainter); ok {
81+
constraint = constrainter.PgsqlOnConflictConstraint()
82+
} else {
83+
constraint = "pk_" + table
84+
}
85+
86+
clause = fmt.Sprintf("ON CONFLICT ON CONSTRAINT %s DO NOTHING", constraint)
87+
default:
88+
db.logger.Fatalw("Driver unsupported", zap.String("driver", db.DriverName()))
89+
}
90+
91+
return fmt.Sprintf(
92+
`INSERT INTO "%s" ("%s") VALUES (%s) %s`,
93+
table,
94+
strings.Join(columns, `", "`),
95+
fmt.Sprintf(":%s", strings.Join(columns, ", :")),
96+
clause,
97+
), len(columns)
98+
}
99+
100+
// Select returns a SELECT statement from the query builders subject and the already set columns.
101+
// If no columns are set, they will be extracted from the query builders subject.
102+
// When the query builders subject is of type Scoper, a WHERE clause is appended to the statement.
103+
func (qb *QueryBuilder) Select(db *DB) string {
104+
var scoper Scoper
105+
if sc, ok := qb.subject.(Scoper); ok {
106+
scoper = sc
107+
}
108+
109+
return qb.SelectScoped(db, scoper)
110+
}
111+
112+
// SelectScoped returns a SELECT statement from the query builders subject and the already set columns filtered
113+
// by the given scoper/column. When no columns are set, they will be extracted from the query builders subject.
114+
// The argument scoper must either be of type Scoper, string or nil to get SELECT statements without a WHERE clause.
115+
func (qb *QueryBuilder) SelectScoped(db *DB, scoper any) string {
116+
query := fmt.Sprintf(`SELECT "%s" FROM "%s"`, strings.Join(qb.BuildColumns(db), `", "`), TableName(qb.subject))
117+
where, placeholders := qb.Where(db, scoper)
118+
if placeholders > 0 {
119+
query += ` WHERE ` + where
120+
}
121+
122+
return query
123+
}
124+
125+
// Update returns an UPDATE statement for the query builders subject filter by ID column.
126+
func (qb *QueryBuilder) Update(db *DB) (string, int) {
127+
return qb.UpdateScoped(db, "id")
128+
}
129+
130+
// UpdateScoped returns an UPDATE statement for the query builders subject filtered by the given column/scoper.
131+
// The argument scoper must either be of type Scoper, string or nil to get UPDATE statements without a WHERE clause.
132+
func (qb *QueryBuilder) UpdateScoped(db *DB, scoper any) (string, int) {
133+
columns := qb.BuildColumns(db)
134+
set := make([]string, 0, len(columns))
135+
136+
for _, col := range columns {
137+
set = append(set, fmt.Sprintf(`"%[1]s" = :%[1]s`, col))
138+
}
139+
140+
placeholders := len(columns)
141+
query := `UPDATE "%s" SET %s`
142+
if where, count := qb.Where(db, scoper); count > 0 {
143+
placeholders += count
144+
query += ` WHERE ` + where
145+
}
146+
147+
return fmt.Sprintf(query, TableName(qb.subject), strings.Join(set, ", ")), placeholders
148+
}
149+
150+
// Upsert returns an upsert statement for the query builders subject.
151+
func (qb *QueryBuilder) Upsert(db *DB) (string, int) {
152+
var updateColumns []string
153+
if upserter, ok := qb.subject.(Upserter); ok {
154+
updateColumns = db.columnMap.Columns(upserter.Upsert())
155+
} else {
156+
updateColumns = qb.BuildColumns(db)
157+
}
158+
159+
return qb.UpsertColumns(db, updateColumns...)
160+
}
161+
162+
// UpsertColumns returns an upsert statement for the query builders subject and the specified update columns.
163+
func (qb *QueryBuilder) UpsertColumns(db *DB, updateColumns ...string) (string, int) {
164+
insertColumns := qb.BuildColumns(db)
165+
table := TableName(qb.subject)
166+
167+
var clause, setFormat string
168+
switch db.DriverName() {
169+
case MySQL:
170+
clause = "ON DUPLICATE KEY UPDATE"
171+
setFormat = `"%[1]s" = VALUES("%[1]s")`
172+
case PostgreSQL:
173+
var constraint string
174+
if constrainter, ok := qb.subject.(PgsqlOnConflictConstrainter); ok {
175+
constraint = constrainter.PgsqlOnConflictConstraint()
176+
} else {
177+
constraint = "pk_" + table
178+
}
179+
180+
clause = fmt.Sprintf("ON CONFLICT ON CONSTRAINT %s DO UPDATE SET", constraint)
181+
setFormat = `"%[1]s" = EXCLUDED."%[1]s"`
182+
default:
183+
db.logger.Fatalw("Driver unsupported", zap.String("driver", db.DriverName()))
184+
}
185+
186+
set := make([]string, 0, len(updateColumns))
187+
for _, col := range updateColumns {
188+
set = append(set, fmt.Sprintf(setFormat, col))
189+
}
190+
191+
return fmt.Sprintf(
192+
`INSERT INTO "%s" ("%s") VALUES (%s) %s %s`,
193+
table,
194+
strings.Join(insertColumns, `", "`),
195+
fmt.Sprintf(":%s", strings.Join(insertColumns, ", :")),
196+
clause,
197+
strings.Join(set, ", "),
198+
), len(insertColumns)
199+
}
200+
201+
// Where returns a WHERE clause with named placeholder conditions built from the
202+
// specified scoper/column combined with the AND operator.
203+
func (qb *QueryBuilder) Where(db *DB, subject any) (string, int) {
204+
t := reflect.TypeOf(subject)
205+
if t == nil { // Subject is a nil interface value.
206+
return "", 0
207+
}
208+
209+
var columns []string
210+
if t.Kind() == reflect.String {
211+
columns = []string{subject.(string)}
212+
} else if t.Kind() == reflect.Struct || t.Kind() == reflect.Pointer {
213+
if scoper, ok := subject.(Scoper); ok {
214+
return qb.Where(db, scoper.Scope())
215+
}
216+
217+
columns = db.columnMap.Columns(subject)
218+
}
219+
220+
where := make([]string, 0, len(columns))
221+
for _, col := range columns {
222+
where = append(where, fmt.Sprintf(`"%[1]s" = :%[1]s`, col))
223+
}
224+
225+
return strings.Join(where, ` AND `), len(columns)
226+
}
227+
228+
// BuildColumns returns all the Query Builder columns (if specified), otherwise they are
229+
// determined dynamically using its subject. Additionally, it checks whether columns need
230+
// to be excluded and proceeds accordingly.
231+
func (qb *QueryBuilder) BuildColumns(db *DB) []string {
232+
var columns []string
233+
if len(qb.columns) > 0 {
234+
columns = qb.columns
235+
} else {
236+
columns = db.columnMap.Columns(qb.subject)
237+
}
238+
239+
if len(qb.excludedColumns) > 0 {
240+
columns = slices.DeleteFunc(append([]string(nil), columns...), func(column string) bool {
241+
for _, exclude := range qb.excludedColumns {
242+
if exclude == column {
243+
return true
244+
}
245+
}
246+
247+
return false
248+
})
249+
}
250+
251+
if qb.sort {
252+
// The order in which the columns appear is not guaranteed as we extract the columns dynamically
253+
// from the struct. So, we've to sort them here to be able to test the generated statements.
254+
sort.SliceStable(columns, func(a, b int) bool {
255+
return columns[a] < columns[b]
256+
})
257+
}
258+
259+
return columns
260+
}

0 commit comments

Comments
 (0)