99 "errors"
1010 "fmt"
1111 "strings"
12+ "sync"
1213 "time"
1314
1415 _ "github.com/go-sql-driver/mysql"
@@ -17,30 +18,48 @@ import (
1718 "golang.org/x/oauth2"
1819)
1920
20- type SQLProvider struct {}
21+ type SQLProvider struct {
22+ db * sql.DB
23+ mu sync.Mutex
24+ }
2125
2226const sqlSchemaVersion = 1
2327
2428//go:embed sql/schema*.sql
2529var sqlSchemas embed.FS
2630
2731func init () {
28- RegisterProvider (SQLProvider {})
32+ RegisterProvider (& SQLProvider {})
2933}
3034
31- func (SQLProvider ) Name () string { return "sql" }
32- func (SQLProvider ) DefaultServer () string { return "" }
33- func (SQLProvider ) Config (clientID , clientSecret , redirectURL string ) * oauth2.Config { return nil }
34- func (SQLProvider ) CurrentUser (ctx context.Context , token * oauth2.Token ) (* User , error ) {
35+ func (p * SQLProvider ) getDB () (* sql.DB , error ) {
36+ p .mu .Lock ()
37+ defer p .mu .Unlock ()
38+
39+ if p .db != nil {
40+ return p .db , nil
41+ }
42+
43+ db , err := OpenDB ()
44+ if err != nil {
45+ return nil , err
46+ }
47+ p .db = db
48+ return p .db , nil
49+ }
50+
51+ func (p * SQLProvider ) Name () string { return "sql" }
52+ func (p * SQLProvider ) DefaultServer () string { return "" }
53+ func (p * SQLProvider ) Config (clientID , clientSecret , redirectURL string ) * oauth2.Config { return nil }
54+ func (p * SQLProvider ) CurrentUser (ctx context.Context , token * oauth2.Token ) (* User , error ) {
3555 return nil , errors .New ("not implemented" )
3656}
3757
38- func (p SQLProvider ) GetTags (ctx context.Context , user string , token * oauth2.Token ) ([]* Tag , error ) {
39- db , err := OpenDB ()
58+ func (p * SQLProvider ) GetTags (ctx context.Context , user string , token * oauth2.Token ) ([]* Tag , error ) {
59+ db , err := p . getDB ()
4060 if err != nil {
4161 return nil , err
4262 }
43- defer db .Close ()
4463
4564 rows , err := db .QueryContext (ctx , "SELECT name FROM tags WHERE user=? ORDER BY name" , user )
4665 if err != nil {
@@ -59,12 +78,11 @@ func (p SQLProvider) GetTags(ctx context.Context, user string, token *oauth2.Tok
5978 return tags , rows .Err ()
6079}
6180
62- func (p SQLProvider ) GetBranches (ctx context.Context , user string , token * oauth2.Token ) ([]* Branch , error ) {
63- db , err := OpenDB ()
81+ func (p * SQLProvider ) GetBranches (ctx context.Context , user string , token * oauth2.Token ) ([]* Branch , error ) {
82+ db , err := p . getDB ()
6483 if err != nil {
6584 return nil , err
6685 }
67- defer db .Close ()
6886
6987 rows , err := db .QueryContext (ctx , "SELECT name FROM branches WHERE user=? ORDER BY name" , user )
7088 if err != nil {
@@ -86,12 +104,11 @@ func (p SQLProvider) GetBranches(ctx context.Context, user string, token *oauth2
86104 return branches , rows .Err ()
87105}
88106
89- func (p SQLProvider ) GetCommits (ctx context.Context , user string , token * oauth2.Token , ref string , page , perPage int ) ([]* Commit , error ) {
90- db , err := OpenDB ()
107+ func (p * SQLProvider ) GetCommits (ctx context.Context , user string , token * oauth2.Token , ref string , page , perPage int ) ([]* Commit , error ) {
108+ db , err := p . getDB ()
91109 if err != nil {
92110 return nil , err
93111 }
94- defer db .Close ()
95112
96113 query := "SELECT sha, message, date FROM history WHERE user=? ORDER BY id DESC"
97114 args := []any {user }
@@ -123,12 +140,11 @@ func (p SQLProvider) GetCommits(ctx context.Context, user string, token *oauth2.
123140 return commits , rows .Err ()
124141}
125142
126- func (p SQLProvider ) AdjacentCommits (ctx context.Context , user string , token * oauth2.Token , ref , sha string ) (string , string , error ) {
127- db , err := OpenDB ()
143+ func (p * SQLProvider ) AdjacentCommits (ctx context.Context , user string , token * oauth2.Token , ref , sha string ) (string , string , error ) {
144+ db , err := p . getDB ()
128145 if err != nil {
129146 return "" , "" , err
130147 }
131- defer db .Close ()
132148
133149 var id int
134150 err = db .QueryRowContext (ctx , "SELECT id FROM history WHERE user=? AND sha=?" , user , sha ).Scan (& id )
@@ -150,12 +166,11 @@ func (p SQLProvider) AdjacentCommits(ctx context.Context, user string, token *oa
150166 return prev .String , next .String , nil
151167}
152168
153- func (p SQLProvider ) GetBookmarks (ctx context.Context , user , ref string , token * oauth2.Token ) (string , string , error ) {
154- db , err := OpenDB ()
169+ func (p * SQLProvider ) GetBookmarks (ctx context.Context , user , ref string , token * oauth2.Token ) (string , string , error ) {
170+ db , err := p . getDB ()
155171 if err != nil {
156172 return "" , "" , err
157173 }
158- defer db .Close ()
159174
160175 if ref == "" {
161176 ref = "refs/heads/main"
@@ -196,15 +211,14 @@ func (p SQLProvider) GetBookmarks(ctx context.Context, user, ref string, token *
196211 return text , sha , nil
197212}
198213
199- func (p SQLProvider ) UpdateBookmarks (ctx context.Context , user string , token * oauth2.Token , sourceRef , branch , text , expectSHA string ) error {
214+ func (p * SQLProvider ) UpdateBookmarks (ctx context.Context , user string , token * oauth2.Token , sourceRef , branch , text , expectSHA string ) error {
200215 if branch == "" {
201216 branch = "main"
202217 }
203- db , err := OpenDB ()
218+ db , err := p . getDB ()
204219 if err != nil {
205220 return err
206221 }
207- defer db .Close ()
208222
209223 tx , err := db .BeginTx (ctx , nil )
210224 if err != nil {
@@ -268,15 +282,14 @@ func (p SQLProvider) UpdateBookmarks(ctx context.Context, user string, token *oa
268282 return tx .Commit ()
269283}
270284
271- func (p SQLProvider ) CreateBookmarks (ctx context.Context , user string , token * oauth2.Token , branch , text string ) error {
285+ func (p * SQLProvider ) CreateBookmarks (ctx context.Context , user string , token * oauth2.Token , branch , text string ) error {
272286 if branch == "" {
273287 branch = "main"
274288 }
275- db , err := OpenDB ()
289+ db , err := p . getDB ()
276290 if err != nil {
277291 return err
278292 }
279- defer db .Close ()
280293
281294 tx , err := db .BeginTx (ctx , nil )
282295 if err != nil {
@@ -352,12 +365,11 @@ func (p SQLProvider) CreateBookmarks(ctx context.Context, user string, token *oa
352365 return tx .Commit ()
353366}
354367
355- func (p SQLProvider ) CreateRepo (ctx context.Context , user string , token * oauth2.Token , name string ) error {
356- db , err := OpenDB ()
368+ func (p * SQLProvider ) CreateRepo (ctx context.Context , user string , token * oauth2.Token , name string ) error {
369+ db , err := p . getDB ()
357370 if err != nil {
358371 return err
359372 }
360- defer db .Close ()
361373
362374 tx , err := db .BeginTx (ctx , nil )
363375 if err != nil {
@@ -405,24 +417,22 @@ func (p SQLProvider) CreateRepo(ctx context.Context, user string, token *oauth2.
405417 return tx .Commit ()
406418}
407419
408- func (p SQLProvider ) RepoExists (ctx context.Context , user string , token * oauth2.Token , name string ) (bool , error ) {
409- db , err := OpenDB ()
420+ func (p * SQLProvider ) RepoExists (ctx context.Context , user string , token * oauth2.Token , name string ) (bool , error ) {
421+ db , err := p . getDB ()
410422 if err != nil {
411423 return false , err
412424 }
413- defer db .Close ()
414425
415426 var count int
416427 err = db .QueryRowContext (ctx , "SELECT COUNT(1) FROM bookmarks WHERE user=?" , user ).Scan (& count )
417428 return count > 0 , err
418429}
419430
420- func (p SQLProvider ) CreateUser (ctx context.Context , user , password string ) error {
421- db , err := OpenDB ()
431+ func (p * SQLProvider ) CreateUser (ctx context.Context , user , password string ) error {
432+ db , err := p . getDB ()
422433 if err != nil {
423434 return err
424435 }
425- defer db .Close ()
426436
427437 var count int
428438 if err := db .QueryRowContext (ctx , "SELECT COUNT(1) FROM passwords WHERE user=?" , user ).Scan (& count ); err != nil {
@@ -440,12 +450,11 @@ func (p SQLProvider) CreateUser(ctx context.Context, user, password string) erro
440450 return err
441451}
442452
443- func (p SQLProvider ) SetPassword (ctx context.Context , user , password string ) error {
444- db , err := OpenDB ()
453+ func (p * SQLProvider ) SetPassword (ctx context.Context , user , password string ) error {
454+ db , err := p . getDB ()
445455 if err != nil {
446456 return err
447457 }
448- defer db .Close ()
449458
450459 hash , err := bcrypt .GenerateFromPassword ([]byte (password ), bcrypt .DefaultCost )
451460 if err != nil {
@@ -462,12 +471,11 @@ func (p SQLProvider) SetPassword(ctx context.Context, user, password string) err
462471 return nil
463472}
464473
465- func (p SQLProvider ) CheckPassword (ctx context.Context , user , password string ) (bool , error ) {
466- db , err := OpenDB ()
474+ func (p * SQLProvider ) CheckPassword (ctx context.Context , user , password string ) (bool , error ) {
475+ db , err := p . getDB ()
467476 if err != nil {
468477 return false , err
469478 }
470- defer db .Close ()
471479
472480 var hash []byte
473481 err = db .QueryRowContext (ctx , "SELECT hash FROM passwords WHERE user=?" , user ).Scan (& hash )
0 commit comments