Skip to content

Commit 5253251

Browse files
committed
Introduce QueryBuilder type
1 parent a4de971 commit 5253251

File tree

2 files changed

+491
-0
lines changed

2 files changed

+491
-0
lines changed

database/query_builder.go

Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
package database
2+
3+
import (
4+
"fmt"
5+
"golang.org/x/exp/slices"
6+
"reflect"
7+
"sort"
8+
"strings"
9+
)
10+
11+
// QueryBuilder is an addon for the [DB] type that takes care of all the database statement building shenanigans.
12+
// The recommended use of QueryBuilder is to only use it to generate a single query at a time and not two different
13+
// ones. If for instance you want to generate `INSERT` and `SELECT` queries, it is best to use two different
14+
// QueryBuilder instances. You can use the DB#QueryBuilder() method to get fully initialised instances each time.
15+
type QueryBuilder struct {
16+
subject any
17+
columns []string
18+
excludedColumns []string
19+
20+
// Indicates whether the generated columns should be sorted in ascending order before generating the
21+
// actual statements. This is intended for unit tests only and shouldn't be necessary for production code.
22+
sort bool
23+
}
24+
25+
// SetColumns sets the DB columns to be used when building the statements.
26+
// When you do not want the columns to be extracted dynamically, you can use this method to specify them manually.
27+
// Returns the current *[QueryBuilder] receiver and allows you to chain some method calls.
28+
func (qb *QueryBuilder) SetColumns(columns ...string) *QueryBuilder {
29+
qb.columns = columns
30+
return qb
31+
}
32+
33+
// SetExcludedColumns excludes the given columns from all the database statements.
34+
// Returns the current *[QueryBuilder] receiver and allows you to chain some method calls.
35+
func (qb *QueryBuilder) SetExcludedColumns(columns ...string) *QueryBuilder {
36+
qb.excludedColumns = columns
37+
return qb
38+
}
39+
40+
// Delete returns a DELETE statement for the query builders subject filtered by ID.
41+
func (qb *QueryBuilder) Delete() string {
42+
return qb.DeleteBy("id")
43+
}
44+
45+
// DeleteBy returns a DELETE statement for the query builders subject filtered by the given column.
46+
func (qb *QueryBuilder) DeleteBy(column string) string {
47+
return fmt.Sprintf(`DELETE FROM "%s" WHERE "%s" IN (?)`, TableName(qb.subject), column)
48+
}
49+
50+
// Insert returns an INSERT INTO statement for the query builders subject.
51+
func (qb *QueryBuilder) Insert(db *DB) (string, int) {
52+
columns := qb.BuildColumns(db)
53+
54+
return fmt.Sprintf(
55+
`INSERT INTO "%s" ("%s") VALUES (:%s)`,
56+
TableName(qb.subject),
57+
strings.Join(columns, `", "`),
58+
strings.Join(columns, ", :"),
59+
), len(columns)
60+
}
61+
62+
// InsertIgnore returns an INSERT statement for the query builders subject for
63+
// which the database ignores rows that have already been inserted.
64+
func (qb *QueryBuilder) InsertIgnore(db *DB) (string, int) {
65+
columns := qb.BuildColumns(db)
66+
67+
var clause string
68+
switch db.DriverName() {
69+
case MySQL:
70+
// MySQL treats UPDATE id = id as a no-op.
71+
clause = fmt.Sprintf(`ON DUPLICATE KEY UPDATE "%[1]s" = "%[1]s"`, columns[0])
72+
case PostgreSQL:
73+
clause = fmt.Sprintf("ON CONFLICT ON CONSTRAINT %s DO NOTHING", qb.getPgsqlOnConflictConstraint())
74+
default:
75+
panic("Driver unsupported: " + db.DriverName())
76+
}
77+
78+
return fmt.Sprintf(
79+
`INSERT INTO "%s" ("%s") VALUES (:%s) %s`,
80+
TableName(qb.subject),
81+
strings.Join(columns, `", "`),
82+
strings.Join(columns, ", :"),
83+
clause,
84+
), len(columns)
85+
}
86+
87+
// Select returns a SELECT statement from the query builders subject and the already set columns.
88+
// If no columns are set, they will be extracted from the query builders subject.
89+
// When the query builders subject is of type Scoper, a WHERE clause is appended to the statement.
90+
func (qb *QueryBuilder) Select(db *DB) string {
91+
var scoper Scoper
92+
if sc, ok := qb.subject.(Scoper); ok {
93+
scoper = sc
94+
}
95+
96+
return qb.SelectScoped(db, scoper)
97+
}
98+
99+
// SelectScoped returns a SELECT statement from the query builders subject and the already set columns filtered
100+
// by the given scoper/column. When no columns are set, they will be extracted from the query builders subject.
101+
// The argument scoper must either be of type Scoper, string or nil to get SELECT statements without a WHERE clause.
102+
func (qb *QueryBuilder) SelectScoped(db *DB, scoper any) string {
103+
query := fmt.Sprintf(`SELECT "%s" FROM "%s"`, strings.Join(qb.BuildColumns(db), `", "`), TableName(qb.subject))
104+
where, placeholders := qb.Where(db, scoper)
105+
if placeholders > 0 {
106+
query += ` WHERE ` + where
107+
}
108+
109+
return query
110+
}
111+
112+
// Update returns an UPDATE statement for the query builders subject filter by ID column.
113+
func (qb *QueryBuilder) Update(db *DB) (string, int) {
114+
return qb.UpdateScoped(db, "id")
115+
}
116+
117+
// UpdateScoped returns an UPDATE statement for the query builders subject filtered by the given column/scoper.
118+
// The argument scoper must either be of type Scoper, string or nil to get UPDATE statements without a WHERE clause.
119+
func (qb *QueryBuilder) UpdateScoped(db *DB, scoper any) (string, int) {
120+
columns := qb.BuildColumns(db)
121+
set := make([]string, 0, len(columns))
122+
123+
for _, col := range columns {
124+
set = append(set, fmt.Sprintf(`"%[1]s" = :%[1]s`, col))
125+
}
126+
127+
placeholders := len(columns)
128+
query := `UPDATE "%s" SET %s`
129+
if where, count := qb.Where(db, scoper); count > 0 {
130+
placeholders += count
131+
query += ` WHERE ` + where
132+
}
133+
134+
return fmt.Sprintf(query, TableName(qb.subject), strings.Join(set, ", ")), placeholders
135+
}
136+
137+
// Upsert returns an upsert statement for the query builders subject.
138+
func (qb *QueryBuilder) Upsert(db *DB) (string, int) {
139+
var updateColumns []string
140+
if upserter, ok := qb.subject.(Upserter); ok {
141+
updateColumns = db.columnMap.Columns(upserter.Upsert())
142+
} else {
143+
updateColumns = qb.BuildColumns(db)
144+
}
145+
146+
return qb.UpsertColumns(db, updateColumns...)
147+
}
148+
149+
// UpsertColumns returns an upsert statement for the query builders subject and the specified update columns.
150+
func (qb *QueryBuilder) UpsertColumns(db *DB, updateColumns ...string) (string, int) {
151+
var clause, setFormat string
152+
switch db.DriverName() {
153+
case MySQL:
154+
clause = "ON DUPLICATE KEY UPDATE"
155+
setFormat = `"%[1]s" = VALUES("%[1]s")`
156+
case PostgreSQL:
157+
clause = fmt.Sprintf("ON CONFLICT ON CONSTRAINT %s DO UPDATE SET", qb.getPgsqlOnConflictConstraint())
158+
setFormat = `"%[1]s" = EXCLUDED."%[1]s"`
159+
default:
160+
panic("Driver unsupported: " + db.DriverName())
161+
}
162+
163+
set := make([]string, 0, len(updateColumns))
164+
for _, col := range updateColumns {
165+
set = append(set, fmt.Sprintf(setFormat, col))
166+
}
167+
168+
insertColumns := qb.BuildColumns(db)
169+
170+
return fmt.Sprintf(
171+
`INSERT INTO "%s" ("%s") VALUES (:%s) %s %s`,
172+
TableName(qb.subject),
173+
strings.Join(insertColumns, `", "`),
174+
strings.Join(insertColumns, ", :"),
175+
clause,
176+
strings.Join(set, ", "),
177+
), len(insertColumns)
178+
}
179+
180+
// Where returns a WHERE clause with named placeholder conditions built from the
181+
// specified scoper/column combined with the AND operator.
182+
func (qb *QueryBuilder) Where(db *DB, subject any) (string, int) {
183+
t := reflect.TypeOf(subject)
184+
if t == nil { // Subject is a nil interface value.
185+
return "", 0
186+
}
187+
188+
var columns []string
189+
if t.Kind() == reflect.String {
190+
columns = []string{subject.(string)}
191+
} else if t.Kind() == reflect.Struct || t.Kind() == reflect.Pointer {
192+
if scoper, ok := subject.(Scoper); ok {
193+
return qb.Where(db, scoper.Scope())
194+
}
195+
196+
columns = db.columnMap.Columns(subject)
197+
} else { // This should never happen unless someone wants to do some silly things.
198+
panic(fmt.Sprintf("qb.Where: unknown subject type provided: %q", t.Kind().String()))
199+
}
200+
201+
where := make([]string, 0, len(columns))
202+
for _, col := range columns {
203+
where = append(where, fmt.Sprintf(`"%[1]s" = :%[1]s`, col))
204+
}
205+
206+
return strings.Join(where, ` AND `), len(columns)
207+
}
208+
209+
// BuildColumns returns all the Query Builder columns (if specified), otherwise they are
210+
// determined dynamically using its subject. Additionally, it checks whether columns need
211+
// to be excluded and proceeds accordingly.
212+
func (qb *QueryBuilder) BuildColumns(db *DB) []string {
213+
var columns []string
214+
if len(qb.columns) > 0 {
215+
columns = qb.columns
216+
} else {
217+
columns = db.columnMap.Columns(qb.subject)
218+
}
219+
220+
if len(qb.excludedColumns) > 0 {
221+
columns = slices.DeleteFunc(slices.Clone(columns), func(column string) bool {
222+
return slices.Contains(qb.excludedColumns, column)
223+
})
224+
}
225+
226+
if qb.sort {
227+
// The order in which the columns appear is not guaranteed as we extract the columns dynamically
228+
// from the struct. So, we've to sort them here to be able to test the generated statements.
229+
sort.Strings(columns)
230+
}
231+
232+
return slices.Clip(columns)
233+
}
234+
235+
// getPgsqlOnConflictConstraint returns the constraint name of the current [QueryBuilder]'s subject.
236+
// If the subject does not implement the PgsqlOnConflictConstrainter interface, it will simply return
237+
// the table name prefixed with `pk_`.
238+
func (qb *QueryBuilder) getPgsqlOnConflictConstraint() string {
239+
if constrainter, ok := qb.subject.(PgsqlOnConflictConstrainter); ok {
240+
return constrainter.PgsqlOnConflictConstraint()
241+
}
242+
243+
return "pk_" + TableName(qb.subject)
244+
}

0 commit comments

Comments
 (0)