Skip to content

Commit 24a3cae

Browse files
committed
Introduce QueryBuilder type
1 parent 2d47d95 commit 24a3cae

File tree

2 files changed

+487
-0
lines changed

2 files changed

+487
-0
lines changed

database/query_builder.go

Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
package database
2+
3+
import (
4+
"fmt"
5+
"golang.org/x/exp/slices"
6+
"reflect"
7+
"sort"
8+
"strings"
9+
)
10+
11+
// QueryBuilder is an addon for the [DB] type that takes care of all the database statement building shenanigans.
12+
// The recommended use of QueryBuilder is to only use it to generate a single query at a time and not two different
13+
// ones. If for instance you want to generate `INSERT` and `SELECT` queries, it is best to use two different
14+
// QueryBuilder instances. You can use the DB#QueryBuilder() method to get fully initialised instances each time.
15+
type QueryBuilder struct {
16+
db *DB
17+
18+
subject any
19+
columns []string
20+
excludedColumns []string
21+
22+
// Indicates whether the generated columns should be sorted in ascending order before generating the
23+
// actual statements. This is intended for unit tests only and shouldn't be necessary for production code.
24+
sort bool
25+
}
26+
27+
// SetColumns sets the DB columns to be used when building the statements.
28+
// When you do not want the columns to be extracted dynamically, you can use this method to specify them manually.
29+
// Returns the current *[QueryBuilder] receiver and allows you to chain some method calls.
30+
func (qb *QueryBuilder) SetColumns(columns ...string) *QueryBuilder {
31+
qb.columns = columns
32+
return qb
33+
}
34+
35+
// SetExcludedColumns excludes the given columns from all the database statements.
36+
// Returns the current *[QueryBuilder] receiver and allows you to chain some method calls.
37+
func (qb *QueryBuilder) SetExcludedColumns(columns ...string) *QueryBuilder {
38+
qb.excludedColumns = columns
39+
return qb
40+
}
41+
42+
// Delete returns a DELETE statement for the query builders subject filtered by ID.
43+
func (qb *QueryBuilder) Delete() string {
44+
return qb.DeleteBy("id")
45+
}
46+
47+
// DeleteBy returns a DELETE statement for the query builders subject filtered by the given column.
48+
func (qb *QueryBuilder) DeleteBy(column string) string {
49+
return fmt.Sprintf(`DELETE FROM "%s" WHERE "%s" IN (?)`, TableName(qb.subject), column)
50+
}
51+
52+
// Insert returns an INSERT INTO statement for the query builders subject.
53+
func (qb *QueryBuilder) Insert() (string, int) {
54+
columns := qb.BuildColumns()
55+
56+
return fmt.Sprintf(
57+
`INSERT INTO "%s" ("%s") VALUES (%s)`,
58+
TableName(qb.subject),
59+
strings.Join(columns, `", "`),
60+
fmt.Sprintf(":%s", strings.Join(columns, ", :")),
61+
), len(columns)
62+
}
63+
64+
// InsertIgnore returns an INSERT statement for the query builders subject for
65+
// which the database ignores rows that have already been inserted.
66+
func (qb *QueryBuilder) InsertIgnore() (string, int) {
67+
columns := qb.BuildColumns()
68+
table := TableName(qb.subject)
69+
70+
var clause string
71+
switch qb.db.DriverName() {
72+
case MySQL:
73+
// MySQL treats UPDATE id = id as a no-op.
74+
clause = fmt.Sprintf(`ON DUPLICATE KEY UPDATE "%[1]s" = "%[1]s"`, columns[0])
75+
case PostgreSQL:
76+
clause = fmt.Sprintf("ON CONFLICT ON CONSTRAINT %s DO NOTHING", qb.getPgsqlOnConflictConstraint())
77+
default:
78+
panic("Driver unsupported: " + qb.db.DriverName())
79+
}
80+
81+
return fmt.Sprintf(
82+
`INSERT INTO "%s" ("%s") VALUES (%s) %s`,
83+
table,
84+
strings.Join(columns, `", "`),
85+
fmt.Sprintf(":%s", strings.Join(columns, ", :")),
86+
clause,
87+
), len(columns)
88+
}
89+
90+
// Select returns a SELECT statement from the query builders subject and the already set columns.
91+
// If no columns are set, they will be extracted from the query builders subject.
92+
// When the query builders subject is of type Scoper, a WHERE clause is appended to the statement.
93+
func (qb *QueryBuilder) Select() string {
94+
var scoper Scoper
95+
if sc, ok := qb.subject.(Scoper); ok {
96+
scoper = sc
97+
}
98+
99+
return qb.SelectScoped(scoper)
100+
}
101+
102+
// SelectScoped returns a SELECT statement from the query builders subject and the already set columns filtered
103+
// by the given scoper/column. When no columns are set, they will be extracted from the query builders subject.
104+
// The argument scoper must either be of type Scoper, string or nil to get SELECT statements without a WHERE clause.
105+
func (qb *QueryBuilder) SelectScoped(scoper any) string {
106+
query := fmt.Sprintf(`SELECT "%s" FROM "%s"`, strings.Join(qb.BuildColumns(), `", "`), TableName(qb.subject))
107+
where, placeholders := qb.Where(scoper)
108+
if placeholders > 0 {
109+
query += ` WHERE ` + where
110+
}
111+
112+
return query
113+
}
114+
115+
// Update returns an UPDATE statement for the query builders subject filter by ID column.
116+
func (qb *QueryBuilder) Update() (string, int) {
117+
return qb.UpdateScoped("id")
118+
}
119+
120+
// UpdateScoped returns an UPDATE statement for the query builders subject filtered by the given column/scoper.
121+
// The argument scoper must either be of type Scoper, string or nil to get UPDATE statements without a WHERE clause.
122+
func (qb *QueryBuilder) UpdateScoped(scoper any) (string, int) {
123+
columns := qb.BuildColumns()
124+
set := make([]string, 0, len(columns))
125+
126+
for _, col := range columns {
127+
set = append(set, fmt.Sprintf(`"%[1]s" = :%[1]s`, col))
128+
}
129+
130+
placeholders := len(columns)
131+
query := `UPDATE "%s" SET %s`
132+
if where, count := qb.Where(scoper); count > 0 {
133+
placeholders += count
134+
query += ` WHERE ` + where
135+
}
136+
137+
return fmt.Sprintf(query, TableName(qb.subject), strings.Join(set, ", ")), placeholders
138+
}
139+
140+
// Upsert returns an upsert statement for the query builders subject.
141+
func (qb *QueryBuilder) Upsert() (string, int) {
142+
var updateColumns []string
143+
if upserter, ok := qb.subject.(Upserter); ok {
144+
updateColumns = qb.db.columnMap.Columns(upserter.Upsert())
145+
} else {
146+
updateColumns = qb.BuildColumns()
147+
}
148+
149+
return qb.UpsertColumns(updateColumns...)
150+
}
151+
152+
// UpsertColumns returns an upsert statement for the query builders subject and the specified update columns.
153+
func (qb *QueryBuilder) UpsertColumns(updateColumns ...string) (string, int) {
154+
insertColumns := qb.BuildColumns()
155+
table := TableName(qb.subject)
156+
157+
var clause, setFormat string
158+
switch qb.db.DriverName() {
159+
case MySQL:
160+
clause = "ON DUPLICATE KEY UPDATE"
161+
setFormat = `"%[1]s" = VALUES("%[1]s")`
162+
case PostgreSQL:
163+
clause = fmt.Sprintf("ON CONFLICT ON CONSTRAINT %s DO UPDATE SET", qb.getPgsqlOnConflictConstraint())
164+
setFormat = `"%[1]s" = EXCLUDED."%[1]s"`
165+
default:
166+
panic("Driver unsupported: " + qb.db.DriverName())
167+
}
168+
169+
set := make([]string, 0, len(updateColumns))
170+
for _, col := range updateColumns {
171+
set = append(set, fmt.Sprintf(setFormat, col))
172+
}
173+
174+
return fmt.Sprintf(
175+
`INSERT INTO "%s" ("%s") VALUES (%s) %s %s`,
176+
table,
177+
strings.Join(insertColumns, `", "`),
178+
fmt.Sprintf(":%s", strings.Join(insertColumns, ", :")),
179+
clause,
180+
strings.Join(set, ", "),
181+
), len(insertColumns)
182+
}
183+
184+
// Where returns a WHERE clause with named placeholder conditions built from the
185+
// specified scoper/column combined with the AND operator.
186+
func (qb *QueryBuilder) Where(subject any) (string, int) {
187+
t := reflect.TypeOf(subject)
188+
if t == nil { // Subject is a nil interface value.
189+
return "", 0
190+
}
191+
192+
var columns []string
193+
if t.Kind() == reflect.String {
194+
columns = []string{subject.(string)}
195+
} else if t.Kind() == reflect.Struct || t.Kind() == reflect.Pointer {
196+
if scoper, ok := subject.(Scoper); ok {
197+
return qb.Where(scoper.Scope())
198+
}
199+
200+
columns = qb.db.columnMap.Columns(subject)
201+
} else { // This should never happen unless someone wants to do some silly things.
202+
panic(fmt.Sprintf("qb.Where: unknown subject type provided: %q", t.Kind().String()))
203+
}
204+
205+
where := make([]string, 0, len(columns))
206+
for _, col := range columns {
207+
where = append(where, fmt.Sprintf(`"%[1]s" = :%[1]s`, col))
208+
}
209+
210+
return strings.Join(where, ` AND `), len(columns)
211+
}
212+
213+
// BuildColumns returns all the Query Builder columns (if specified), otherwise they are
214+
// determined dynamically using its subject. Additionally, it checks whether columns need
215+
// to be excluded and proceeds accordingly.
216+
func (qb *QueryBuilder) BuildColumns() []string {
217+
var columns []string
218+
if len(qb.columns) > 0 {
219+
columns = qb.columns
220+
} else {
221+
columns = qb.db.columnMap.Columns(qb.subject)
222+
}
223+
224+
if len(qb.excludedColumns) > 0 {
225+
columns = slices.DeleteFunc(append([]string(nil), columns...), func(column string) bool {
226+
for _, exclude := range qb.excludedColumns {
227+
if exclude == column {
228+
return true
229+
}
230+
}
231+
232+
return false
233+
})
234+
}
235+
236+
if qb.sort {
237+
// The order in which the columns appear is not guaranteed as we extract the columns dynamically
238+
// from the struct. So, we've to sort them here to be able to test the generated statements.
239+
sort.SliceStable(columns, func(a, b int) bool {
240+
return columns[a] < columns[b]
241+
})
242+
}
243+
244+
return columns
245+
}
246+
247+
// getPgsqlOnConflictConstraint returns the constraint name of the current [QueryBuilder]'s subject.
248+
// If the subject does not implement the PgsqlOnConflictConstrainter interface, it will simply return
249+
// the table name prefixed with `pk_`.
250+
func (qb *QueryBuilder) getPgsqlOnConflictConstraint() string {
251+
if constrainter, ok := qb.subject.(PgsqlOnConflictConstrainter); ok {
252+
return constrainter.PgsqlOnConflictConstraint()
253+
}
254+
255+
return "pk_" + TableName(qb.subject)
256+
}

0 commit comments

Comments
 (0)