Skip to content

Commit 97b850a

Browse files
committed
Introduce QueryBuilder type
1 parent 2d47d95 commit 97b850a

File tree

2 files changed

+485
-0
lines changed

2 files changed

+485
-0
lines changed

database/query_builder.go

Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
package database
2+
3+
import (
4+
"fmt"
5+
"golang.org/x/exp/slices"
6+
"reflect"
7+
"sort"
8+
"strings"
9+
)
10+
11+
// QueryBuilder is an addon for the [DB] type that takes care of all the database statement building shenanigans.
12+
// The recommended use of QueryBuilder is to only use it to generate a single query at a time and not two different
13+
// ones. If for instance you want to generate `INSERT` and `SELECT` queries, it is best to use two different
14+
// QueryBuilder instances. You can use the DB#QueryBuilder() method to get fully initialised instances each time.
15+
type QueryBuilder struct {
16+
db *DB
17+
18+
subject any
19+
columns []string
20+
excludedColumns []string
21+
22+
// Indicates whether the generated columns should be sorted in ascending order before generating the
23+
// actual statements. This is intended for unit tests only and shouldn't be necessary for production code.
24+
sort bool
25+
}
26+
27+
// SetColumns sets the DB columns to be used when building the statements.
28+
// When you do not want the columns to be extracted dynamically, you can use this method to specify them manually.
29+
// Returns the current *[QueryBuilder] receiver and allows you to chain some method calls.
30+
func (qb *QueryBuilder) SetColumns(columns ...string) *QueryBuilder {
31+
qb.columns = columns
32+
return qb
33+
}
34+
35+
// SetExcludedColumns excludes the given columns from all the database statements.
36+
// Returns the current *[QueryBuilder] receiver and allows you to chain some method calls.
37+
func (qb *QueryBuilder) SetExcludedColumns(columns ...string) *QueryBuilder {
38+
qb.excludedColumns = columns
39+
return qb
40+
}
41+
42+
// Delete returns a DELETE statement for the query builders subject filtered by ID.
43+
func (qb *QueryBuilder) Delete() string {
44+
return qb.DeleteBy("id")
45+
}
46+
47+
// DeleteBy returns a DELETE statement for the query builders subject filtered by the given column.
48+
func (qb *QueryBuilder) DeleteBy(column string) string {
49+
return fmt.Sprintf(`DELETE FROM "%s" WHERE "%s" IN (?)`, TableName(qb.subject), column)
50+
}
51+
52+
// Insert returns an INSERT INTO statement for the query builders subject.
53+
func (qb *QueryBuilder) Insert() (string, int) {
54+
columns := qb.BuildColumns()
55+
56+
return fmt.Sprintf(
57+
`INSERT INTO "%s" ("%s") VALUES (%s)`,
58+
TableName(qb.subject),
59+
strings.Join(columns, `", "`),
60+
fmt.Sprintf(":%s", strings.Join(columns, ", :")),
61+
), len(columns)
62+
}
63+
64+
// InsertIgnore returns an INSERT statement for the query builders subject for
65+
// which the database ignores rows that have already been inserted.
66+
func (qb *QueryBuilder) InsertIgnore() (string, int) {
67+
columns := qb.BuildColumns()
68+
69+
var clause string
70+
switch qb.db.DriverName() {
71+
case MySQL:
72+
// MySQL treats UPDATE id = id as a no-op.
73+
clause = fmt.Sprintf(`ON DUPLICATE KEY UPDATE "%[1]s" = "%[1]s"`, columns[0])
74+
case PostgreSQL:
75+
clause = fmt.Sprintf("ON CONFLICT ON CONSTRAINT %s DO NOTHING", qb.getPgsqlOnConflictConstraint())
76+
default:
77+
panic("Driver unsupported: " + qb.db.DriverName())
78+
}
79+
80+
return fmt.Sprintf(
81+
`INSERT INTO "%s" ("%s") VALUES (%s) %s`,
82+
TableName(qb.subject),
83+
strings.Join(columns, `", "`),
84+
fmt.Sprintf(":%s", strings.Join(columns, ", :")),
85+
clause,
86+
), len(columns)
87+
}
88+
89+
// Select returns a SELECT statement from the query builders subject and the already set columns.
90+
// If no columns are set, they will be extracted from the query builders subject.
91+
// When the query builders subject is of type Scoper, a WHERE clause is appended to the statement.
92+
func (qb *QueryBuilder) Select() string {
93+
var scoper Scoper
94+
if sc, ok := qb.subject.(Scoper); ok {
95+
scoper = sc
96+
}
97+
98+
return qb.SelectScoped(scoper)
99+
}
100+
101+
// SelectScoped returns a SELECT statement from the query builders subject and the already set columns filtered
102+
// by the given scoper/column. When no columns are set, they will be extracted from the query builders subject.
103+
// The argument scoper must either be of type Scoper, string or nil to get SELECT statements without a WHERE clause.
104+
func (qb *QueryBuilder) SelectScoped(scoper any) string {
105+
query := fmt.Sprintf(`SELECT "%s" FROM "%s"`, strings.Join(qb.BuildColumns(), `", "`), TableName(qb.subject))
106+
where, placeholders := qb.Where(scoper)
107+
if placeholders > 0 {
108+
query += ` WHERE ` + where
109+
}
110+
111+
return query
112+
}
113+
114+
// Update returns an UPDATE statement for the query builders subject filter by ID column.
115+
func (qb *QueryBuilder) Update() (string, int) {
116+
return qb.UpdateScoped("id")
117+
}
118+
119+
// UpdateScoped returns an UPDATE statement for the query builders subject filtered by the given column/scoper.
120+
// The argument scoper must either be of type Scoper, string or nil to get UPDATE statements without a WHERE clause.
121+
func (qb *QueryBuilder) UpdateScoped(scoper any) (string, int) {
122+
columns := qb.BuildColumns()
123+
set := make([]string, 0, len(columns))
124+
125+
for _, col := range columns {
126+
set = append(set, fmt.Sprintf(`"%[1]s" = :%[1]s`, col))
127+
}
128+
129+
placeholders := len(columns)
130+
query := `UPDATE "%s" SET %s`
131+
if where, count := qb.Where(scoper); count > 0 {
132+
placeholders += count
133+
query += ` WHERE ` + where
134+
}
135+
136+
return fmt.Sprintf(query, TableName(qb.subject), strings.Join(set, ", ")), placeholders
137+
}
138+
139+
// Upsert returns an upsert statement for the query builders subject.
140+
func (qb *QueryBuilder) Upsert() (string, int) {
141+
var updateColumns []string
142+
if upserter, ok := qb.subject.(Upserter); ok {
143+
updateColumns = qb.db.columnMap.Columns(upserter.Upsert())
144+
} else {
145+
updateColumns = qb.BuildColumns()
146+
}
147+
148+
return qb.UpsertColumns(updateColumns...)
149+
}
150+
151+
// UpsertColumns returns an upsert statement for the query builders subject and the specified update columns.
152+
func (qb *QueryBuilder) UpsertColumns(updateColumns ...string) (string, int) {
153+
var clause, setFormat string
154+
switch qb.db.DriverName() {
155+
case MySQL:
156+
clause = "ON DUPLICATE KEY UPDATE"
157+
setFormat = `"%[1]s" = VALUES("%[1]s")`
158+
case PostgreSQL:
159+
clause = fmt.Sprintf("ON CONFLICT ON CONSTRAINT %s DO UPDATE SET", qb.getPgsqlOnConflictConstraint())
160+
setFormat = `"%[1]s" = EXCLUDED."%[1]s"`
161+
default:
162+
panic("Driver unsupported: " + qb.db.DriverName())
163+
}
164+
165+
set := make([]string, 0, len(updateColumns))
166+
for _, col := range updateColumns {
167+
set = append(set, fmt.Sprintf(setFormat, col))
168+
}
169+
170+
insertColumns := qb.BuildColumns()
171+
172+
return fmt.Sprintf(
173+
`INSERT INTO "%s" ("%s") VALUES (%s) %s %s`,
174+
TableName(qb.subject),
175+
strings.Join(insertColumns, `", "`),
176+
fmt.Sprintf(":%s", strings.Join(insertColumns, ", :")),
177+
clause,
178+
strings.Join(set, ", "),
179+
), len(insertColumns)
180+
}
181+
182+
// Where returns a WHERE clause with named placeholder conditions built from the
183+
// specified scoper/column combined with the AND operator.
184+
func (qb *QueryBuilder) Where(subject any) (string, int) {
185+
t := reflect.TypeOf(subject)
186+
if t == nil { // Subject is a nil interface value.
187+
return "", 0
188+
}
189+
190+
var columns []string
191+
if t.Kind() == reflect.String {
192+
columns = []string{subject.(string)}
193+
} else if t.Kind() == reflect.Struct || t.Kind() == reflect.Pointer {
194+
if scoper, ok := subject.(Scoper); ok {
195+
return qb.Where(scoper.Scope())
196+
}
197+
198+
columns = qb.db.columnMap.Columns(subject)
199+
} else { // This should never happen unless someone wants to do some silly things.
200+
panic(fmt.Sprintf("qb.Where: unknown subject type provided: %q", t.Kind().String()))
201+
}
202+
203+
where := make([]string, 0, len(columns))
204+
for _, col := range columns {
205+
where = append(where, fmt.Sprintf(`"%[1]s" = :%[1]s`, col))
206+
}
207+
208+
return strings.Join(where, ` AND `), len(columns)
209+
}
210+
211+
// BuildColumns returns all the Query Builder columns (if specified), otherwise they are
212+
// determined dynamically using its subject. Additionally, it checks whether columns need
213+
// to be excluded and proceeds accordingly.
214+
func (qb *QueryBuilder) BuildColumns() []string {
215+
var columns []string
216+
if len(qb.columns) > 0 {
217+
columns = qb.columns
218+
} else {
219+
columns = qb.db.columnMap.Columns(qb.subject)
220+
}
221+
222+
if len(qb.excludedColumns) > 0 {
223+
columns = slices.DeleteFunc(append([]string(nil), columns...), func(column string) bool {
224+
for _, exclude := range qb.excludedColumns {
225+
if exclude == column {
226+
return true
227+
}
228+
}
229+
230+
return false
231+
})
232+
}
233+
234+
if qb.sort {
235+
// The order in which the columns appear is not guaranteed as we extract the columns dynamically
236+
// from the struct. So, we've to sort them here to be able to test the generated statements.
237+
sort.SliceStable(columns, func(a, b int) bool {
238+
return columns[a] < columns[b]
239+
})
240+
}
241+
242+
return columns
243+
}
244+
245+
// getPgsqlOnConflictConstraint returns the constraint name of the current [QueryBuilder]'s subject.
246+
// If the subject does not implement the PgsqlOnConflictConstrainter interface, it will simply return
247+
// the table name prefixed with `pk_`.
248+
func (qb *QueryBuilder) getPgsqlOnConflictConstraint() string {
249+
if constrainter, ok := qb.subject.(PgsqlOnConflictConstrainter); ok {
250+
return constrainter.PgsqlOnConflictConstraint()
251+
}
252+
253+
return "pk_" + TableName(qb.subject)
254+
}

0 commit comments

Comments
 (0)