Skip to content
This repository was archived by the owner on Dec 22, 2025. It is now read-only.

Commit b4a5096

Browse files
authored
Support dumping to AWS RDS Postgres servers (#125)
* remove and re-add foreign keys when dumping to AWS RDS postgres databases * add --no-comments to pg_dump schema call
1 parent d726dda commit b4a5096

File tree

5 files changed

+68
-12
lines changed

5 files changed

+68
-12
lines changed

cmd/steal.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ type (
2929

3030
from string
3131
to string
32+
toRDS bool
3233
concurrency int
3334
readOpts connOpts
3435
writeOpts connOpts
@@ -65,6 +66,7 @@ func NewStealCmd() *cobra.Command {
6566
persistentFlags.StringVarP(&opts.configPath, "config", "c", config.DefaultConfigFileName, "Path to config file")
6667
persistentFlags.StringVarP(&opts.from, "from", "f", "mysql://root:root@tcp(localhost:3306)/klepto", "Database dsn to steal from")
6768
persistentFlags.StringVarP(&opts.to, "to", "t", "os://stdout/", "Database to output to (default writes to stdOut)")
69+
persistentFlags.BoolVar(&opts.toRDS, "to-rds", false, "If the output server is an AWS RDS server")
6870
persistentFlags.IntVar(&opts.concurrency, "concurrency", runtime.NumCPU(), "Sets the amount of dumps to be performed concurrently")
6971
persistentFlags.DurationVar(&opts.readOpts.timeout, "read-timeout", 5*time.Minute, "Sets the timeout for read operations")
7072
persistentFlags.DurationVar(&opts.readOpts.maxConnLifetime, "read-conn-lifetime", 0, "Sets the maximum amount of time a connection may be reused on the read database")
@@ -99,6 +101,7 @@ func RunSteal(opts *StealOptions) (err error) {
99101
source = anonymiser.NewAnonymiser(source, opts.cfgTables)
100102
target, err := dumper.NewDumper(dumper.ConnOpts{
101103
DSN: opts.to,
104+
IsRDS: opts.toRDS,
102105
Timeout: opts.writeOpts.timeout,
103106
MaxConnLifetime: opts.writeOpts.maxConnLifetime,
104107
MaxConns: opts.writeOpts.maxConns,

pkg/dumper/dumper.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ type (
3232
ConnOpts struct {
3333
// DSN is the connection address.
3434
DSN string
35+
// IsRDS identifies if the server is an AWS RDS server
36+
IsRDS bool
3537
// Timeout is the timeout for dump operations.
3638
Timeout time.Duration
3739
// MaxConnLifetime is the maximum amount of time a connection may be reused on the read database.

pkg/dumper/postgres/dumper.go

Lines changed: 61 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,26 @@ import (
1616
)
1717

1818
type (
19+
foreignKeyInfo struct {
20+
tableName string
21+
constraintName string
22+
constraintDefinition string
23+
}
24+
1925
pgDumper struct {
20-
conn *sql.DB
21-
reader reader.Reader
26+
conn *sql.DB
27+
reader reader.Reader
28+
isRDS bool
29+
foreignKeys []foreignKeyInfo
2230
}
2331
)
2432

2533
// NewDumper returns a new postgres dumper.
26-
func NewDumper(conn *sql.DB, rdr reader.Reader) dumper.Dumper {
34+
func NewDumper(opts dumper.ConnOpts, conn *sql.DB, rdr reader.Reader) dumper.Dumper {
2735
return engine.New(rdr, &pgDumper{
2836
conn: conn,
2937
reader: rdr,
38+
isRDS: opts.IsRDS,
3039
})
3140
}
3241

@@ -72,26 +81,67 @@ func (d *pgDumper) DumpTable(tableName string, rowChan <-chan database.Row) erro
7281
// PreDumpTables Disable triggers on all tables to avoid foreign key constraints
7382
func (d *pgDumper) PreDumpTables(tables []string) error {
7483
// We can't use `SET session_replication_role = replica` because multiple connections and stuff
75-
for _, tbl := range tables {
76-
query := fmt.Sprintf("ALTER TABLE %s DISABLE TRIGGER ALL", strconv.Quote(tbl))
77-
if _, err := d.conn.Exec(query); err != nil {
78-
return errors.Wrapf(err, "Failed to disable triggers for %s", tbl)
84+
// For RDS databases, the superuser does not have the required permission to call
85+
// DISABLE TRIGGER ALL, so manually remove and re-add all Foreign Keys
86+
if !d.isRDS {
87+
log.Debug("Disabling triggers")
88+
for _, tbl := range tables {
89+
query := fmt.Sprintf("ALTER TABLE %s DISABLE TRIGGER ALL", strconv.Quote(tbl))
90+
if _, err := d.conn.Exec(query); err != nil {
91+
return errors.Wrapf(err, "Failed to disable triggers for %s", tbl)
92+
}
7993
}
94+
return nil
8095
}
8196

97+
log.Debug("Removing foreign keys")
98+
query := `SELECT conrelid::regclass::varchar tableName,
99+
conname constraintName,
100+
pg_catalog.pg_get_constraintdef(r.oid, true) constraintDefinition
101+
FROM pg_catalog.pg_constraint r
102+
WHERE r.contype = 'f'
103+
AND r.connamespace = (SELECT n.oid FROM pg_namespace n WHERE n.nspname = current_schema())
104+
`
105+
rows, err := d.conn.Query(query)
106+
if err != nil {
107+
return errors.Wrapf(err, "Failed to query ForeignKeys")
108+
}
109+
defer rows.Close()
110+
for rows.Next() {
111+
var fk foreignKeyInfo
112+
if err := rows.Scan(&fk.tableName, &fk.constraintName, &fk.constraintDefinition); err != nil {
113+
return errors.Wrapf(err, "Failed to load ForeignKeyInfo")
114+
}
115+
query := fmt.Sprintf("ALTER TABLE %s DROP CONSTRAINT %s", strconv.Quote(fk.tableName), strconv.Quote(fk.constraintName))
116+
if _, err := d.conn.Exec(query); err != nil {
117+
return errors.Wrapf(err, "Failed to frop contraint %s.%s", fk.tableName, fk.constraintName)
118+
}
119+
d.foreignKeys = append(d.foreignKeys, fk)
120+
}
82121
return nil
83122
}
84123

85124
// PostDumpTables enable triggers on all tables to enforce foreign key constraints
86125
func (d *pgDumper) PostDumpTables(tables []string) error {
87126
// We can't use `SET session_replication_role = DEFAULT` because multiple connections and stuff
88-
for _, tbl := range tables {
89-
query := fmt.Sprintf("ALTER TABLE %s ENABLE TRIGGER ALL", strconv.Quote(tbl))
90-
if _, err := d.conn.Exec(query); err != nil {
91-
return errors.Wrapf(err, "Failed to anble triggers for %s", tbl)
127+
if !d.isRDS {
128+
log.Debug("Reenabling triggers")
129+
for _, tbl := range tables {
130+
query := fmt.Sprintf("ALTER TABLE %s ENABLE TRIGGER ALL", strconv.Quote(tbl))
131+
if _, err := d.conn.Exec(query); err != nil {
132+
return errors.Wrapf(err, "Failed to enable triggers for %s", tbl)
133+
}
92134
}
135+
return nil
93136
}
94137

138+
log.Debug("Recreating foreign keys")
139+
for _, fk := range d.foreignKeys {
140+
query := fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s %s", strconv.Quote(fk.tableName), strconv.Quote(fk.constraintName), fk.constraintDefinition)
141+
if _, err := d.conn.Exec(query); err != nil {
142+
return errors.Wrapf(err, "Failed to re-create ForeignKey %s.%s", fk.tableName, fk.constraintName)
143+
}
144+
}
95145
return nil
96146
}
97147

pkg/dumper/postgres/postgres.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ func (m *driver) NewConnection(opts dumper.ConnOpts, rdr reader.Reader) (dumper.
2424
conn.SetMaxIdleConns(opts.MaxIdleConns)
2525
conn.SetConnMaxLifetime(opts.MaxConnLifetime)
2626

27-
return NewDumper(conn, rdr), nil
27+
return NewDumper(opts, conn, rdr), nil
2828
}
2929

3030
func init() {

pkg/reader/postgres/pg_dump.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ func (p *PgDump) GetStructure() (string, error) {
3838
"--schema-only",
3939
"--no-privileges",
4040
"--no-owner",
41+
"--no-comments",
4142
)
4243

4344
logger.Debug("loading schema for table")

0 commit comments

Comments
 (0)