@@ -2,37 +2,61 @@ package mysql
22
33import (
44 "bytes"
5+ "crypto/tls"
6+ "crypto/x509"
57 "database/sql"
68 "fmt"
79 "io"
810 "net/url"
11+ "os"
912 "regexp"
1013 "strings"
1114
15+ "github.com/go-sql-driver/mysql" // database/sql driver
16+
1217 "github.com/amacneil/dbmate/v2/pkg/dbmate"
1318 "github.com/amacneil/dbmate/v2/pkg/dbutil"
14-
15- _ "github.com/go-sql-driver/mysql" // database/sql driver
1619)
1720
1821func init () {
1922 dbmate .RegisterDriver (NewDriver , "mysql" )
2023}
2124
25+ // sslMode refers to the --ssl-mode options in
26+ // https://dev.mysql.com/doc/refman/8.4/en/connection-options.html#option_general_ssl-mode
27+ type sslMode string
28+
29+ const (
30+ sslModeDisabled sslMode = "DISABLED"
31+ sslModePreferred sslMode = "PREFERRED"
32+ sslModeRequired sslMode = "REQUIRED"
33+ sslModeVerifyCa sslMode = "VERIFY_CA"
34+ sslModeVerifyIdentity sslMode = "VERIFY_IDENTITY"
35+ )
36+
2237// Driver provides top level database functions
2338type Driver struct {
2439 migrationsTableName string
2540 databaseURL * url.URL
2641 log io.Writer
42+
43+ sslMode sslMode
44+ sslConfErr error
45+ // Path to the file containing the certificate authority file in PEM format.
46+ caPath string
2747}
2848
2949// NewDriver initializes the driver
3050func NewDriver (config dbmate.DriverConfig ) dbmate.Driver {
31- return & Driver {
51+ driver := & Driver {
3252 migrationsTableName : config .MigrationsTableName ,
3353 databaseURL : config .DatabaseURL ,
3454 log : config .Log ,
55+ caPath : os .Getenv ("DBMATE_MYSQL_CA_PATH" ),
3556 }
57+ driver .sslConfErr = driver .configureSsl (os .Getenv ("DBMATE_MYSQL_SSL_MODE" ))
58+
59+ return driver
3660}
3761
3862func connectionString (u * url.URL ) string {
@@ -69,12 +93,68 @@ func connectionString(u *url.URL) string {
6993 return normalizedString
7094}
7195
96+ func (drv * Driver ) configureSsl (mode string ) error {
97+ switch sslMode (mode ) {
98+ case sslModeDisabled ,
99+ sslModePreferred :
100+ drv .sslMode = sslMode (mode )
101+ return nil
102+ // required?
103+ case sslModeRequired ,
104+ sslModeVerifyCa ,
105+ sslModeVerifyIdentity :
106+ drv .sslMode = sslMode (mode )
107+ case "" :
108+ drv .sslMode = sslModePreferred
109+ return nil
110+ default :
111+ return fmt .Errorf ("unknown ssl mode: %s" , mode )
112+ }
113+
114+ var tlsConf tls.Config
115+
116+ if drv .caPath != "" {
117+ caPem , err := os .ReadFile (drv .caPath )
118+ if err != nil {
119+ return fmt .Errorf ("failed to read CA file: %w" , err )
120+ }
121+
122+ rootCertPool := x509 .NewCertPool ()
123+ if ok := rootCertPool .AppendCertsFromPEM (caPem ); ! ok {
124+ return fmt .Errorf ("failed to append to root cert pool" )
125+ }
126+ tlsConf .RootCAs = rootCertPool
127+ }
128+ switch drv .sslMode {
129+ case sslModeRequired :
130+ tlsConf .InsecureSkipVerify = true
131+ case sslModeVerifyCa :
132+ case sslModeVerifyIdentity :
133+ tlsConf .ServerName = drv .databaseURL .Hostname ()
134+ }
135+
136+ err := mysql .RegisterTLSConfig ("custom" , & tlsConf )
137+ if err != nil {
138+ return fmt .Errorf ("failed to register custom TLS config: %v" , err )
139+ }
140+ query := drv .databaseURL .Query ()
141+ query .Set ("tls" , "custom" )
142+ drv .databaseURL .RawQuery = query .Encode ()
143+ return nil
144+ }
145+
72146// Open creates a new database connection
73147func (drv * Driver ) Open () (* sql.DB , error ) {
148+ if drv .sslConfErr != nil {
149+ return nil , fmt .Errorf ("failed to configure ssl: %w" , drv .sslConfErr )
150+ }
74151 return sql .Open ("mysql" , connectionString (drv .databaseURL ))
75152}
76153
77154func (drv * Driver ) openRootDB () (* sql.DB , error ) {
155+ if drv .sslConfErr != nil {
156+ return nil , fmt .Errorf ("failed to configure ssl: %w" , drv .sslConfErr )
157+ }
78158 // clone databaseURL
79159 rootURL , err := url .Parse (drv .databaseURL .String ())
80160 if err != nil {
@@ -129,8 +209,17 @@ func (drv *Driver) DropDatabase() error {
129209
130210func (drv * Driver ) mysqldumpArgs () []string {
131211 // generate CLI arguments
132- args := []string {"--opt" , "--routines" , "--no-data" ,
133- "--skip-dump-date" , "--skip-add-drop-table" }
212+ args := []string {
213+ "--opt" , "--routines" , "--no-data" ,
214+ "--skip-dump-date" , "--skip-add-drop-table" ,
215+ }
216+
217+ if drv .sslMode != sslModePreferred {
218+ args = append (args , "--ssl-mode" , string (drv .sslMode ))
219+ }
220+ if drv .caPath != "" {
221+ args = append (args , "--ssl-ca" , drv .caPath )
222+ }
134223
135224 socket := drv .databaseURL .Query ().Get ("socket" )
136225 if socket != "" {
0 commit comments