diff --git a/cmd/pg_unregister/main.go b/cmd/pg_unregister/main.go index 93b31649..c1ac9517 100644 --- a/cmd/pg_unregister/main.go +++ b/cmd/pg_unregister/main.go @@ -4,11 +4,14 @@ import ( "context" "encoding/base64" "fmt" + "log" "os" + "time" "github.com/fly-apps/postgres-flex/internal/flypg" "github.com/fly-apps/postgres-flex/internal/flypg/admin" "github.com/fly-apps/postgres-flex/internal/utils" + "github.com/jackc/pgx/v5" ) func main() { @@ -49,20 +52,43 @@ func processUnregistration(ctx context.Context) error { return fmt.Errorf("failed to unregister member: %v", err) } - slots, err := admin.ListReplicationSlots(ctx, conn) - if err != nil { - return fmt.Errorf("failed to list replication slots: %v", err) + slotName := fmt.Sprintf("repmgr_slot_%d", member.ID) + if err := removeReplicationSlot(ctx, conn, slotName); err != nil { + return err } - targetSlot := fmt.Sprintf("repmgr_slot_%d", member.ID) - for _, slot := range slots { - if slot.Name == targetSlot { - if err := admin.DropReplicationSlot(ctx, conn, targetSlot); err != nil { - return fmt.Errorf("failed to drop replication slot: %v", err) + return nil +} + +func removeReplicationSlot(ctx context.Context, conn *pgx.Conn, slotName string) error { + ticker := time.NewTicker(1 * time.Second) + timeout := time.After(10 * time.Second) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-timeout: + return fmt.Errorf("timed out trying to drop replication slot") + case <-ticker.C: + slot, err := admin.GetReplicationSlot(ctx, conn, slotName) + if err != nil { + if err == pgx.ErrNoRows { + return nil + } + return fmt.Errorf("failed to get replication slot %s: %v", slotName, err) + } + + if slot.Active { + log.Printf("Slot %s is still active, waiting...", slotName) + continue } - break + + if err := admin.DropReplicationSlot(ctx, conn, slotName); err != nil { + return fmt.Errorf("failed to drop replication slot %s: %v", slotName, err) + } + + return nil } } - - return nil } diff --git a/internal/flypg/admin/admin.go b/internal/flypg/admin/admin.go index 2f1925b1..985dd483 100644 --- a/internal/flypg/admin/admin.go +++ b/internal/flypg/admin/admin.go @@ -130,6 +130,18 @@ type ReplicationSlot struct { RetainedWalInBytes int } +func GetReplicationSlot(ctx context.Context, pg *pgx.Conn, slotName string) (*ReplicationSlot, error) { + sql := fmt.Sprintf("SELECT slot_name, active, wal_status, pg_wal_lsn_diff(pg_current_wal_lsn(), restart_lsn) AS retained_wal FROM pg_replication_slots where slot_name = '%s';", slotName) + row := pg.QueryRow(ctx, sql) + + var slot ReplicationSlot + if err := row.Scan(&slot.Name, &slot.Active, &slot.WalStatus, &slot.RetainedWalInBytes); err != nil { + return nil, err + } + + return &slot, nil +} + func ListReplicationSlots(ctx context.Context, pg *pgx.Conn) ([]ReplicationSlot, error) { sql := "SELECT slot_name, active, wal_status, pg_wal_lsn_diff(pg_current_wal_lsn(), restart_lsn) AS retained_wal FROM pg_replication_slots;" rows, err := pg.Query(ctx, sql) @@ -167,7 +179,6 @@ func ListReplicationSlots(ctx context.Context, pg *pgx.Conn) ([]ReplicationSlot, func DropReplicationSlot(ctx context.Context, pg *pgx.Conn, name string) error { sql := fmt.Sprintf("SELECT pg_drop_replication_slot('%s');", name) - _, err := pg.Exec(ctx, sql) if err != nil { return err