@@ -86,7 +86,7 @@ func txRollback(tx **sql.Tx, w io.Writer) error {
8686 * tx = nil
8787 }
8888 if err == nil {
89- fmt .Fprintln (os . Stderr , "Rollback complete." )
89+ fmt .Fprintln (w , "Rollback complete." )
9090 }
9191 return err
9292}
@@ -118,7 +118,11 @@ func echo(spool *os.File, query string) {
118118 }
119119}
120120
121- func loop (ctx context.Context , conn * sql.DB ) error {
121+ type Options struct {
122+ RollbackOnFail bool
123+ }
124+
125+ func loop (ctx context.Context , options * Options , conn * sql.DB ) error {
122126 disabler := colorable .EnableColorsStdout (nil )
123127 defer disabler ()
124128
@@ -182,6 +186,11 @@ func loop(ctx context.Context, conn *sql.DB) error {
182186 err = txBegin (ctx , conn , & tx , tee (os .Stderr , spool ))
183187 if err == nil {
184188 err = doDML (ctx , tx , query , tee (os .Stdout , spool ))
189+ if err != nil && options .RollbackOnFail {
190+ fmt .Fprintln (tee (os .Stderr , spool ), err .Error ())
191+ echo (spool , "rollback (automatically)" )
192+ err = txRollback (& tx , tee (os .Stderr , spool ))
193+ }
185194 }
186195 case "COMMIT" :
187196 echo (spool , query )
@@ -226,7 +235,12 @@ func mains(args []string) error {
226235 }
227236 defer conn .Close ()
228237
229- return loop (context .Background (), conn )
238+ var options Options
239+ switch strings .ToUpper (args [0 ]) {
240+ case "POSTGRES" :
241+ options .RollbackOnFail = true
242+ }
243+ return loop (context .Background (), & options , conn )
230244}
231245
232246func main () {
0 commit comments