Skip to content

Commit 754c8c0

Browse files
Support requiring SSL, and verifying CA, for MySQL
1 parent 3fe5366 commit 754c8c0

File tree

3 files changed

+204
-5
lines changed

3 files changed

+204
-5
lines changed

docker-compose.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ services:
1818
CLICKHOUSE_CLUSTER_01_TEST_URL: clickhouse://ch-cluster-01:9000/dbmate_test
1919
CLICKHOUSE_CLUSTER_02_TEST_URL: clickhouse://ch-cluster-02:9000/dbmate_test
2020
MYSQL_TEST_URL: mysql://root:root@mysql/dbmate_test
21+
DBMATE_MYSQL_SSL_MODE: DISABLED
2122
POSTGRES_TEST_URL: postgres://postgres:postgres@postgres/dbmate_test?sslmode=disable
2223
BIGQUERY_TEST_URL: bigquery://test/us-east5/dbmate_test?disable_auth=true&endpoint=http%3A%2F%2Fbigquery%3A9050
2324
SPANNER_POSTGRES_TEST_URL: spanner-postgres://spanner-emulator/dbmate_test?sslmode=disable

pkg/driver/mysql/mysql.go

Lines changed: 94 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,37 +2,61 @@ package mysql
22

33
import (
44
"bytes"
5+
"crypto/tls"
6+
"crypto/x509"
57
"database/sql"
68
"fmt"
79
"io"
810
"net/url"
11+
"os"
912
"regexp"
1013
"strings"
1114

15+
"github.com/go-sql-driver/mysql" // database/sql driver
16+
1217
"github.com/amacneil/dbmate/v2/pkg/dbmate"
1318
"github.com/amacneil/dbmate/v2/pkg/dbutil"
14-
15-
_ "github.com/go-sql-driver/mysql" // database/sql driver
1619
)
1720

1821
func init() {
1922
dbmate.RegisterDriver(NewDriver, "mysql")
2023
}
2124

25+
// sslMode refers to the --ssl-mode options in
26+
// https://dev.mysql.com/doc/refman/8.4/en/connection-options.html#option_general_ssl-mode
27+
type sslMode string
28+
29+
const (
30+
sslModeDisabled sslMode = "DISABLED"
31+
sslModePreferred sslMode = "PREFERRED"
32+
sslModeRequired sslMode = "REQUIRED"
33+
sslModeVerifyCa sslMode = "VERIFY_CA"
34+
sslModeVerifyIdentity sslMode = "VERIFY_IDENTITY"
35+
)
36+
2237
// Driver provides top level database functions
2338
type Driver struct {
2439
migrationsTableName string
2540
databaseURL *url.URL
2641
log io.Writer
42+
43+
sslMode sslMode
44+
sslConfErr error
45+
// Path to the file containing the certificate authority file in PEM format.
46+
caPath string
2747
}
2848

2949
// NewDriver initializes the driver
3050
func NewDriver(config dbmate.DriverConfig) dbmate.Driver {
31-
return &Driver{
51+
driver := &Driver{
3252
migrationsTableName: config.MigrationsTableName,
3353
databaseURL: config.DatabaseURL,
3454
log: config.Log,
55+
caPath: os.Getenv("DBMATE_MYSQL_CA_PATH"),
3556
}
57+
driver.sslConfErr = driver.configureSsl(os.Getenv("DBMATE_MYSQL_SSL_MODE"))
58+
59+
return driver
3660
}
3761

3862
func connectionString(u *url.URL) string {
@@ -69,12 +93,68 @@ func connectionString(u *url.URL) string {
6993
return normalizedString
7094
}
7195

96+
func (drv *Driver) configureSsl(mode string) error {
97+
switch sslMode(mode) {
98+
case sslModeDisabled,
99+
sslModePreferred:
100+
drv.sslMode = sslMode(mode)
101+
return nil
102+
// required?
103+
case sslModeRequired,
104+
sslModeVerifyCa,
105+
sslModeVerifyIdentity:
106+
drv.sslMode = sslMode(mode)
107+
case "":
108+
drv.sslMode = sslModePreferred
109+
return nil
110+
default:
111+
return fmt.Errorf("unknown ssl mode: %s", mode)
112+
}
113+
114+
var tlsConf tls.Config
115+
116+
if drv.caPath != "" {
117+
caPem, err := os.ReadFile(drv.caPath)
118+
if err != nil {
119+
return fmt.Errorf("failed to read CA file: %w", err)
120+
}
121+
122+
rootCertPool := x509.NewCertPool()
123+
if ok := rootCertPool.AppendCertsFromPEM(caPem); !ok {
124+
return fmt.Errorf("failed to append to root cert pool")
125+
}
126+
tlsConf.RootCAs = rootCertPool
127+
}
128+
switch drv.sslMode {
129+
case sslModeRequired:
130+
tlsConf.InsecureSkipVerify = true
131+
case sslModeVerifyCa:
132+
case sslModeVerifyIdentity:
133+
tlsConf.ServerName = drv.databaseURL.Hostname()
134+
}
135+
136+
err := mysql.RegisterTLSConfig("custom", &tlsConf)
137+
if err != nil {
138+
return fmt.Errorf("failed to register custom TLS config: %v", err)
139+
}
140+
query := drv.databaseURL.Query()
141+
query.Set("tls", "custom")
142+
drv.databaseURL.RawQuery = query.Encode()
143+
return nil
144+
}
145+
72146
// Open creates a new database connection
73147
func (drv *Driver) Open() (*sql.DB, error) {
148+
if drv.sslConfErr != nil {
149+
return nil, fmt.Errorf("failed to configure ssl: %w", drv.sslConfErr)
150+
}
74151
return sql.Open("mysql", connectionString(drv.databaseURL))
75152
}
76153

77154
func (drv *Driver) openRootDB() (*sql.DB, error) {
155+
if drv.sslConfErr != nil {
156+
return nil, fmt.Errorf("failed to configure ssl: %w", drv.sslConfErr)
157+
}
78158
// clone databaseURL
79159
rootURL, err := url.Parse(drv.databaseURL.String())
80160
if err != nil {
@@ -129,8 +209,17 @@ func (drv *Driver) DropDatabase() error {
129209

130210
func (drv *Driver) mysqldumpArgs() []string {
131211
// generate CLI arguments
132-
args := []string{"--opt", "--routines", "--no-data",
133-
"--skip-dump-date", "--skip-add-drop-table"}
212+
args := []string{
213+
"--opt", "--routines", "--no-data",
214+
"--skip-dump-date", "--skip-add-drop-table",
215+
}
216+
217+
if drv.sslMode != sslModePreferred {
218+
args = append(args, "--ssl-mode", string(drv.sslMode))
219+
}
220+
if drv.caPath != "" {
221+
args = append(args, "--ssl-ca", drv.caPath)
222+
}
134223

135224
socket := drv.databaseURL.Query().Get("socket")
136225
if socket != "" {

pkg/driver/mysql/mysql_test.go

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ func TestMySQLCreateDropDatabase(t *testing.T) {
146146
}
147147

148148
func TestMySQLDumpArgs(t *testing.T) {
149+
t.Setenv("DBMATE_MYSQL_SSL_MODE", "PREFERRED")
149150
drv := testMySQLDriver(t)
150151
drv.databaseURL = dbtest.MustParseURL(t, "mysql://bob/mydb")
151152

@@ -181,6 +182,114 @@ func TestMySQLDumpArgs(t *testing.T) {
181182
"mydb"}, drv.mysqldumpArgs())
182183
}
183184

185+
func TestMySQLDumpArgsWithSsl(t *testing.T) {
186+
t.Run("mode=DISABLED,ca=empty", func(t *testing.T) {
187+
t.Setenv("DBMATE_MYSQL_SSL_MODE", "DISABLED")
188+
drv := testMySQLDriver(t)
189+
drv.databaseURL = dbtest.MustParseURL(t, "mysql://bob/mydb")
190+
require.Equal(t, []string{
191+
"--opt",
192+
"--routines",
193+
"--no-data",
194+
"--skip-dump-date",
195+
"--skip-add-drop-table",
196+
"--ssl-mode",
197+
"DISABLED",
198+
"--host=bob",
199+
"mydb",
200+
}, drv.mysqldumpArgs())
201+
})
202+
t.Run("mode=DISABLED,ca=set", func(t *testing.T) {
203+
t.Setenv("DBMATE_MYSQL_SSL_MODE", "DISABLED")
204+
t.Setenv("DBMATE_MYSQL_CA_PATH", "/tmp/404")
205+
drv := testMySQLDriver(t)
206+
drv.databaseURL = dbtest.MustParseURL(t, "mysql://bob/mydb")
207+
require.Equal(t, []string{
208+
"--opt",
209+
"--routines",
210+
"--no-data",
211+
"--skip-dump-date",
212+
"--skip-add-drop-table",
213+
"--ssl-mode",
214+
"DISABLED",
215+
"--ssl-ca",
216+
"/tmp/404",
217+
"--host=bob",
218+
"mydb",
219+
}, drv.mysqldumpArgs())
220+
})
221+
t.Run("mode=VERIFY_IDENTITY,ca=empty", func(t *testing.T) {
222+
t.Setenv("DBMATE_MYSQL_SSL_MODE", "VERIFY_IDENTITY")
223+
drv := testMySQLDriver(t)
224+
drv.databaseURL = dbtest.MustParseURL(t, "mysql://bob/mydb")
225+
require.Equal(t, []string{
226+
"--opt",
227+
"--routines",
228+
"--no-data",
229+
"--skip-dump-date",
230+
"--skip-add-drop-table",
231+
"--ssl-mode",
232+
"VERIFY_IDENTITY",
233+
"--host=bob",
234+
"mydb",
235+
}, drv.mysqldumpArgs())
236+
})
237+
t.Run("mode=VERIFY_IDENTITY,ca=set", func(t *testing.T) {
238+
t.Setenv("DBMATE_MYSQL_SSL_MODE", "VERIFY_IDENTITY")
239+
t.Setenv("DBMATE_MYSQL_CA_PATH", "/tmp/404")
240+
drv := testMySQLDriver(t)
241+
drv.databaseURL = dbtest.MustParseURL(t, "mysql://bob/mydb")
242+
require.Equal(t, []string{
243+
"--opt",
244+
"--routines",
245+
"--no-data",
246+
"--skip-dump-date",
247+
"--skip-add-drop-table",
248+
"--ssl-mode",
249+
"VERIFY_IDENTITY",
250+
"--ssl-ca",
251+
"/tmp/404",
252+
"--host=bob",
253+
"mydb",
254+
}, drv.mysqldumpArgs())
255+
})
256+
t.Run("mode=REQUIRED,ca=empty", func(t *testing.T) {
257+
t.Setenv("DBMATE_MYSQL_SSL_MODE", "REQUIRED")
258+
drv := testMySQLDriver(t)
259+
drv.databaseURL = dbtest.MustParseURL(t, "mysql://bob/mydb")
260+
require.Equal(t, []string{
261+
"--opt",
262+
"--routines",
263+
"--no-data",
264+
"--skip-dump-date",
265+
"--skip-add-drop-table",
266+
"--ssl-mode",
267+
"REQUIRED",
268+
"--host=bob",
269+
"mydb",
270+
}, drv.mysqldumpArgs())
271+
})
272+
t.Run("mode=REQUIRED,ca=set", func(t *testing.T) {
273+
t.Setenv("DBMATE_MYSQL_SSL_MODE", "REQUIRED")
274+
t.Setenv("DBMATE_MYSQL_CA_PATH", "/tmp/404")
275+
drv := testMySQLDriver(t)
276+
drv.databaseURL = dbtest.MustParseURL(t, "mysql://bob/mydb")
277+
require.Equal(t, []string{
278+
"--opt",
279+
"--routines",
280+
"--no-data",
281+
"--skip-dump-date",
282+
"--skip-add-drop-table",
283+
"--ssl-mode",
284+
"REQUIRED",
285+
"--ssl-ca",
286+
"/tmp/404",
287+
"--host=bob",
288+
"mydb",
289+
}, drv.mysqldumpArgs())
290+
})
291+
}
292+
184293
func TestMySQLDumpSchema(t *testing.T) {
185294
drv := testMySQLDriver(t)
186295
drv.migrationsTableName = "test_migrations"

0 commit comments

Comments
 (0)