diff --git a/.gitignore b/.gitignore index e099c55..c9d6cb5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ random_data_load* +*.swp vendor/ bin/* diff --git a/generator/generator.go b/generator/generator.go new file mode 100644 index 0000000..100fa8d --- /dev/null +++ b/generator/generator.go @@ -0,0 +1,335 @@ +package generator + +import ( + "database/sql" + "fmt" + "net/url" + "strings" + "sync" + "time" + + "github.com/Percona-Lab/mysql_random_data_load/internal/getters" + "github.com/Percona-Lab/mysql_random_data_load/tableparser" + "github.com/gosuri/uiprogress" + log "github.com/sirupsen/logrus" +) + +type Getter interface { + Value() interface{} + Quote() string + String() string +} +type InsertValues []Getter +type insertFunction func(*sql.DB, string, chan int, chan bool, *sync.WaitGroup) + +var ( + maxValues = map[string]int64{ + "tinyint": 0XF, + "smallint": 0xFF, + "mediumint": 0x7FFFF, + "int": 0x7FFFFFFF, + "integer": 0x7FFFFFFF, + "float": 0x7FFFFFFF, + "decimal": 0x7FFFFFFF, + "double": 0x7FFFFFFF, + "bigint": 0x7FFFFFFFFFFFFFFF, + } +) + +func Run(db *sql.DB, table *tableparser.Table, bar *uiprogress.Bar, sem chan bool, + rowValues InsertValues, count, bulkSize int, insertFunc insertFunction, newLineOnEachRow bool) (int, error) { + if count == 0 { + return 0, nil + } + var wg sync.WaitGroup + insertQuery := GenerateInsertStmt(table) + rowsChan := make(chan []Getter, 1000) + okRowsChan := countRowsOK(count, bar) + + go GenerateInsertData(count*bulkSize, rowValues, rowsChan) + defaultSeparator1 := "" + if newLineOnEachRow { + defaultSeparator1 = "\n" + } + + i := 0 + rowsCount := 0 + sep1, sep2 := defaultSeparator1, "" + + for i < count { + rowData := <-rowsChan + rowsCount++ + insertQuery += sep1 + " (" + for _, field := range rowData { + insertQuery += sep2 + field.Quote() + sep2 = ", " + } + insertQuery += ")" + sep1 = ", " + if newLineOnEachRow { + sep1 += "\n" + } + sep2 = "" + if rowsCount < bulkSize { + continue + } + + insertQuery += ";\n" + <-sem + wg.Add(1) + go insertFunc(db, insertQuery, okRowsChan, sem, &wg) + + insertQuery = GenerateInsertStmt(table) + sep1, sep2 = defaultSeparator1, "" + rowsCount = 0 + i++ + } + + wg.Wait() + okCount := <-okRowsChan + return okCount, nil +} + +func RunInsert(db *sql.DB, insertQuery string, resultsChan chan int, sem chan bool, wg *sync.WaitGroup) { + result, err := db.Exec(insertQuery) + if err != nil { + log.Debugf("Cannot run insert: %s", err) + resultsChan <- 0 + sem <- true + wg.Done() + return + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + log.Errorf("Cannot get rows affected after insert: %s", err) + } + resultsChan <- int(rowsAffected) + sem <- true + wg.Done() +} + +// GenerateInsertData will generate 'rows' items, where each item in the channel has 'bulkSize' rows. +// For example: +// We need to load 6 rows using a bulk insert having 2 rows per insert, like this: +// INSERT INTO table (f1, f2, f3) VALUES (?, ?, ?), (?, ?, ?) +// +// This function will put into rowsChan 3 elements, each one having the values for 2 rows: +// rowsChan <- [ v1-1, v1-2, v1-3, v2-1, v2-2, v2-3 ] +// rowsChan <- [ v3-1, v3-2, v3-3, v4-1, v4-2, v4-3 ] +// rowsChan <- [ v1-5, v5-2, v5-3, v6-1, v6-2, v6-3 ] +// +func GenerateInsertData(count int, values InsertValues, rowsChan chan []Getter) { + for i := 0; i < count; i++ { + insertRow := make([]Getter, 0, len(values)) + for _, val := range values { + insertRow = append(insertRow, val) + } + rowsChan <- insertRow + } +} + +func GenerateInsertStmt(table *tableparser.Table) string { + fields := getFieldNames(table.Fields) + query := fmt.Sprintf("INSERT IGNORE INTO %s.%s (%s) VALUES ", + backticks(table.Schema), + backticks(table.Name), + strings.Join(fields, ","), + ) + return query +} + +func getFieldNames(fields []tableparser.Field) []string { + var fieldNames []string + for _, field := range fields { + if !isSupportedType(field.DataType) { + continue + } + if !field.IsNullable && field.ColumnKey == "PRI" && + strings.Contains(field.Extra, "auto_increment") { + continue + } + fieldNames = append(fieldNames, backticks(field.ColumnName)) + } + return fieldNames +} + +func backticks(val string) string { + if strings.HasPrefix(val, "`") && strings.HasSuffix(val, "`") { + return url.QueryEscape(val) + } + return "`" + url.QueryEscape(val) + "`" +} + +func isSupportedType(fieldType string) bool { + supportedTypes := map[string]bool{ + "tinyint": true, + "smallint": true, + "mediumint": true, + "int": true, + "integer": true, + "bigint": true, + "float": true, + "decimal": true, + "double": true, + "char": true, + "varchar": true, + "date": true, + "datetime": true, + "timestamp": true, + "time": true, + "year": true, + "tinyblob": true, + "tinytext": true, + "blob": true, + "text": true, + "mediumblob": true, + "mediumtext": true, + "longblob": true, + "longtext": true, + "binary": true, + "varbinary": true, + "enum": true, + "set": true, + } + _, ok := supportedTypes[fieldType] + return ok +} + +// This go-routine keeps track of how many rows were actually inserted +// by the bulk inserts since one or more rows could generate duplicated +// keys so, not allways the number of inserted rows = number of rows in +// the bulk insert +func countRowsOK(count int, bar *uiprogress.Bar) chan int { + var totalOk int + resultsChan := make(chan int, 10000) + go func() { + for i := 0; i < count; i++ { + okCount := <-resultsChan + for j := 0; j < okCount; j++ { + bar.Incr() + } + totalOk += okCount + } + resultsChan <- totalOk + }() + return resultsChan +} + +// MakeValueFuncs returns an array of functions to generate all the values needed for a single row +func MakeValueFuncs(conn *sql.DB, fields []tableparser.Field) InsertValues { + var values []Getter + for _, field := range fields { + if !field.IsNullable && field.ColumnKey == "PRI" && strings.Contains(field.Extra, "auto_increment") { + continue + } + if field.Constraint != nil { + samples, err := getSamples(conn, field.Constraint.ReferencedTableSchema, + field.Constraint.ReferencedTableName, + field.Constraint.ReferencedColumnName, + 100, field.DataType) + if err != nil { + log.Printf("cannot get samples for field %q: %s\n", field.ColumnName, err) + continue + } + values = append(values, getters.NewRandomSample(field.ColumnName, samples, field.IsNullable)) + continue + } + maxValue := maxValues["bigint"] + if m, ok := maxValues[field.DataType]; ok { + maxValue = m + } + switch field.DataType { + case "tinyint", "smallint", "mediumint", "int", "integer", "bigint": + values = append(values, getters.NewRandomInt(field.ColumnName, maxValue, field.IsNullable)) + case "float", "decimal", "double": + values = append(values, getters.NewRandomDecimal(field.ColumnName, + field.NumericPrecision.Int64-field.NumericScale.Int64, field.IsNullable)) + case "char", "varchar": + values = append(values, getters.NewRandomString(field.ColumnName, + field.CharacterMaximumLength.Int64, field.IsNullable)) + case "date": + values = append(values, getters.NewRandomDate(field.ColumnName, field.IsNullable)) + case "datetime", "timestamp": + values = append(values, getters.NewRandomDateTime(field.ColumnName, field.IsNullable)) + case "tinyblob", "tinytext", "blob", "text", "mediumtext", "mediumblob", "longblob", "longtext": + values = append(values, getters.NewRandomString(field.ColumnName, + field.CharacterMaximumLength.Int64, field.IsNullable)) + case "time": + values = append(values, getters.NewRandomTime(field.IsNullable)) + case "year": + values = append(values, getters.NewRandomIntRange(field.ColumnName, int64(time.Now().Year()-1), + int64(time.Now().Year()), field.IsNullable)) + case "enum", "set": + values = append(values, getters.NewRandomEnum(field.SetEnumVals, field.IsNullable)) + case "binary", "varbinary": + values = append(values, getters.NewRandomString(field.ColumnName, field.CharacterMaximumLength.Int64, field.IsNullable)) + default: + log.Printf("cannot get field type: %s: %s\n", field.ColumnName, field.DataType) + } + } + + return values +} + +func getSamples(conn *sql.DB, schema, table, field string, samples int64, dataType string) ([]interface{}, error) { + var count int64 + var query string + + queryCount := fmt.Sprintf("SELECT COUNT(*) FROM `%s`.`%s`", schema, table) + if err := conn.QueryRow(queryCount).Scan(&count); err != nil { + return nil, fmt.Errorf("cannot get count for table %q: %s", table, err) + } + + if count < samples { + query = fmt.Sprintf("SELECT `%s` FROM `%s`.`%s`", field, schema, table) + } else { + query = fmt.Sprintf("SELECT `%s` FROM `%s`.`%s` WHERE RAND() <= .3 LIMIT %d", + field, schema, table, samples) + } + + rows, err := conn.Query(query) + if err != nil { + return nil, fmt.Errorf("cannot get samples: %s, %s", query, err) + } + defer rows.Close() + + values := []interface{}{} + + for rows.Next() { + var err error + var val interface{} + + switch dataType { + case "tinyint", "smallint", "mediumint", "int", "integer", "bigint", "year": + var v int64 + err = rows.Scan(&v) + val = v + case "char", "varchar", "blob", "text", "mediumtext", + "mediumblob", "longblob", "longtext": + var v string + err = rows.Scan(&v) + val = v + case "binary", "varbinary": + var v []rune + err = rows.Scan(&v) + val = v + case "float", "decimal", "double": + var v float64 + err = rows.Scan(&v) + val = v + case "date", "time", "datetime", "timestamp": + var v time.Time + err = rows.Scan(&v) + val = v + } + if err != nil { + return nil, fmt.Errorf("cannot scan sample: %s", err) + } + values = append(values, val) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("cannot get samples: %s", err) + } + return values, nil +} diff --git a/generator/generator_test.go b/generator/generator_test.go new file mode 100644 index 0000000..e0694b8 --- /dev/null +++ b/generator/generator_test.go @@ -0,0 +1,70 @@ +package generator + +import ( + "fmt" + "reflect" + "sync" + "testing" + "time" + + "github.com/Percona-Lab/mysql_random_data_load/internal/getters" + "github.com/Percona-Lab/mysql_random_data_load/tableparser" + tu "github.com/Percona-Lab/mysql_random_data_load/testutils" +) + +func TestGenerateInsertData(t *testing.T) { + wantRows := 3 + + values := []Getter{ + getters.NewRandomInt("f1", 100, false), + getters.NewRandomString("f2", 10, false), + getters.NewRandomDate("f3", false), + } + + rowsChan := make(chan []Getter, 100) + count := 0 + wg := &sync.WaitGroup{} + wg.Add(1) + + go func() { + for { + select { + case <-time.After(10 * time.Millisecond): + wg.Done() + return + case row := <-rowsChan: + if reflect.TypeOf(row[0]).String() != "*getters.RandomInt" { + fmt.Printf("Expected '*getters.RandomInt' for field [0], got %q\n", reflect.TypeOf(row[0]).String()) + t.Fail() + } + if reflect.TypeOf(row[1]).String() != "*getters.RandomString" { + fmt.Printf("Expected '*getters.RandomString' for field [1], got %q\n", reflect.TypeOf(row[1]).String()) + t.Fail() + } + if reflect.TypeOf(row[2]).String() != "*getters.RandomDate" { + fmt.Printf("Expected '*getters.RandomDate' for field [2], got %q\n", reflect.TypeOf(row[2]).String()) + t.Fail() + } + count++ + } + } + }() + + GenerateInsertData(wantRows, values, rowsChan) + + wg.Wait() + tu.Assert(t, count == 3, "Invalid number of rows") +} + +func TestGenerateInsertStmt(t *testing.T) { + var table *tableparser.Table + tu.LoadJson(t, "sakila.film.json", &table) + want := "INSERT IGNORE INTO `sakila`.`film` " + + "(`title`,`description`,`release_year`,`language_id`," + + "`original_language_id`,`rental_duration`,`rental_rate`," + + "`length`,`replacement_cost`,`rating`,`special_features`," + + "`last_update`) VALUES " + + query := GenerateInsertStmt(table) + tu.Equals(t, want, query) +} diff --git a/testdata/sakila.film.json b/generator/testdata/sakila.film.json similarity index 100% rename from testdata/sakila.film.json rename to generator/testdata/sakila.film.json diff --git a/main.go b/main.go index 9ff7f83..a5f63c0 100644 --- a/main.go +++ b/main.go @@ -3,7 +3,6 @@ package main import ( "database/sql" "fmt" - "net/url" "os" "os/user" "runtime" @@ -11,7 +10,7 @@ import ( "sync" "time" - "github.com/Percona-Lab/mysql_random_data_load/internal/getters" + "github.com/Percona-Lab/mysql_random_data_load/generator" "github.com/Percona-Lab/mysql_random_data_load/tableparser" "github.com/go-ini/ini" "github.com/go-sql-driver/mysql" @@ -58,17 +57,6 @@ var ( opts *cliOptions validFunctions = []string{"int", "string", "date", "date_in_range"} - maxValues = map[string]int64{ - "tinyint": 0XF, - "smallint": 0xFF, - "mediumint": 0x7FFFF, - "int": 0x7FFFFFFF, - "integer": 0x7FFFFFFF, - "float": 0x7FFFFFFF, - "decimal": 0x7FFFFFFF, - "double": 0x7FFFFFFF, - "bigint": 0x7FFFFFFFFFFFFFFF, - } Version = "0.0.0." Commit = "" @@ -77,14 +65,6 @@ var ( GoVersion = "1.9.2" ) -type getter interface { - Value() interface{} - Quote() string - String() string -} -type insertValues []getter -type insertFunction func(*sql.DB, string, chan int, chan bool, *sync.WaitGroup) - const ( defaultMySQLConfigSection = "client" defaultConfigFile = "~/.my.cnf" @@ -187,10 +167,10 @@ func main() { count := *opts.Rows / *opts.BulkSize remainder := *opts.Rows - count**opts.BulkSize semaphores := makeSemaphores(*opts.MaxThreads) - rowValues := makeValueFuncs(db, table.Fields) + rowValues := generator.MakeValueFuncs(db, table.Fields) log.Debugf("Must run %d bulk inserts having %d rows each", count, *opts.BulkSize) - runInsertFunc := runInsert + runInsertFunc := generator.RunInsert if *opts.Print { *opts.MaxThreads = 1 *opts.NoProgress = true @@ -208,14 +188,14 @@ func main() { uiprogress.Start() } - okCount, err := run(db, table, bar, semaphores, rowValues, count, *opts.BulkSize, runInsertFunc, newLineOnEachRow) + okCount, err := generator.Run(db, table, bar, semaphores, rowValues, count, *opts.BulkSize, runInsertFunc, newLineOnEachRow) if err != nil { log.Errorln(err) } var okrCount, okiCount int // remainder & individual inserts OK count if remainder > 0 { log.Debugf("Must run 1 extra bulk insert having %d rows, to complete %d rows", remainder, *opts.Rows) - okrCount, err = run(db, table, bar, semaphores, rowValues, 1, remainder, runInsertFunc, newLineOnEachRow) + okrCount, err = generator.Run(db, table, bar, semaphores, rowValues, 1, remainder, runInsertFunc, newLineOnEachRow) if err != nil { log.Errorln(err) } @@ -229,7 +209,7 @@ func main() { log.Debugf("Running extra %d individual inserts (duplicated keys?)", *opts.Rows-totalOkCount) } for totalOkCount < *opts.Rows && retries < *opts.MaxRetries { - okiCount, err = run(db, table, bar, semaphores, rowValues, *opts.Rows-totalOkCount, 1, runInsertFunc, newLineOnEachRow) + okiCount, err = generator.Run(db, table, bar, semaphores, rowValues, *opts.Rows-totalOkCount, 1, runInsertFunc, newLineOnEachRow) if err != nil { log.Errorf("Cannot run extra insert: %s", err) } @@ -245,60 +225,6 @@ func main() { db.Close() } -func run(db *sql.DB, table *tableparser.Table, bar *uiprogress.Bar, sem chan bool, - rowValues insertValues, count, bulkSize int, insertFunc insertFunction, newLineOnEachRow bool) (int, error) { - if count == 0 { - return 0, nil - } - var wg sync.WaitGroup - insertQuery := generateInsertStmt(table) - rowsChan := make(chan []getter, 1000) - okRowsChan := countRowsOK(count, bar) - - go generateInsertData(count*bulkSize, rowValues, rowsChan) - defaultSeparator1 := "" - if newLineOnEachRow { - defaultSeparator1 = "\n" - } - - i := 0 - rowsCount := 0 - sep1, sep2 := defaultSeparator1, "" - - for i < count { - rowData := <-rowsChan - rowsCount++ - insertQuery += sep1 + " (" - for _, field := range rowData { - insertQuery += sep2 + field.Quote() - sep2 = ", " - } - insertQuery += ")" - sep1 = ", " - if newLineOnEachRow { - sep1 += "\n" - } - sep2 = "" - if rowsCount < bulkSize { - continue - } - - insertQuery += ";\n" - <-sem - wg.Add(1) - go insertFunc(db, insertQuery, okRowsChan, sem, &wg) - - insertQuery = generateInsertStmt(table) - sep1, sep2 = defaultSeparator1, "" - rowsCount = 0 - i++ - } - - wg.Wait() - okCount := <-okRowsChan - return okCount, nil -} - func makeSemaphores(count int) chan bool { sem := make(chan bool, count) for i := 0; i < count; i++ { @@ -307,251 +233,6 @@ func makeSemaphores(count int) chan bool { return sem } -// This go-routine keeps track of how many rows were actually inserted -// by the bulk inserts since one or more rows could generate duplicated -// keys so, not allways the number of inserted rows = number of rows in -// the bulk insert - -func countRowsOK(count int, bar *uiprogress.Bar) chan int { - var totalOk int - resultsChan := make(chan int, 10000) - go func() { - for i := 0; i < count; i++ { - okCount := <-resultsChan - for j := 0; j < okCount; j++ { - bar.Incr() - } - totalOk += okCount - } - resultsChan <- totalOk - }() - return resultsChan -} - -// generateInsertData will generate 'rows' items, where each item in the channel has 'bulkSize' rows. -// For example: -// We need to load 6 rows using a bulk insert having 2 rows per insert, like this: -// INSERT INTO table (f1, f2, f3) VALUES (?, ?, ?), (?, ?, ?) -// -// This function will put into rowsChan 3 elements, each one having the values for 2 rows: -// rowsChan <- [ v1-1, v1-2, v1-3, v2-1, v2-2, v2-3 ] -// rowsChan <- [ v3-1, v3-2, v3-3, v4-1, v4-2, v4-3 ] -// rowsChan <- [ v1-5, v5-2, v5-3, v6-1, v6-2, v6-3 ] -// -func generateInsertData(count int, values insertValues, rowsChan chan []getter) { - for i := 0; i < count; i++ { - insertRow := make([]getter, 0, len(values)) - for _, val := range values { - insertRow = append(insertRow, val) - } - rowsChan <- insertRow - } -} - -func generateInsertStmt(table *tableparser.Table) string { - fields := getFieldNames(table.Fields) - query := fmt.Sprintf("INSERT IGNORE INTO %s.%s (%s) VALUES ", - backticks(table.Schema), - backticks(table.Name), - strings.Join(fields, ","), - ) - return query -} - -func runInsert(db *sql.DB, insertQuery string, resultsChan chan int, sem chan bool, wg *sync.WaitGroup) { - result, err := db.Exec(insertQuery) - if err != nil { - log.Debugf("Cannot run insert: %s", err) - resultsChan <- 0 - sem <- true - wg.Done() - return - } - - rowsAffected, err := result.RowsAffected() - if err != nil { - log.Errorf("Cannot get rows affected after insert: %s", err) - } - resultsChan <- int(rowsAffected) - sem <- true - wg.Done() -} - -// makeValueFuncs returns an array of functions to generate all the values needed for a single row -func makeValueFuncs(conn *sql.DB, fields []tableparser.Field) insertValues { - var values []getter - for _, field := range fields { - if !field.IsNullable && field.ColumnKey == "PRI" && strings.Contains(field.Extra, "auto_increment") { - continue - } - if field.Constraint != nil { - samples, err := getSamples(conn, field.Constraint.ReferencedTableSchema, - field.Constraint.ReferencedTableName, - field.Constraint.ReferencedColumnName, - 100, field.DataType) - if err != nil { - log.Printf("cannot get samples for field %q: %s\n", field.ColumnName, err) - continue - } - values = append(values, getters.NewRandomSample(field.ColumnName, samples, field.IsNullable)) - continue - } - maxValue := maxValues["bigint"] - if m, ok := maxValues[field.DataType]; ok { - maxValue = m - } - switch field.DataType { - case "tinyint", "smallint", "mediumint", "int", "integer", "bigint": - values = append(values, getters.NewRandomInt(field.ColumnName, maxValue, field.IsNullable)) - case "float", "decimal", "double": - values = append(values, getters.NewRandomDecimal(field.ColumnName, - field.NumericPrecision.Int64-field.NumericScale.Int64, field.IsNullable)) - case "char", "varchar": - values = append(values, getters.NewRandomString(field.ColumnName, - field.CharacterMaximumLength.Int64, field.IsNullable)) - case "date": - values = append(values, getters.NewRandomDate(field.ColumnName, field.IsNullable)) - case "datetime", "timestamp": - values = append(values, getters.NewRandomDateTime(field.ColumnName, field.IsNullable)) - case "tinyblob", "tinytext", "blob", "text", "mediumtext", "mediumblob", "longblob", "longtext": - values = append(values, getters.NewRandomString(field.ColumnName, - field.CharacterMaximumLength.Int64, field.IsNullable)) - case "time": - values = append(values, getters.NewRandomTime(field.IsNullable)) - case "year": - values = append(values, getters.NewRandomIntRange(field.ColumnName, int64(time.Now().Year()-1), - int64(time.Now().Year()), field.IsNullable)) - case "enum", "set": - values = append(values, getters.NewRandomEnum(field.SetEnumVals, field.IsNullable)) - case "binary", "varbinary": - values = append(values, getters.NewRandomBinary(field.ColumnName, field.CharacterMaximumLength.Int64, field.IsNullable)) - default: - log.Printf("cannot get field type: %s: %s\n", field.ColumnName, field.DataType) - } - } - - return values -} - -func getFieldNames(fields []tableparser.Field) []string { - var fieldNames []string - for _, field := range fields { - if !isSupportedType(field.DataType) { - continue - } - if !field.IsNullable && field.ColumnKey == "PRI" && - strings.Contains(field.Extra, "auto_increment") { - continue - } - fieldNames = append(fieldNames, backticks(field.ColumnName)) - } - return fieldNames -} - -func getSamples(conn *sql.DB, schema, table, field string, samples int64, dataType string) ([]interface{}, error) { - var count int64 - var query string - - queryCount := fmt.Sprintf("SELECT COUNT(*) FROM `%s`.`%s`", schema, table) - if err := conn.QueryRow(queryCount).Scan(&count); err != nil { - return nil, fmt.Errorf("cannot get count for table %q: %s", table, err) - } - - if count < samples { - query = fmt.Sprintf("SELECT `%s` FROM `%s`.`%s`", field, schema, table) - } else { - query = fmt.Sprintf("SELECT `%s` FROM `%s`.`%s` WHERE RAND() <= .3 LIMIT %d", - field, schema, table, samples) - } - - rows, err := conn.Query(query) - if err != nil { - return nil, fmt.Errorf("cannot get samples: %s, %s", query, err) - } - defer rows.Close() - - values := []interface{}{} - - for rows.Next() { - var err error - var val interface{} - - switch dataType { - case "tinyint", "smallint", "mediumint", "int", "integer", "bigint", "year": - var v int64 - err = rows.Scan(&v) - val = v - case "char", "varchar", "blob", "text", "mediumtext", - "mediumblob", "longblob", "longtext": - var v string - err = rows.Scan(&v) - val = v - case "binary", "varbinary": - var v []rune - err = rows.Scan(&v) - val = v - case "float", "decimal", "double": - var v float64 - err = rows.Scan(&v) - val = v - case "date", "time", "datetime", "timestamp": - var v time.Time - err = rows.Scan(&v) - val = v - } - if err != nil { - return nil, fmt.Errorf("cannot scan sample: %s", err) - } - values = append(values, val) - } - if err := rows.Err(); err != nil { - return nil, fmt.Errorf("cannot get samples: %s", err) - } - return values, nil -} - -func backticks(val string) string { - if strings.HasPrefix(val, "`") && strings.HasSuffix(val, "`") { - return url.QueryEscape(val) - } - return "`" + url.QueryEscape(val) + "`" -} - -func isSupportedType(fieldType string) bool { - supportedTypes := map[string]bool{ - "tinyint": true, - "smallint": true, - "mediumint": true, - "int": true, - "integer": true, - "bigint": true, - "float": true, - "decimal": true, - "double": true, - "char": true, - "varchar": true, - "date": true, - "datetime": true, - "timestamp": true, - "time": true, - "year": true, - "tinyblob": true, - "tinytext": true, - "blob": true, - "text": true, - "mediumblob": true, - "mediumtext": true, - "longblob": true, - "longtext": true, - "binary": true, - "varbinary": true, - "enum": true, - "set": true, - } - _, ok := supportedTypes[fieldType] - return ok -} - func processCliParams() (*cliOptions, error) { app := kingpin.New("mysql_random_data_loader", "MySQL Random Data Loader") diff --git a/main_test.go b/main_test.go index 7f6c601..55cbc34 100644 --- a/main_test.go +++ b/main_test.go @@ -1,14 +1,8 @@ package main import ( - "fmt" - "reflect" - "sync" "testing" - "time" - "github.com/Percona-Lab/mysql_random_data_load/internal/getters" - "github.com/Percona-Lab/mysql_random_data_load/tableparser" tu "github.com/Percona-Lab/mysql_random_data_load/testutils" ) @@ -22,60 +16,3 @@ func TestGetSamples(t *testing.T) { tu.Assert(t, int64(len(samples)) == wantRows, "Wrong number of samples. Have %d, want 100.", len(samples)) } - -func TestGenerateInsertData(t *testing.T) { - wantRows := 3 - - values := []getter{ - getters.NewRandomInt("f1", 100, false), - getters.NewRandomString("f2", 10, false), - getters.NewRandomDate("f3", false), - } - - rowsChan := make(chan []getter, 100) - count := 0 - wg := &sync.WaitGroup{} - wg.Add(1) - - go func() { - for { - select { - case <-time.After(10 * time.Millisecond): - wg.Done() - return - case row := <-rowsChan: - if reflect.TypeOf(row[0]).String() != "*getters.RandomInt" { - fmt.Printf("Expected '*getters.RandomInt' for field [0], got %q\n", reflect.TypeOf(row[0]).String()) - t.Fail() - } - if reflect.TypeOf(row[1]).String() != "*getters.RandomString" { - fmt.Printf("Expected '*getters.RandomString' for field [1], got %q\n", reflect.TypeOf(row[1]).String()) - t.Fail() - } - if reflect.TypeOf(row[2]).String() != "*getters.RandomDate" { - fmt.Printf("Expected '*getters.RandomDate' for field [2], got %q\n", reflect.TypeOf(row[2]).String()) - t.Fail() - } - count++ - } - } - }() - - generateInsertData(wantRows, values, rowsChan) - - wg.Wait() - tu.Assert(t, count == 3, "Invalid number of rows") -} - -func TestGenerateInsertStmt(t *testing.T) { - var table *tableparser.Table - tu.LoadJson(t, "sakila.film.json", &table) - want := "INSERT IGNORE INTO `sakila`.`film` " + - "(`title`,`description`,`release_year`,`language_id`," + - "`original_language_id`,`rental_duration`,`rental_rate`," + - "`length`,`replacement_cost`,`rating`,`special_features`," + - "`last_update`) VALUES " - - query := generateInsertStmt(table) - tu.Equals(t, want, query) -}