Skip to content

Commit 90fdf91

Browse files
committed
Sanitize and blob values
1 parent 523f634 commit 90fdf91

File tree

4 files changed

+67
-5
lines changed

4 files changed

+67
-5
lines changed

dump.go

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,7 @@ INSERT INTO {{ .Name }} VALUES {{ .Values }};
7777
UNLOCK TABLES;
7878
`
7979

80-
const footerTmpl = `
81-
/*!40103 SET TIME_ZONE=@OLD_TIME_ZONE */;
80+
const footerTmpl = `/*!40103 SET TIME_ZONE=@OLD_TIME_ZONE */;
8281
8382
/*!40101 SET SQL_MODE=@OLD_SQL_MODE */;
8483
/*!40014 SET FOREIGN_KEY_CHECKS=@OLD_FOREIGN_KEY_CHECKS */;
@@ -282,7 +281,9 @@ func createTableValues(db *sql.DB, name string) (string, error) {
282281
types := make([]reflect.Type, len(tt))
283282
for i, tp := range tt {
284283
st := tp.ScanType()
285-
if st == nil || st.Kind() == reflect.Slice {
284+
if tp.DatabaseTypeName() == "BLOB" {
285+
types[i] = reflect.TypeOf(sql.RawBytes{})
286+
} else if st == nil || st.Kind() == reflect.Slice {
286287
types[i] = reflect.TypeOf(sql.NullString{})
287288
} else if st.Kind() == reflect.Int ||
288289
st.Kind() == reflect.Int8 ||
@@ -311,7 +312,7 @@ func createTableValues(db *sql.DB, name string) (string, error) {
311312
dataStrings[key] = "null"
312313
} else if s, ok := value.(*sql.NullString); ok {
313314
if s.Valid {
314-
dataStrings[key] = "'" + strings.Replace(s.String, "\n", "\\n", -1) + "'"
315+
dataStrings[key] = "'" + sanitize(s.String) + "'"
315316
} else {
316317
dataStrings[key] = "NULL"
317318
}
@@ -321,6 +322,12 @@ func createTableValues(db *sql.DB, name string) (string, error) {
321322
} else {
322323
dataStrings[key] = "NULL"
323324
}
325+
} else if s, ok := value.(*sql.RawBytes); ok {
326+
if len(*s) == 0 {
327+
dataStrings[key] = "NULL"
328+
} else {
329+
dataStrings[key] = "_binary '" + sanitize(string(*s)) + "'"
330+
}
324331
} else {
325332
dataStrings[key] = fmt.Sprint("'", value, "'")
326333
}

dump_test.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,6 @@ LOCK TABLES \Test_Table\ WRITE;
325325
INSERT INTO \Test_Table\ VALUES ('1',null,'Test Name 1'),('2','[email protected]','Test Name 2');
326326
/*!40000 ALTER TABLE \Test_Table\ ENABLE KEYS */;
327327
UNLOCK TABLES;
328-
329328
/*!40103 SET TIME_ZONE=@OLD_TIME_ZONE */;
330329
331330
/*!40101 SET SQL_MODE=@OLD_SQL_MODE */;

sanitize.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package mysqldump
2+
3+
import "strings"
4+
5+
var lazyMySQLReplacer *strings.Replacer
6+
7+
func mysqlReplacer() *strings.Replacer {
8+
if lazyMySQLReplacer == nil {
9+
lazyMySQLReplacer = strings.NewReplacer(
10+
"\x00", "\\0",
11+
"'", "\\'",
12+
"\"", "\\\"",
13+
"\b", "\\b",
14+
"\n", "\\n",
15+
"\r", "\\r",
16+
// "\t", "\\t",
17+
"\x1A", "\\Z", // ASCII 26 == x1A
18+
"\\", "\\\\",
19+
"%", "\\%",
20+
// "_", "\\_",
21+
)
22+
}
23+
return lazyMySQLReplacer
24+
}
25+
26+
// MySQL sanitizes mysql based on
27+
// https://dev.mysql.com/doc/refman/8.0/en/string-literals.html table 9.1
28+
// needs to be placed in either a single or a double quoted string
29+
func sanitize(input string) string {
30+
return mysqlReplacer().Replace(input)
31+
}

sanitize_test.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package mysqldump
2+
3+
import (
4+
"fmt"
5+
"testing"
6+
)
7+
8+
func TestForSQLInjection(t *testing.T) {
9+
examples := [][]string{
10+
/** Query ** Input ** Expected **/
11+
{"SELECT * WHERE field = '%s';", "test", "SELECT * WHERE field = 'test';"},
12+
{"'%s'", "'; DROP TABLES `test`;", "'\\'; DROP TABLES `test`;'"},
13+
{"'%s'", "'+(SELECT name FROM users LIMIT 1)+'", "'\\'+(SELECT name FROM users LIMIT 1)+\\''"},
14+
{"SELECT '%s'", "\x00x633A5C626F6F742E696E69", "SELECT '\\0x633A5C626F6F742E696E69'"},
15+
{"WHERE PASSWORD('%s')", "') OR 1=1--", "WHERE PASSWORD('\\') OR 1=1--')"},
16+
}
17+
var query string
18+
for _, example := range examples {
19+
query = fmt.Sprintf(example[0], sanitize(example[1]))
20+
21+
if example[2] != query {
22+
t.Fatalf("expected %#v, got %#v", example[2], query)
23+
}
24+
}
25+
}

0 commit comments

Comments
 (0)