99 "errors"
1010 "fmt"
1111 "strings"
12+ "sync"
1213 "time"
1314
1415 _ "github.com/go-sql-driver/mysql"
@@ -17,30 +18,48 @@ import (
1718 "golang.org/x/oauth2"
1819)
1920
20- type SQLProvider struct {}
21+ type SQLProvider struct {
22+ db * sql.DB
23+ mu sync.Mutex
24+ }
2125
2226const sqlSchemaVersion = 1
2327
2428//go:embed sql/schema*.sql
2529var sqlSchemas embed.FS
2630
2731func init () {
28- RegisterProvider (SQLProvider {})
32+ RegisterProvider (& SQLProvider {})
2933}
3034
31- func (SQLProvider ) Name () string { return "sql" }
32- func (SQLProvider ) DefaultServer () string { return "" }
33- func (SQLProvider ) Config (clientID , clientSecret , redirectURL string ) * oauth2.Config { return nil }
34- func (SQLProvider ) CurrentUser (ctx context.Context , token * oauth2.Token ) (* User , error ) {
35+ func (p * SQLProvider ) getDB () (* sql.DB , error ) {
36+ p .mu .Lock ()
37+ defer p .mu .Unlock ()
38+
39+ if p .db != nil {
40+ return p .db , nil
41+ }
42+
43+ db , err := OpenDB ()
44+ if err != nil {
45+ return nil , err
46+ }
47+ p .db = db
48+ return p .db , nil
49+ }
50+
51+ func (p * SQLProvider ) Name () string { return "sql" }
52+ func (p * SQLProvider ) DefaultServer () string { return "" }
53+ func (p * SQLProvider ) Config (clientID , clientSecret , redirectURL string ) * oauth2.Config { return nil }
54+ func (p * SQLProvider ) CurrentUser (ctx context.Context , token * oauth2.Token ) (* User , error ) {
3555 return nil , errors .New ("not implemented" )
3656}
3757
38- func (p SQLProvider ) GetTags (ctx context.Context , user string , token * oauth2.Token ) ([]* Tag , error ) {
39- db , err := OpenDB ()
58+ func (p * SQLProvider ) GetTags (ctx context.Context , user string , token * oauth2.Token ) ([]* Tag , error ) {
59+ db , err := p . getDB ()
4060 if err != nil {
4161 return nil , err
4262 }
43- defer db .Close ()
4463
4564 rows , err := db .QueryContext (ctx , "SELECT name FROM tags WHERE user=? ORDER BY name" , user )
4665 if err != nil {
@@ -59,12 +78,11 @@ func (p SQLProvider) GetTags(ctx context.Context, user string, token *oauth2.Tok
5978 return tags , rows .Err ()
6079}
6180
62- func (p SQLProvider ) GetBranches (ctx context.Context , user string , token * oauth2.Token ) ([]* Branch , error ) {
63- db , err := OpenDB ()
81+ func (p * SQLProvider ) GetBranches (ctx context.Context , user string , token * oauth2.Token ) ([]* Branch , error ) {
82+ db , err := p . getDB ()
6483 if err != nil {
6584 return nil , err
6685 }
67- defer db .Close ()
6886
6987 rows , err := db .QueryContext (ctx , "SELECT name FROM branches WHERE user=? ORDER BY name" , user )
7088 if err != nil {
@@ -86,12 +104,11 @@ func (p SQLProvider) GetBranches(ctx context.Context, user string, token *oauth2
86104 return branches , rows .Err ()
87105}
88106
89- func (p SQLProvider ) GetCommits (ctx context.Context , user string , token * oauth2.Token , ref string , page , perPage int ) ([]* Commit , error ) {
90- db , err := OpenDB ()
107+ func (p * SQLProvider ) GetCommits (ctx context.Context , user string , token * oauth2.Token , ref string , page , perPage int ) ([]* Commit , error ) {
108+ db , err := p . getDB ()
91109 if err != nil {
92110 return nil , err
93111 }
94- defer db .Close ()
95112
96113 query := "SELECT sha, message, date FROM history WHERE user=? ORDER BY id DESC"
97114 args := []any {user }
@@ -123,12 +140,11 @@ func (p SQLProvider) GetCommits(ctx context.Context, user string, token *oauth2.
123140 return commits , rows .Err ()
124141}
125142
126- func (p SQLProvider ) AdjacentCommits (ctx context.Context , user string , token * oauth2.Token , ref , sha string ) (string , string , error ) {
127- db , err := OpenDB ()
143+ func (p * SQLProvider ) AdjacentCommits (ctx context.Context , user string , token * oauth2.Token , ref , sha string ) (string , string , error ) {
144+ db , err := p . getDB ()
128145 if err != nil {
129146 return "" , "" , err
130147 }
131- defer db .Close ()
132148
133149 var prev , next sql.NullString
134150 err = db .QueryRowContext (ctx , `
@@ -148,12 +164,11 @@ func (p SQLProvider) AdjacentCommits(ctx context.Context, user string, token *oa
148164 return prev .String , next .String , nil
149165}
150166
151- func (p SQLProvider ) GetBookmarks (ctx context.Context , user , ref string , token * oauth2.Token ) (string , string , error ) {
152- db , err := OpenDB ()
167+ func (p * SQLProvider ) GetBookmarks (ctx context.Context , user , ref string , token * oauth2.Token ) (string , string , error ) {
168+ db , err := p . getDB ()
153169 if err != nil {
154170 return "" , "" , err
155171 }
156- defer db .Close ()
157172
158173 if ref == "" {
159174 ref = "refs/heads/main"
@@ -194,15 +209,14 @@ func (p SQLProvider) GetBookmarks(ctx context.Context, user, ref string, token *
194209 return text , sha , nil
195210}
196211
197- func (p SQLProvider ) UpdateBookmarks (ctx context.Context , user string , token * oauth2.Token , sourceRef , branch , text , expectSHA string ) error {
212+ func (p * SQLProvider ) UpdateBookmarks (ctx context.Context , user string , token * oauth2.Token , sourceRef , branch , text , expectSHA string ) error {
198213 if branch == "" {
199214 branch = "main"
200215 }
201- db , err := OpenDB ()
216+ db , err := p . getDB ()
202217 if err != nil {
203218 return err
204219 }
205- defer db .Close ()
206220
207221 tx , err := db .BeginTx (ctx , nil )
208222 if err != nil {
@@ -266,15 +280,14 @@ func (p SQLProvider) UpdateBookmarks(ctx context.Context, user string, token *oa
266280 return tx .Commit ()
267281}
268282
269- func (p SQLProvider ) CreateBookmarks (ctx context.Context , user string , token * oauth2.Token , branch , text string ) error {
283+ func (p * SQLProvider ) CreateBookmarks (ctx context.Context , user string , token * oauth2.Token , branch , text string ) error {
270284 if branch == "" {
271285 branch = "main"
272286 }
273- db , err := OpenDB ()
287+ db , err := p . getDB ()
274288 if err != nil {
275289 return err
276290 }
277- defer db .Close ()
278291
279292 tx , err := db .BeginTx (ctx , nil )
280293 if err != nil {
@@ -350,12 +363,11 @@ func (p SQLProvider) CreateBookmarks(ctx context.Context, user string, token *oa
350363 return tx .Commit ()
351364}
352365
353- func (p SQLProvider ) CreateRepo (ctx context.Context , user string , token * oauth2.Token , name string ) error {
354- db , err := OpenDB ()
366+ func (p * SQLProvider ) CreateRepo (ctx context.Context , user string , token * oauth2.Token , name string ) error {
367+ db , err := p . getDB ()
355368 if err != nil {
356369 return err
357370 }
358- defer db .Close ()
359371
360372 tx , err := db .BeginTx (ctx , nil )
361373 if err != nil {
@@ -403,24 +415,22 @@ func (p SQLProvider) CreateRepo(ctx context.Context, user string, token *oauth2.
403415 return tx .Commit ()
404416}
405417
406- func (p SQLProvider ) RepoExists (ctx context.Context , user string , token * oauth2.Token , name string ) (bool , error ) {
407- db , err := OpenDB ()
418+ func (p * SQLProvider ) RepoExists (ctx context.Context , user string , token * oauth2.Token , name string ) (bool , error ) {
419+ db , err := p . getDB ()
408420 if err != nil {
409421 return false , err
410422 }
411- defer db .Close ()
412423
413424 var count int
414425 err = db .QueryRowContext (ctx , "SELECT COUNT(1) FROM bookmarks WHERE user=?" , user ).Scan (& count )
415426 return count > 0 , err
416427}
417428
418- func (p SQLProvider ) CreateUser (ctx context.Context , user , password string ) error {
419- db , err := OpenDB ()
429+ func (p * SQLProvider ) CreateUser (ctx context.Context , user , password string ) error {
430+ db , err := p . getDB ()
420431 if err != nil {
421432 return err
422433 }
423- defer db .Close ()
424434
425435 var count int
426436 if err := db .QueryRowContext (ctx , "SELECT COUNT(1) FROM passwords WHERE user=?" , user ).Scan (& count ); err != nil {
@@ -438,12 +448,11 @@ func (p SQLProvider) CreateUser(ctx context.Context, user, password string) erro
438448 return err
439449}
440450
441- func (p SQLProvider ) SetPassword (ctx context.Context , user , password string ) error {
442- db , err := OpenDB ()
451+ func (p * SQLProvider ) SetPassword (ctx context.Context , user , password string ) error {
452+ db , err := p . getDB ()
443453 if err != nil {
444454 return err
445455 }
446- defer db .Close ()
447456
448457 hash , err := bcrypt .GenerateFromPassword ([]byte (password ), bcrypt .DefaultCost )
449458 if err != nil {
@@ -460,12 +469,11 @@ func (p SQLProvider) SetPassword(ctx context.Context, user, password string) err
460469 return nil
461470}
462471
463- func (p SQLProvider ) CheckPassword (ctx context.Context , user , password string ) (bool , error ) {
464- db , err := OpenDB ()
472+ func (p * SQLProvider ) CheckPassword (ctx context.Context , user , password string ) (bool , error ) {
473+ db , err := p . getDB ()
465474 if err != nil {
466475 return false , err
467476 }
468- defer db .Close ()
469477
470478 var hash []byte
471479 err = db .QueryRowContext (ctx , "SELECT hash FROM passwords WHERE user=?" , user ).Scan (& hash )
0 commit comments