diff --git a/internal/dumper/loader.go b/internal/dumper/loader.go index c11228e1..9d9bf626 100644 --- a/internal/dumper/loader.go +++ b/internal/dumper/loader.go @@ -304,32 +304,68 @@ func (l *Loader) restoreTableSchema(overwrite bool, tables []string, conn *Conne return err } + // Drop table once before processing queries (if overwrite is enabled) + if overwrite { + l.log.Info( + "drop(overwrite.is.true)", + zap.String("database", db), + zap.String("table ", tbl), + ) + + if l.cfg.ShowDetails { + l.cfg.Printer.Println("Dropping Existing Table (if it exists): " + printer.BoldBlue(name)) + } + dropQuery := fmt.Sprintf("DROP TABLE IF EXISTS %s", name) + err = conn.Execute(dropQuery) + if err != nil { + return err + } + } + + // Execute each valid SQL statement (skip comments) for _, query := range queries { - if !strings.HasPrefix(query, "/*") && query != "" { - if overwrite { - l.log.Info( - "drop(overwrite.is.true)", - zap.String("database", db), - zap.String("table ", tbl), - ) + // Skip empty queries and block comments + trimmedQuery := strings.TrimSpace(query) + if trimmedQuery == "" || strings.HasPrefix(trimmedQuery, "/*") { + continue + } - if l.cfg.ShowDetails { - l.cfg.Printer.Println("Dropping Existing Table (if it exists): " + printer.BoldBlue(name)) - } - dropQuery := fmt.Sprintf("DROP TABLE IF EXISTS %s", name) - err = conn.Execute(dropQuery) - if err != nil { - return err - } + // Filter out line comments (--) but keep the rest of the query + var cleanedLines []string + for _, line := range strings.Split(query, "\n") { + trimmedLine := strings.TrimSpace(line) + // Skip empty lines and comment-only lines + if trimmedLine != "" && !strings.HasPrefix(trimmedLine, "--") { + cleanedLines = append(cleanedLines, line) } + } - if l.cfg.ShowDetails { + // Skip if no non-comment content remains + if len(cleanedLines) == 0 { + continue + } + + // Reconstruct the query without comment lines + cleanedQuery := strings.Join(cleanedLines, "\n") + trimmedCleanedQuery := strings.TrimSpace(cleanedQuery) + + if l.cfg.ShowDetails { + // Detect query type and provide appropriate output + upperQuery := strings.ToUpper(trimmedCleanedQuery) + if strings.HasPrefix(upperQuery, "CREATE TABLE") { l.cfg.Printer.Printf("Creating Table: %s (Table %d of %d)\n", printer.BoldBlue(name), (idx + 1), numberOfTables) + } else if strings.HasPrefix(upperQuery, "ALTER TABLE") { + l.cfg.Printer.Printf("Altering Table: %s (Table %d of %d)\n", printer.BoldBlue(name), (idx + 1), numberOfTables) + l.cfg.Printer.Printf("Query: %s\n", cleanedQuery) + } else { + // For any other query type, show what's being executed + l.cfg.Printer.Printf("Executing Query for Table: %s (Table %d of %d)\n", printer.BoldBlue(name), (idx + 1), numberOfTables) + l.cfg.Printer.Printf("Query: %s\n", cleanedQuery) } - err = conn.Execute(query) - if err != nil { - return err - } + } + err = conn.Execute(cleanedQuery) + if err != nil { + return err } } l.log.Info("restoring schema", diff --git a/internal/dumper/loader_test.go b/internal/dumper/loader_test.go index e584dbad..4182434f 100644 --- a/internal/dumper/loader_test.go +++ b/internal/dumper/loader_test.go @@ -2,6 +2,7 @@ package dumper import ( "context" + "os" "testing" qt "github.com/frankban/quicktest" @@ -47,3 +48,186 @@ func TestLoader(t *testing.T) { err = loader.Run(context.Background()) c.Assert(err, qt.IsNil) } + +func TestRestoreTableSchema_WithComments(t *testing.T) { + tests := []struct { + name string + schemaContent string + setupQueries []string + description string + }{ + { + name: "schema with line comments at end", + schemaContent: `CREATE TABLE example_table ( + id INT PRIMARY KEY +); +-- This is a comment +-- This is another comment`, + setupQueries: []string{ + "CREATE TABLE example_table (\n id INT PRIMARY KEY\n)", + }, + description: "Should execute CREATE TABLE and skip trailing comments", + }, + { + name: "schema with ALTER TABLE after comments", + schemaContent: `CREATE TABLE example_table ( + id INT PRIMARY KEY, + name VARCHAR(100) +); +-- This is a comment +-- This is another comment +ALTER TABLE example_table + ADD INDEX idx_name (name);`, + setupQueries: []string{ + "CREATE TABLE example_table (\n id INT PRIMARY KEY,\n name VARCHAR(100)\n)", + }, + description: "Should execute CREATE TABLE and ALTER TABLE, skipping comments in between", + }, + { + name: "schema with block comments", + schemaContent: `/* This is a block comment */ +CREATE TABLE example_table ( + id INT PRIMARY KEY +);`, + setupQueries: []string{ + "CREATE TABLE example_table (\n id INT PRIMARY KEY\n)", + }, + description: "Should skip block comments and execute CREATE TABLE", + }, + { + name: "schema with interspersed comments", + schemaContent: `CREATE TABLE example_table ( + id INT PRIMARY KEY +); +-- Comment between statements +ALTER TABLE example_table ADD COLUMN name VARCHAR(100); +-- Another comment +ALTER TABLE example_table ADD INDEX idx_id (id);`, + setupQueries: []string{ + "CREATE TABLE example_table (\n id INT PRIMARY KEY\n)", + }, + description: "Should execute all SQL statements and skip all comment lines", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := qt.New(t) + + log := xlog.NewStdLog(xlog.Level(xlog.ERROR)) + fakedbs := driver.NewTestHandler(log) + server, err := driver.MockMysqlServer(log, fakedbs) + c.Assert(err, qt.IsNil) + defer server.Close() + + address := server.Addr() + + // Set up mock expectations + fakedbs.AddQueryPattern("use .*", &sqltypes.Result{}) + fakedbs.AddQueryPattern("set foreign_key_checks=.*", &sqltypes.Result{}) + fakedbs.AddQueryPattern("drop table .*", &sqltypes.Result{}) + fakedbs.AddQueryPattern("alter table .*", &sqltypes.Result{}) + + // Add expected queries - if these aren't executed, the test will fail + for _, query := range tt.setupQueries { + fakedbs.AddQuery(query, &sqltypes.Result{}) + } + + // Create test schema file + tempDir := c.TempDir() + schemaFile := tempDir + "/testdb.test_table-schema.sql" + err = os.WriteFile(schemaFile, []byte(tt.schemaContent), 0644) + c.Assert(err, qt.IsNil) + + // Create loader + cfg := &Config{ + Database: "testdb", + Outdir: tempDir, + User: "mock", + Password: "mock", + Threads: 1, + Address: address, + IntervalMs: 500, + OverwriteTables: true, + ShowDetails: false, + Debug: false, + } + loader, err := NewLoader(cfg) + c.Assert(err, qt.IsNil) + + // Create connection pool + pool, err := NewPool(loader.log, cfg.Threads, cfg.Address, cfg.User, cfg.Password, cfg.SessionVars, "") + c.Assert(err, qt.IsNil) + defer pool.Close() + + conn := pool.Get() + defer pool.Put(conn) + + // Execute restoreTableSchema - should not return error if all expected queries are executed + err = loader.restoreTableSchema(cfg.OverwriteTables, []string{schemaFile}, conn) + c.Assert(err, qt.IsNil, qt.Commentf("%s: failed to restore table schema", tt.description)) + }) + } +} + +func TestRestoreTableSchema_DropTableCalledOnce(t *testing.T) { + c := qt.New(t) + + log := xlog.NewStdLog(xlog.Level(xlog.ERROR)) + fakedbs := driver.NewTestHandler(log) + server, err := driver.MockMysqlServer(log, fakedbs) + c.Assert(err, qt.IsNil) + defer server.Close() + + address := server.Addr() + + // Set up mock expectations + fakedbs.AddQueryPattern("use .*", &sqltypes.Result{}) + fakedbs.AddQueryPattern("set foreign_key_checks=.*", &sqltypes.Result{}) + + // Add DROP TABLE query only once - if it's called twice, the second call will fail + // because there's no matching handler for it + fakedbs.AddQuery("DROP TABLE IF EXISTS `testdb`.`test_table`", &sqltypes.Result{}) + fakedbs.AddQuery("CREATE TABLE test_table (\n id INT PRIMARY KEY\n)", &sqltypes.Result{}) + + // Create test schema file with comments at the end (the original bug scenario) + tempDir := c.TempDir() + schemaFile := tempDir + "/testdb.test_table-schema.sql" + schemaContent := `CREATE TABLE test_table ( + id INT PRIMARY KEY +); +-- This is a comment +-- This is another comment` + err = os.WriteFile(schemaFile, []byte(schemaContent), 0644) + c.Assert(err, qt.IsNil) + + // Create loader + cfg := &Config{ + Database: "testdb", + Outdir: tempDir, + User: "mock", + Password: "mock", + Threads: 1, + Address: address, + IntervalMs: 500, + OverwriteTables: true, + ShowDetails: false, + Debug: false, + } + loader, err := NewLoader(cfg) + c.Assert(err, qt.IsNil) + + // Create connection pool + pool, err := NewPool(loader.log, cfg.Threads, cfg.Address, cfg.User, cfg.Password, cfg.SessionVars, "") + c.Assert(err, qt.IsNil) + defer pool.Close() + + conn := pool.Get() + defer pool.Put(conn) + + // Execute restoreTableSchema + // If DROP TABLE is called more than once, the test will fail because there's + // only one handler registered for it + err = loader.restoreTableSchema(cfg.OverwriteTables, []string{schemaFile}, conn) + c.Assert(err, qt.IsNil, qt.Commentf("DROP TABLE should be called exactly once. If called multiple times, this test will fail.")) +}