Skip to content

Commit ac03ee0

Browse files
authored
Merge pull request #174 from arran4/perf-sql-connection-reuse-10357083104816326878
⚡ Reuse SQL connection in SQLProvider
2 parents 9b86f5f + b944f5e commit ac03ee0

File tree

2 files changed

+51
-43
lines changed

2 files changed

+51
-43
lines changed

cmd/gobookmarks/provider_helper.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import (
88

99
func getConfiguredProvider(cfg *Config) (Provider, error) {
1010
if cfg.DBConnectionProvider != "" && cfg.DBConnectionString != "" {
11-
return SQLProvider{}, nil
11+
return &SQLProvider{}, nil
1212
}
1313
if cfg.LocalGitPath != "" {
1414
return GitProvider{}, nil

provider_sql.go

Lines changed: 50 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"errors"
1010
"fmt"
1111
"strings"
12+
"sync"
1213
"time"
1314

1415
_ "github.com/go-sql-driver/mysql"
@@ -17,30 +18,48 @@ import (
1718
"golang.org/x/oauth2"
1819
)
1920

20-
type SQLProvider struct{}
21+
type SQLProvider struct {
22+
db *sql.DB
23+
mu sync.Mutex
24+
}
2125

2226
const sqlSchemaVersion = 1
2327

2428
//go:embed sql/schema*.sql
2529
var sqlSchemas embed.FS
2630

2731
func init() {
28-
RegisterProvider(SQLProvider{})
32+
RegisterProvider(&SQLProvider{})
2933
}
3034

31-
func (SQLProvider) Name() string { return "sql" }
32-
func (SQLProvider) DefaultServer() string { return "" }
33-
func (SQLProvider) Config(clientID, clientSecret, redirectURL string) *oauth2.Config { return nil }
34-
func (SQLProvider) CurrentUser(ctx context.Context, token *oauth2.Token) (*User, error) {
35+
func (p *SQLProvider) getDB() (*sql.DB, error) {
36+
p.mu.Lock()
37+
defer p.mu.Unlock()
38+
39+
if p.db != nil {
40+
return p.db, nil
41+
}
42+
43+
db, err := OpenDB()
44+
if err != nil {
45+
return nil, err
46+
}
47+
p.db = db
48+
return p.db, nil
49+
}
50+
51+
func (p *SQLProvider) Name() string { return "sql" }
52+
func (p *SQLProvider) DefaultServer() string { return "" }
53+
func (p *SQLProvider) Config(clientID, clientSecret, redirectURL string) *oauth2.Config { return nil }
54+
func (p *SQLProvider) CurrentUser(ctx context.Context, token *oauth2.Token) (*User, error) {
3555
return nil, errors.New("not implemented")
3656
}
3757

38-
func (p SQLProvider) GetTags(ctx context.Context, user string, token *oauth2.Token) ([]*Tag, error) {
39-
db, err := OpenDB()
58+
func (p *SQLProvider) GetTags(ctx context.Context, user string, token *oauth2.Token) ([]*Tag, error) {
59+
db, err := p.getDB()
4060
if err != nil {
4161
return nil, err
4262
}
43-
defer db.Close()
4463

4564
rows, err := db.QueryContext(ctx, "SELECT name FROM tags WHERE user=? ORDER BY name", user)
4665
if err != nil {
@@ -59,12 +78,11 @@ func (p SQLProvider) GetTags(ctx context.Context, user string, token *oauth2.Tok
5978
return tags, rows.Err()
6079
}
6180

62-
func (p SQLProvider) GetBranches(ctx context.Context, user string, token *oauth2.Token) ([]*Branch, error) {
63-
db, err := OpenDB()
81+
func (p *SQLProvider) GetBranches(ctx context.Context, user string, token *oauth2.Token) ([]*Branch, error) {
82+
db, err := p.getDB()
6483
if err != nil {
6584
return nil, err
6685
}
67-
defer db.Close()
6886

6987
rows, err := db.QueryContext(ctx, "SELECT name FROM branches WHERE user=? ORDER BY name", user)
7088
if err != nil {
@@ -86,12 +104,11 @@ func (p SQLProvider) GetBranches(ctx context.Context, user string, token *oauth2
86104
return branches, rows.Err()
87105
}
88106

89-
func (p SQLProvider) GetCommits(ctx context.Context, user string, token *oauth2.Token, ref string, page, perPage int) ([]*Commit, error) {
90-
db, err := OpenDB()
107+
func (p *SQLProvider) GetCommits(ctx context.Context, user string, token *oauth2.Token, ref string, page, perPage int) ([]*Commit, error) {
108+
db, err := p.getDB()
91109
if err != nil {
92110
return nil, err
93111
}
94-
defer db.Close()
95112

96113
query := "SELECT sha, message, date FROM history WHERE user=? ORDER BY id DESC"
97114
args := []any{user}
@@ -123,12 +140,11 @@ func (p SQLProvider) GetCommits(ctx context.Context, user string, token *oauth2.
123140
return commits, rows.Err()
124141
}
125142

126-
func (p SQLProvider) AdjacentCommits(ctx context.Context, user string, token *oauth2.Token, ref, sha string) (string, string, error) {
127-
db, err := OpenDB()
143+
func (p *SQLProvider) AdjacentCommits(ctx context.Context, user string, token *oauth2.Token, ref, sha string) (string, string, error) {
144+
db, err := p.getDB()
128145
if err != nil {
129146
return "", "", err
130147
}
131-
defer db.Close()
132148

133149
var prev, next sql.NullString
134150
err = db.QueryRowContext(ctx, `
@@ -148,12 +164,11 @@ func (p SQLProvider) AdjacentCommits(ctx context.Context, user string, token *oa
148164
return prev.String, next.String, nil
149165
}
150166

151-
func (p SQLProvider) GetBookmarks(ctx context.Context, user, ref string, token *oauth2.Token) (string, string, error) {
152-
db, err := OpenDB()
167+
func (p *SQLProvider) GetBookmarks(ctx context.Context, user, ref string, token *oauth2.Token) (string, string, error) {
168+
db, err := p.getDB()
153169
if err != nil {
154170
return "", "", err
155171
}
156-
defer db.Close()
157172

158173
if ref == "" {
159174
ref = "refs/heads/main"
@@ -194,15 +209,14 @@ func (p SQLProvider) GetBookmarks(ctx context.Context, user, ref string, token *
194209
return text, sha, nil
195210
}
196211

197-
func (p SQLProvider) UpdateBookmarks(ctx context.Context, user string, token *oauth2.Token, sourceRef, branch, text, expectSHA string) error {
212+
func (p *SQLProvider) UpdateBookmarks(ctx context.Context, user string, token *oauth2.Token, sourceRef, branch, text, expectSHA string) error {
198213
if branch == "" {
199214
branch = "main"
200215
}
201-
db, err := OpenDB()
216+
db, err := p.getDB()
202217
if err != nil {
203218
return err
204219
}
205-
defer db.Close()
206220

207221
tx, err := db.BeginTx(ctx, nil)
208222
if err != nil {
@@ -266,15 +280,14 @@ func (p SQLProvider) UpdateBookmarks(ctx context.Context, user string, token *oa
266280
return tx.Commit()
267281
}
268282

269-
func (p SQLProvider) CreateBookmarks(ctx context.Context, user string, token *oauth2.Token, branch, text string) error {
283+
func (p *SQLProvider) CreateBookmarks(ctx context.Context, user string, token *oauth2.Token, branch, text string) error {
270284
if branch == "" {
271285
branch = "main"
272286
}
273-
db, err := OpenDB()
287+
db, err := p.getDB()
274288
if err != nil {
275289
return err
276290
}
277-
defer db.Close()
278291

279292
tx, err := db.BeginTx(ctx, nil)
280293
if err != nil {
@@ -350,12 +363,11 @@ func (p SQLProvider) CreateBookmarks(ctx context.Context, user string, token *oa
350363
return tx.Commit()
351364
}
352365

353-
func (p SQLProvider) CreateRepo(ctx context.Context, user string, token *oauth2.Token, name string) error {
354-
db, err := OpenDB()
366+
func (p *SQLProvider) CreateRepo(ctx context.Context, user string, token *oauth2.Token, name string) error {
367+
db, err := p.getDB()
355368
if err != nil {
356369
return err
357370
}
358-
defer db.Close()
359371

360372
tx, err := db.BeginTx(ctx, nil)
361373
if err != nil {
@@ -403,24 +415,22 @@ func (p SQLProvider) CreateRepo(ctx context.Context, user string, token *oauth2.
403415
return tx.Commit()
404416
}
405417

406-
func (p SQLProvider) RepoExists(ctx context.Context, user string, token *oauth2.Token, name string) (bool, error) {
407-
db, err := OpenDB()
418+
func (p *SQLProvider) RepoExists(ctx context.Context, user string, token *oauth2.Token, name string) (bool, error) {
419+
db, err := p.getDB()
408420
if err != nil {
409421
return false, err
410422
}
411-
defer db.Close()
412423

413424
var count int
414425
err = db.QueryRowContext(ctx, "SELECT COUNT(1) FROM bookmarks WHERE user=?", user).Scan(&count)
415426
return count > 0, err
416427
}
417428

418-
func (p SQLProvider) CreateUser(ctx context.Context, user, password string) error {
419-
db, err := OpenDB()
429+
func (p *SQLProvider) CreateUser(ctx context.Context, user, password string) error {
430+
db, err := p.getDB()
420431
if err != nil {
421432
return err
422433
}
423-
defer db.Close()
424434

425435
var count int
426436
if err := db.QueryRowContext(ctx, "SELECT COUNT(1) FROM passwords WHERE user=?", user).Scan(&count); err != nil {
@@ -438,12 +448,11 @@ func (p SQLProvider) CreateUser(ctx context.Context, user, password string) erro
438448
return err
439449
}
440450

441-
func (p SQLProvider) SetPassword(ctx context.Context, user, password string) error {
442-
db, err := OpenDB()
451+
func (p *SQLProvider) SetPassword(ctx context.Context, user, password string) error {
452+
db, err := p.getDB()
443453
if err != nil {
444454
return err
445455
}
446-
defer db.Close()
447456

448457
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
449458
if err != nil {
@@ -460,12 +469,11 @@ func (p SQLProvider) SetPassword(ctx context.Context, user, password string) err
460469
return nil
461470
}
462471

463-
func (p SQLProvider) CheckPassword(ctx context.Context, user, password string) (bool, error) {
464-
db, err := OpenDB()
472+
func (p *SQLProvider) CheckPassword(ctx context.Context, user, password string) (bool, error) {
473+
db, err := p.getDB()
465474
if err != nil {
466475
return false, err
467476
}
468-
defer db.Close()
469477

470478
var hash []byte
471479
err = db.QueryRowContext(ctx, "SELECT hash FROM passwords WHERE user=?", user).Scan(&hash)

0 commit comments

Comments
 (0)