-
Notifications
You must be signed in to change notification settings - Fork 436
Expand file tree
/
Copy pathstmtcache.go
More file actions
200 lines (168 loc) · 5.34 KB
/
stmtcache.go
File metadata and controls
200 lines (168 loc) · 5.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
// Package stmtcache contains tools for managing the prepared-statement cache.
package stmtcache
import (
"context"
"database/sql"
"fmt"
"strings"
"sync"
"github.com/google/trillian/monitoring"
"k8s.io/klog/v2"
)
var (
once sync.Once
errStmtCounter monitoring.Counter
)
// PlaceholderSQL SQL statement placeholder.
const PlaceholderSQL = "<placeholder>"
// Stmt is wraps the sql.Stmt struct for handling and monitoring SQL errors.
// If Stmt execution errors occur, it is automatically closed and the prepared statements in the cache are cleared.
type Stmt struct {
statement string
placeholderNum int
stmtCache *StmtCache
stmt *sql.Stmt
parentStmt *Stmt
}
// errHandler handling and monitoring SQL errors
// This err parameter is not currently used, but it may be necessary to perform more granular processing and monitoring of different errs in the future.
func (s *Stmt) errHandler(_ error) {
o := s
if s.parentStmt != nil {
o = s.parentStmt
}
if err := o.Close(); err != nil {
klog.Warningf("Failed to close stmt: %s", err)
}
if o.stmtCache != nil {
once.Do(func() {
errStmtCounter = o.stmtCache.mf.NewCounter("sql_stmt_errors", "Number of statement execution errors")
})
errStmtCounter.Inc()
}
}
// SQLStmt returns the referenced sql.Stmt struct.
func (s *Stmt) SQLStmt() *sql.Stmt {
return s.stmt
}
// Close closes the Stmt.
// Clear if Stmt belongs to cache
func (s *Stmt) Close() error {
if cache := s.stmtCache; cache != nil {
cache.clearOne(s)
}
return s.stmt.Close()
}
// WithTx returns a transaction-specific prepared statement from
// an existing statement.
// The transaction-specific Stmt is closed by the caller.
func (s *Stmt) WithTx(ctx context.Context, tx *sql.Tx) *Stmt {
parent := s
if s.parentStmt != nil {
parent = s.parentStmt
}
return &Stmt{
parentStmt: parent,
stmt: tx.StmtContext(ctx, parent.stmt),
}
}
// ExecContext executes a prepared statement with the given arguments and
// returns a Result summarizing the effect of the statement.
func (s *Stmt) ExecContext(ctx context.Context, args ...any) (sql.Result, error) {
res, err := s.stmt.ExecContext(ctx, args...)
if err != nil {
s.errHandler(err)
}
return res, err
}
// QueryContext executes a prepared query statement with the given arguments
// and returns the query results as a *Rows.
func (s *Stmt) QueryContext(ctx context.Context, args ...any) (*sql.Rows, error) {
res, err := s.stmt.QueryContext(ctx, args...)
if err != nil {
s.errHandler(err)
}
return res, err
}
// QueryRowContext executes a prepared query statement with the given arguments.
// If an error occurs during the execution of the statement, that error will
// be returned by a call to Scan on the returned *Row, which is always non-nil.
// If the query selects no rows, the *Row's Scan will return ErrNoRows.
// Otherwise, the *Row's Scan scans the first selected row and discards
// the rest.
func (s *Stmt) QueryRowContext(ctx context.Context, args ...any) *sql.Row {
res := s.stmt.QueryRowContext(ctx, args...)
if err := res.Err(); err != nil {
s.errHandler(err)
}
return res
}
// StmtCache is a cache of the sql.Stmt structs.
type StmtCache struct {
db *sql.DB
statementMutex sync.Mutex
statements map[string]map[int]*sql.Stmt
mf monitoring.MetricFactory
}
// New creates a StmtCache instance.
func New(db *sql.DB, mf monitoring.MetricFactory) *StmtCache {
if mf == nil {
mf = monitoring.InertMetricFactory{}
}
return &StmtCache{
db: db,
statements: make(map[string]map[int]*sql.Stmt),
mf: mf,
}
}
// clearOne clear the cache of a sql.Stmt.
func (sc *StmtCache) clearOne(s *Stmt) {
if s == nil || s.stmt == nil || s.stmtCache != sc {
return
}
sc.statementMutex.Lock()
defer sc.statementMutex.Unlock()
if _s, ok := sc.statements[s.statement][s.placeholderNum]; ok && _s == s.stmt {
sc.statements[s.statement][s.placeholderNum] = nil
}
}
func (sc *StmtCache) getStmt(ctx context.Context, statement string, num int, first, rest string) (*sql.Stmt, error) {
sc.statementMutex.Lock()
defer sc.statementMutex.Unlock()
if sc.statements[statement] != nil {
if sc.statements[statement][num] != nil {
return sc.statements[statement][num], nil
}
} else {
sc.statements[statement] = make(map[int]*sql.Stmt)
}
s, err := sc.db.PrepareContext(ctx, expandPlaceholderSQL(statement, num, first, rest))
if err != nil {
klog.Warningf("Failed to prepare statement %d: %s", num, err)
return nil, err
}
sc.statements[statement][num] = s
return s, nil
}
// expandPlaceholderSQL expands an sql statement by adding a specified number of '?'
// placeholder slots. At most one placeholder will be expanded.
func expandPlaceholderSQL(sql string, num int, first, rest string) string {
if num <= 0 {
panic(fmt.Errorf("trying to expand SQL placeholder with <= 0 parameters: %s", sql))
}
parameters := first + strings.Repeat(","+rest, num-1)
return strings.Replace(sql, PlaceholderSQL, parameters, 1)
}
// GetStmt creates and caches sql.Stmt and returns their wrapper Stmt.
func (sc *StmtCache) GetStmt(ctx context.Context, statement string, num int, first, rest string) (*Stmt, error) {
stmt, err := sc.getStmt(ctx, statement, num, first, rest)
if err != nil {
return nil, err
}
return &Stmt{
statement: statement,
placeholderNum: num,
stmtCache: sc,
stmt: stmt,
}, nil
}