@@ -241,8 +241,9 @@ func (n *connector) Connect(ctx context.Context) (ret driver.Conn, err error) {
241
241
}
242
242
}()
243
243
244
- old := c .Conn .SetInterrupt (ctx )
245
- defer c .Conn .SetInterrupt (old )
244
+ if old := c .Conn .SetInterrupt (ctx ); old != ctx {
245
+ defer c .Conn .SetInterrupt (old )
246
+ }
246
247
247
248
if ! n .pragmas {
248
249
err = c .Conn .BusyTimeout (time .Minute )
@@ -362,8 +363,9 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
362
363
c .txReset = `; PRAGMA query_only=` + string (c .readOnly )
363
364
}
364
365
365
- old := c .Conn .SetInterrupt (ctx )
366
- defer c .Conn .SetInterrupt (old )
366
+ if old := c .Conn .SetInterrupt (ctx ); old != ctx {
367
+ defer c .Conn .SetInterrupt (old )
368
+ }
367
369
368
370
err := c .Conn .Exec (txBegin )
369
371
if err != nil {
@@ -382,8 +384,10 @@ func (c *conn) Commit() error {
382
384
383
385
func (c * conn ) Rollback () error {
384
386
// ROLLBACK even if interrupted.
385
- old := c .Conn .SetInterrupt (context .Background ())
386
- defer c .Conn .SetInterrupt (old )
387
+ ctx := context .Background ()
388
+ if old := c .Conn .SetInterrupt (ctx ); old != ctx {
389
+ defer c .Conn .SetInterrupt (old )
390
+ }
387
391
return c .Conn .Exec (`ROLLBACK` + c .txReset )
388
392
}
389
393
@@ -393,8 +397,9 @@ func (c *conn) Prepare(query string) (driver.Stmt, error) {
393
397
}
394
398
395
399
func (c * conn ) PrepareContext (ctx context.Context , query string ) (driver.Stmt , error ) {
396
- old := c .Conn .SetInterrupt (ctx )
397
- defer c .Conn .SetInterrupt (old )
400
+ if old := c .Conn .SetInterrupt (ctx ); old != ctx {
401
+ defer c .Conn .SetInterrupt (old )
402
+ }
398
403
399
404
s , tail , err := c .Conn .Prepare (query )
400
405
if err != nil {
@@ -419,8 +424,9 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
419
424
return resultRowsAffected (0 ), nil
420
425
}
421
426
422
- old := c .Conn .SetInterrupt (ctx )
423
- defer c .Conn .SetInterrupt (old )
427
+ if old := c .Conn .SetInterrupt (ctx ); old != ctx {
428
+ defer c .Conn .SetInterrupt (old )
429
+ }
424
430
425
431
err := c .Conn .Exec (query )
426
432
if err != nil {
@@ -483,8 +489,10 @@ func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive
483
489
return nil , err
484
490
}
485
491
486
- old := s .Stmt .Conn ().SetInterrupt (ctx )
487
- defer s .Stmt .Conn ().SetInterrupt (old )
492
+ c := s .Stmt .Conn ()
493
+ if old := c .SetInterrupt (ctx ); old != ctx {
494
+ defer c .SetInterrupt (old )
495
+ }
488
496
489
497
err = errors .Join (
490
498
s .Stmt .Exec (),
@@ -493,7 +501,7 @@ func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive
493
501
return nil , err
494
502
}
495
503
496
- return newResult (s . Stmt . Conn () ), nil
504
+ return newResult (c ), nil
497
505
}
498
506
499
507
func (s * stmt ) QueryContext (ctx context.Context , args []driver.NamedValue ) (driver.Rows , error ) {
@@ -678,13 +686,14 @@ func (r *rows) scanType(index int) scantype {
678
686
679
687
func (r * rows ) loadColumnMetadata () {
680
688
if r .nulls == nil {
689
+ c := r .Stmt .Conn ()
681
690
count := r .Stmt .ColumnCount ()
682
691
nulls := make ([]bool , count )
683
692
types := make ([]string , count )
684
693
scans := make ([]scantype , count )
685
694
for i := range nulls {
686
695
if col := r .Stmt .ColumnOriginName (i ); col != "" {
687
- types [i ], _ , nulls [i ], _ , _ , _ = r . Stmt . Conn () .TableColumnMetadata (
696
+ types [i ], _ , nulls [i ], _ , _ , _ = c .TableColumnMetadata (
688
697
r .Stmt .ColumnDatabaseName (i ),
689
698
r .Stmt .ColumnTableName (i ),
690
699
col )
@@ -762,8 +771,10 @@ func (r *rows) ColumnTypeScanType(index int) (typ reflect.Type) {
762
771
}
763
772
764
773
func (r * rows ) Next (dest []driver.Value ) error {
765
- old := r .Stmt .Conn ().SetInterrupt (r .ctx )
766
- defer r .Stmt .Conn ().SetInterrupt (old )
774
+ c := r .Stmt .Conn ()
775
+ if old := c .SetInterrupt (r .ctx ); old != r .ctx {
776
+ defer c .SetInterrupt (old )
777
+ }
767
778
768
779
if ! r .Stmt .Step () {
769
780
if err := r .Stmt .Err (); err != nil {
0 commit comments