Skip to content

Commit 7e2b1a6

Browse files
committed
add dsn parameter to runTests
1 parent 4e994df commit 7e2b1a6

File tree

1 file changed

+18
-23
lines changed

1 file changed

+18
-23
lines changed

driver_test.go

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ type DBTest struct {
106106
db *sql.DB
107107
}
108108

109-
func runTests(t *testing.T, name string, tests ...func(dbt *DBTest)) {
109+
func runTests(t *testing.T, name, dsn string, tests ...func(dbt *DBTest)) {
110110
if !available {
111111
t.Logf("MySQL-Server not running on %s. Skipping %s", netAddr, name)
112112
return
@@ -151,7 +151,7 @@ func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *sql.Rows)
151151
}
152152

153153
func TestRawBytesResultExceedsBuffer(t *testing.T) {
154-
runTests(t, "TestRawBytesResultExceedsBuffer", func(dbt *DBTest) {
154+
runTests(t, "TestRawBytesResultExceedsBuffer", dsn, func(dbt *DBTest) {
155155
// defaultBufSize from buffer.go
156156
expected := strings.Repeat("abc", defaultBufSize)
157157
rows := dbt.mustQuery("SELECT '" + expected + "'")
@@ -168,7 +168,7 @@ func TestRawBytesResultExceedsBuffer(t *testing.T) {
168168
}
169169

170170
func TestCRUD(t *testing.T) {
171-
runTests(t, "TestCRUD", func(dbt *DBTest) {
171+
runTests(t, "TestCRUD", dsn, func(dbt *DBTest) {
172172
// Create Table
173173
dbt.mustExec("CREATE TABLE test (value BOOL)")
174174

@@ -260,7 +260,7 @@ func TestCRUD(t *testing.T) {
260260
}
261261

262262
func TestInt(t *testing.T) {
263-
runTests(t, "TestInt", func(dbt *DBTest) {
263+
runTests(t, "TestInt", dsn, func(dbt *DBTest) {
264264
types := [5]string{"TINYINT", "SMALLINT", "MEDIUMINT", "INT", "BIGINT"}
265265
in := int64(42)
266266
var out int64
@@ -307,7 +307,7 @@ func TestInt(t *testing.T) {
307307
}
308308

309309
func TestFloat(t *testing.T) {
310-
runTests(t, "TestFloat", func(dbt *DBTest) {
310+
runTests(t, "TestFloat", dsn, func(dbt *DBTest) {
311311
types := [2]string{"FLOAT", "DOUBLE"}
312312
in := float32(42.23)
313313
var out float32
@@ -330,7 +330,7 @@ func TestFloat(t *testing.T) {
330330
}
331331

332332
func TestString(t *testing.T) {
333-
runTests(t, "TestString", func(dbt *DBTest) {
333+
runTests(t, "TestString", dsn, func(dbt *DBTest) {
334334
types := [6]string{"CHAR(255)", "VARCHAR(255)", "TINYTEXT", "TEXT", "MEDIUMTEXT", "LONGTEXT"}
335335
in := "κόσμε üöäßñóùéàâÿœ'îë Árvíztűrő いろはにほへとちりぬるを イロハニホヘト דג סקרן чащах น่าฟังเอย"
336336
var out string
@@ -470,18 +470,15 @@ func TestDateTime(t *testing.T) {
470470
}
471471
}
472472

473-
oldDsn := dsn
474-
usedDsn := oldDsn + "&sql_mode=ALLOW_INVALID_DATES"
473+
timeDsn := dsn + "&sql_mode=ALLOW_INVALID_DATES"
475474
for _, v := range setups {
476475
s = v
477-
dsn = usedDsn + s.dsnSuffix
478-
runTests(t, "TestDateTime", testTime)
476+
runTests(t, "TestDateTime", timeDsn+s.dsnSuffix, testTime)
479477
}
480-
dsn = oldDsn
481478
}
482479

483480
func TestNULL(t *testing.T) {
484-
runTests(t, "TestNULL", func(dbt *DBTest) {
481+
runTests(t, "TestNULL", dsn, func(dbt *DBTest) {
485482
nullStmt, err := dbt.db.Prepare("SELECT NULL")
486483
if err != nil {
487484
dbt.Fatal(err)
@@ -597,7 +594,7 @@ func TestNULL(t *testing.T) {
597594
}
598595

599596
func TestLongData(t *testing.T) {
600-
runTests(t, "TestLongData", func(dbt *DBTest) {
597+
runTests(t, "TestLongData", dsn, func(dbt *DBTest) {
601598
var maxAllowedPacketSize int
602599
err := dbt.db.QueryRow("select @@max_allowed_packet").Scan(&maxAllowedPacketSize)
603600
if err != nil {
@@ -654,7 +651,7 @@ func TestLongData(t *testing.T) {
654651
}
655652

656653
func TestLoadData(t *testing.T) {
657-
runTests(t, "TestLoadData", func(dbt *DBTest) {
654+
runTests(t, "TestLoadData", dsn, func(dbt *DBTest) {
658655
verifyLoadDataResult := func() {
659656
rows, err := dbt.db.Query("SELECT * FROM test")
660657
if err != nil {
@@ -741,10 +738,9 @@ func TestLoadData(t *testing.T) {
741738
}
742739

743740
func TestStrict(t *testing.T) {
744-
oldDsn := dsn
745-
// to get rid of stricter modes - we want to test for warnings, not errors
746-
dsn += "&sql_mode=ALLOW_INVALID_DATES"
747-
runTests(t, "TestStrict", func(dbt *DBTest) {
741+
// ALLOW_INVALID_DATES to get rid of stricter modes - we want to test for warnings, not errors
742+
relaxedDsn := dsn + "&sql_mode=ALLOW_INVALID_DATES"
743+
runTests(t, "TestStrict", relaxedDsn, func(dbt *DBTest) {
748744
dbt.mustExec("CREATE TABLE test (a TINYINT NOT NULL, b CHAR(4))")
749745

750746
var queries = [...]struct {
@@ -806,13 +802,12 @@ func TestStrict(t *testing.T) {
806802
}
807803
}
808804
})
809-
dsn = oldDsn
810805
}
811806

812807
// Special cases
813808

814809
func TestRowsClose(t *testing.T) {
815-
runTests(t, "TestRowsClose", func(dbt *DBTest) {
810+
runTests(t, "TestRowsClose", dsn, func(dbt *DBTest) {
816811
rows, err := dbt.db.Query("SELECT 1")
817812
if err != nil {
818813
dbt.Fatal(err)
@@ -837,7 +832,7 @@ func TestRowsClose(t *testing.T) {
837832
// dangling statements
838833
// http://code.google.com/p/go/issues/detail?id=3865
839834
func TestCloseStmtBeforeRows(t *testing.T) {
840-
runTests(t, "TestCloseStmtBeforeRows", func(dbt *DBTest) {
835+
runTests(t, "TestCloseStmtBeforeRows", dsn, func(dbt *DBTest) {
841836
stmt, err := dbt.db.Prepare("SELECT 1")
842837
if err != nil {
843838
dbt.Fatal(err)
@@ -878,7 +873,7 @@ func TestCloseStmtBeforeRows(t *testing.T) {
878873
// It is valid to have multiple Rows for the same Stmt
879874
// http://code.google.com/p/go/issues/detail?id=3734
880875
func TestStmtMultiRows(t *testing.T) {
881-
runTests(t, "TestStmtMultiRows", func(dbt *DBTest) {
876+
runTests(t, "TestStmtMultiRows", dsn, func(dbt *DBTest) {
882877
stmt, err := dbt.db.Prepare("SELECT 1 UNION SELECT 0")
883878
if err != nil {
884879
dbt.Fatal(err)
@@ -993,7 +988,7 @@ func TestConcurrent(t *testing.T) {
993988
t.Log("CONCURRENT env var not set. Skipping TestConcurrent")
994989
return
995990
}
996-
runTests(t, "TestConcurrent", func(dbt *DBTest) {
991+
runTests(t, "TestConcurrent", dsn, func(dbt *DBTest) {
997992
var max int
998993
err := dbt.db.QueryRow("SELECT @@max_connections").Scan(&max)
999994
if err != nil {

0 commit comments

Comments
 (0)