Skip to content

Commit 0aabd5f

Browse files
committed
On PostgreSQL, do rollback on failure
1 parent 1057116 commit 0aabd5f

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

main.go

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ func txRollback(tx **sql.Tx, w io.Writer) error {
8686
*tx = nil
8787
}
8888
if err == nil {
89-
fmt.Fprintln(os.Stderr, "Rollback complete.")
89+
fmt.Fprintln(w, "Rollback complete.")
9090
}
9191
return err
9292
}
@@ -118,7 +118,11 @@ func echo(spool *os.File, query string) {
118118
}
119119
}
120120

121-
func loop(ctx context.Context, conn *sql.DB) error {
121+
type Options struct {
122+
RollbackOnFail bool
123+
}
124+
125+
func loop(ctx context.Context, options *Options, conn *sql.DB) error {
122126
disabler := colorable.EnableColorsStdout(nil)
123127
defer disabler()
124128

@@ -182,6 +186,11 @@ func loop(ctx context.Context, conn *sql.DB) error {
182186
err = txBegin(ctx, conn, &tx, tee(os.Stderr, spool))
183187
if err == nil {
184188
err = doDML(ctx, tx, query, tee(os.Stdout, spool))
189+
if err != nil && options.RollbackOnFail {
190+
fmt.Fprintln(tee(os.Stderr, spool), err.Error())
191+
echo(spool, "rollback (automatically)")
192+
err = txRollback(&tx, tee(os.Stderr, spool))
193+
}
185194
}
186195
case "COMMIT":
187196
echo(spool, query)
@@ -226,7 +235,12 @@ func mains(args []string) error {
226235
}
227236
defer conn.Close()
228237

229-
return loop(context.Background(), conn)
238+
var options Options
239+
switch strings.ToUpper(args[0]) {
240+
case "POSTGRES":
241+
options.RollbackOnFail = true
242+
}
243+
return loop(context.Background(), &options, conn)
230244
}
231245

232246
func main() {

0 commit comments

Comments
 (0)