Skip to content

Commit 3207f3d

Browse files
authored
return-txn-current-if-exists (#432)
* return-txn-current-if-exists * prevent ro -> rw
1 parent d104685 commit 3207f3d

File tree

2 files changed

+89
-154
lines changed

2 files changed

+89
-154
lines changed

gorm/transaction.go

Lines changed: 25 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -102,28 +102,34 @@ func (t *Transaction) AddAfterCommitHook(hooks ...func(context.Context)) {
102102
t.afterCommitHook = append(t.afterCommitHook, hooks...)
103103
}
104104

105-
// getReadOnlyDBInstance returns the read only db txn if RO DB available otherwise it returns read/write db txn
105+
// getReadOnlyDBInstance returns current db txn if exists or the read only db txn if RO DB available otherwise it returns read/write db txn
106+
// If current db txn exists, prevents usage of RO db txn for RW operations
106107
func getReadOnlyDBTxn(ctx context.Context, opts *databaseOptions, txn *Transaction) (*gorm.DB, error) {
107108
var db *gorm.DB
108109
switch {
110+
case txn.current != nil:
111+
// Prevent usage of RO db txn for RW operations
112+
if txn.currentOpts.database == dbReadOnly && opts.database == dbReadWrite {
113+
return nil, ErrCtxDBOptMismatch
114+
}
115+
if txn.currentOpts.txOpts != nil && opts.txOpts != nil {
116+
// Prevent converting tx opts from RO to RW
117+
if txn.currentOpts.txOpts.ReadOnly && !opts.txOpts.ReadOnly {
118+
return nil, ErrCtxTxnOptMismatch
119+
}
120+
}
121+
// Return existing db txn; new opts not applied
122+
return txn.current, nil
109123
case txn.parentRO == nil:
110124
return getReadWriteDBTxn(ctx, opts, txn)
111-
case opts.txOpts != nil && txn.currentOpts.txOpts != nil:
112-
if *opts.txOpts != *txn.currentOpts.txOpts {
113-
return nil, ErrCtxTxnOptMismatch
114-
}
115125
case opts.txOpts != nil:
116-
// We should error in two cases 1. We should error if read-only DB requested with read-write txn
117-
// 2. If no txn options provided in previous call but provided in subsequent call
118-
if !opts.txOpts.ReadOnly || txn.currentOpts.database != dbNotSet {
126+
// Return error if read-only DB requested with read-write txn
127+
if !opts.txOpts.ReadOnly {
119128
return nil, ErrCtxTxnOptMismatch
120129
}
121130
txnOpts := *opts.txOpts
122131
txn.currentOpts.txOpts = &txnOpts
123132
}
124-
if txn.current != nil {
125-
return txn.current, nil
126-
}
127133
db = txn.beginReadOnlyWithContextAndOptions(ctx, txn.currentOpts.txOpts)
128134
if db.Error != nil {
129135
return nil, db.Error
@@ -134,27 +140,23 @@ func getReadOnlyDBTxn(ctx context.Context, opts *databaseOptions, txn *Transacti
134140
return db, nil
135141
}
136142

137-
// getReadWriteDBTxn returns the read/write db txn
143+
// getReadWriteDBTxn returns current RW db txn if exist otherwise a RW db txn
138144
func getReadWriteDBTxn(ctx context.Context, opts *databaseOptions, txn *Transaction) (*gorm.DB, error) {
139145
var db *gorm.DB
140146
switch {
147+
case txn.current != nil:
148+
// Prevent use of existing RO db txn for RW operations
149+
if txn.currentOpts.database == dbReadOnly {
150+
return nil, ErrCtxDBOptMismatch
151+
}
152+
// Return existing db txn; new opts not applied
153+
return txn.current, nil
141154
case txn.parent == nil:
142155
return nil, ErrCtxTxnNoDB
143-
case opts.txOpts != nil && txn.currentOpts.txOpts != nil:
144-
if *opts.txOpts != *txn.currentOpts.txOpts {
145-
return nil, ErrCtxTxnOptMismatch
146-
}
147156
case opts.txOpts != nil:
148-
// We should return error If no txn options provided in previous call but provided in subsequent call
149-
if txn.currentOpts.database != dbNotSet {
150-
return nil, ErrCtxTxnOptMismatch
151-
}
152157
txnOpts := *opts.txOpts
153158
txn.currentOpts.txOpts = &txnOpts
154159
}
155-
if txn.current != nil {
156-
return txn.current, nil
157-
}
158160
db = txn.beginWithContextAndOptions(ctx, txn.currentOpts.txOpts)
159161
if db.Error != nil {
160162
return nil, db.Error
@@ -178,14 +180,8 @@ func BeginFromContext(ctx context.Context, options ...DatabaseOption) (*gorm.DB,
178180
opts := toDatabaseOptions(options...)
179181
switch opts.database {
180182
case dbReadOnly:
181-
if txn.currentOpts.database == dbReadWrite && txn.parentRO != nil {
182-
return nil, ErrCtxDBOptMismatch
183-
}
184183
return getReadOnlyDBTxn(ctx, opts, txn)
185184
case dbReadWrite:
186-
if txn.currentOpts.database == dbReadOnly {
187-
return nil, ErrCtxDBOptMismatch
188-
}
189185
return getReadWriteDBTxn(ctx, opts, txn)
190186
default:
191187
// This is the case to handle when no database options provided

gorm/transaction_test.go

Lines changed: 64 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -499,27 +499,17 @@ func TestBeginFromContextStartWithNoOptions(t *testing.T) {
499499
t.Errorf("failed to begin transaction for read-write db - %s", err)
500500
}
501501
test.withOpts = readOnly
502-
_, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
503-
if err != ErrCtxDBOptMismatch {
504-
t.Error("begin transaction should fail with an error DBOptionsMismatch")
505-
}
506-
test.withOpts = readWrite
507-
txn3, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
502+
txn2, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
508503
if err != nil {
509504
t.Error("Received an error beginning transaction")
510505
}
511-
if txn3 == nil {
506+
if txn2 == nil {
512507
t.Error("Did not receive a transaction from context")
513508
}
514-
// Case: Transaction begin is idempotent
515-
if txn1 != txn3 {
509+
// Case: withOpts are ignored if txn is already open, current txn is returned
510+
if txn1 != txn2 {
516511
t.Error("Got a different txn than was opened before")
517512
}
518-
test.txOpts = &sql.TxOptions{}
519-
_, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
520-
if err != ErrCtxTxnOptMismatch {
521-
t.Error("begin transaction should fail with an error TxnOptionsMismatch")
522-
}
523513
} else {
524514
txn1, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
525515
if err != nil {
@@ -532,54 +522,16 @@ func TestBeginFromContextStartWithNoOptions(t *testing.T) {
532522
t.Errorf("failed to begin transaction for read-write db - %s", err)
533523
}
534524
test.txOpts.ReadOnly = true
535-
_, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
536-
if err != ErrCtxTxnOptMismatch {
537-
t.Error("begin transaction should fail with an error TxOptionsMismatch")
538-
}
539-
test.txOpts.ReadOnly = false
540-
test.txOpts.Isolation = sql.LevelSerializable
541-
_, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
542-
if err != ErrCtxTxnOptMismatch {
543-
t.Error("begin transaction should fail with an error TxOptionsMismatch")
544-
}
545-
test.txOpts.Isolation = sql.LevelDefault
546-
test.withOpts = readOnly
547-
_, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
548-
if err != ErrCtxDBOptMismatch {
549-
t.Error("begin transaction should fail with an error DBOptionsMismatch")
550-
}
551-
test.withOpts = readWrite
552-
txn3, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
553-
if err != nil {
554-
t.Error("Received an error beginning transaction")
555-
}
556-
if txn3 == nil {
557-
t.Error("Did not receive a transaction from context")
558-
}
559-
// Case: Transaction begin is idempotent
560-
if txn1 != txn3 {
561-
t.Error("Got a different txn than was opened before")
562-
}
563-
test.txOpts.ReadOnly = true
564-
_, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
565-
if err != ErrCtxTxnOptMismatch {
566-
t.Error("begin transaction should fail with an error TxOptionsMismatch")
567-
}
568-
test.txOpts.ReadOnly = false
569-
test.txOpts.Isolation = sql.LevelSerializable
570-
_, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
571-
if err != ErrCtxTxnOptMismatch {
572-
t.Error("begin transaction should fail with an error TxOptionsMismatch")
573-
}
574-
txn4, err := beginFromContextWithOptions(ctx, test.withOpts, nil)
525+
test.txOpts.Isolation = sql.LevelRepeatableRead
526+
txn2, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
575527
if err != nil {
576528
t.Error("Received an error beginning transaction")
577529
}
578-
if txn4 == nil {
530+
if txn2 == nil {
579531
t.Error("Did not receive a transaction from context")
580532
}
581-
// Case: Transaction begin is idempotent
582-
if txn1 != txn4 {
533+
// Case: txOpts are ignored if txn is already open, current txn is returned
534+
if txn1 != txn2 {
583535
t.Error("Got a different txn than was opened before")
584536
}
585537
}
@@ -639,6 +591,7 @@ func TestBeginFromContextStartWithReadOnlyOptions(t *testing.T) {
639591
if err := dbROMock.ExpectationsWereMet(); err != nil {
640592
t.Errorf("failed to begin transaction for read-only db - %s", err)
641593
}
594+
// Case: noOptions allowed, current txn is returned
642595
test.withOpts = noOptions
643596
txn2, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
644597
if err != nil {
@@ -647,20 +600,14 @@ func TestBeginFromContextStartWithReadOnlyOptions(t *testing.T) {
647600
if txn2 == nil {
648601
t.Error("Did not receive a transaction from context")
649602
}
650-
// Case: Transaction begin is idempotent
651603
if txn1 != txn2 {
652604
t.Error("Got a different txn than was opened before")
653605
}
606+
// Case: DB opts mismatch, RO -> RW not allowed
654607
test.withOpts = readWrite
655608
_, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
656609
if err != ErrCtxDBOptMismatch {
657-
t.Error("begin transaction should fail with an error DBOptionsMismatch")
658-
}
659-
test.withOpts = noOptions
660-
test.txOpts = &sql.TxOptions{}
661-
_, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
662-
if err != ErrCtxTxnOptMismatch {
663-
t.Error("begin transaction should fail with an error TxnOptionsMismatch")
610+
t.Error("begin transaction should fail with an error DBOptMismatch")
664611
}
665612
} else {
666613
_, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
@@ -678,19 +625,8 @@ func TestBeginFromContextStartWithReadOnlyOptions(t *testing.T) {
678625
if err := dbROMock.ExpectationsWereMet(); err != nil {
679626
t.Errorf("failed to begin transaction for read-only db - %s", err)
680627
}
681-
test.txOpts.ReadOnly = false
682-
_, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
683-
if err != ErrCtxTxnOptMismatch {
684-
t.Error("begin transaction should fail with an error TxOptionsMismatch")
685-
}
628+
// Case: txOpts read-only match, current txn is returned
686629
test.txOpts.ReadOnly = true
687-
test.txOpts.Isolation = sql.LevelSerializable
688-
_, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
689-
if err != ErrCtxTxnOptMismatch {
690-
t.Error("begin transaction should fail with an error TxOptionsMismatch")
691-
}
692-
test.txOpts.Isolation = sql.LevelDefault
693-
test.withOpts = noOptions
694630
txn2, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
695631
if err != nil {
696632
t.Error("Received an error beginning transaction")
@@ -701,20 +637,18 @@ func TestBeginFromContextStartWithReadOnlyOptions(t *testing.T) {
701637
if txn1 != txn2 {
702638
t.Error("Got a different txn than was opened before")
703639
}
704-
txn3, err := beginFromContextWithOptions(ctx, test.withOpts, nil)
705-
if err != nil {
706-
t.Error("Received an error beginning transaction")
707-
}
708-
if txn3 == nil {
709-
t.Error("Did not receive a transaction from context")
710-
}
711-
if txn1 != txn3 {
712-
t.Error("Got a different txn than was opened before")
640+
// Case: txOpts mismatch, RO -> RW not allowed
641+
test.txOpts.ReadOnly = false
642+
test.txOpts.Isolation = sql.LevelSerializable
643+
_, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
644+
if err != ErrCtxTxnOptMismatch {
645+
t.Error("begin transaction should fail with an error TxOptionsMismatch")
713646
}
647+
// Case: DB opts mismatch, RO -> RW not allowed
714648
test.withOpts = readWrite
715649
_, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
716650
if err != ErrCtxDBOptMismatch {
717-
t.Error("begin transaction should fail with an error DBOptionsMismatch")
651+
t.Error("begin transaction should fail with an error DBOptMismatch")
718652
}
719653
}
720654
})
@@ -772,27 +706,57 @@ func TestBeginFromContextStartWithReadWriteOptions(t *testing.T) {
772706
if err := mock.ExpectationsWereMet(); err != nil {
773707
t.Errorf("failed to begin transaction for read-write db - %s", err)
774708
}
709+
// Case: DB opts mismatch, RW -> RO allowed, current txn is returned
775710
test.withOpts = readOnly
776-
_, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
777-
if err != ErrCtxDBOptMismatch {
778-
t.Error("begin transaction should fail with an error DBOptionsMismatch")
711+
txn2, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
712+
if err != nil {
713+
t.Error("Received an error beginning transaction")
779714
}
780-
test.withOpts = noOptions
715+
if txn2 == nil {
716+
t.Error("Did not receive a transaction from context")
717+
}
718+
719+
if txn1 != txn2 {
720+
t.Error("Got a different txn than was opened before")
721+
}
722+
// Case: txOpts mismatch, nil -> empty opts allowed, current txn is returned
723+
test.txOpts = &sql.TxOptions{}
781724
txn3, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
782725
if err != nil {
783726
t.Error("Received an error beginning transaction")
784727
}
785728
if txn3 == nil {
786729
t.Error("Did not receive a transaction from context")
787730
}
788-
// Case: Transaction begin is idempotent
731+
789732
if txn1 != txn3 {
790733
t.Error("Got a different txn than was opened before")
791734
}
792-
test.txOpts = &sql.TxOptions{}
793-
_, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
794-
if err != ErrCtxTxnOptMismatch {
795-
t.Error("begin transaction should fail with an error TxOptionsMismatch")
735+
// Case: txOpts mismatch, nil -> RO allowed, current txn is returned
736+
test.txOpts = &sql.TxOptions{ReadOnly: true}
737+
txn4, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
738+
if err != nil {
739+
t.Error("Received an error beginning transaction")
740+
}
741+
if txn4 == nil {
742+
t.Error("Did not receive a transaction from context")
743+
}
744+
745+
if txn1 != txn4 {
746+
t.Error("Got a different txn than was opened before")
747+
}
748+
// Case: txOpts mismatch, nil -> RW allowed, current txn is returned
749+
test.txOpts = &sql.TxOptions{ReadOnly: false}
750+
txn5, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
751+
if err != nil {
752+
t.Error("Received an error beginning transaction")
753+
}
754+
if txn5 == nil {
755+
t.Error("Did not receive a transaction from context")
756+
}
757+
758+
if txn1 != txn5 {
759+
t.Error("Got a different txn than was opened before")
796760
}
797761
} else {
798762
txn1, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
@@ -805,47 +769,22 @@ func TestBeginFromContextStartWithReadWriteOptions(t *testing.T) {
805769
if err := mock.ExpectationsWereMet(); err != nil {
806770
t.Errorf("failed to begin transaction for read-write db - %s", err)
807771
}
772+
// Case: txOpts mismatch, RW -> RO allowed, current txn is returned
808773
test.txOpts.ReadOnly = true
809-
_, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
810-
if err != ErrCtxTxnOptMismatch {
811-
t.Error("begin transaction should fail with an error TxOptionsMismatch")
812-
}
813-
test.txOpts.ReadOnly = false
814774
test.txOpts.Isolation = sql.LevelSerializable
815-
_, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
816-
if err != ErrCtxTxnOptMismatch {
817-
t.Error("begin transaction should fail with an error TxOptionsMismatch")
818-
}
819-
test.txOpts.Isolation = sql.LevelDefault
820-
test.withOpts = readOnly
821-
_, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
822-
if err != ErrCtxDBOptMismatch {
823-
t.Error("begin transaction should fail with an error DBOptionsMismatch")
824-
}
825-
test.withOpts = noOptions
826775
txn2, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
827776
if err != nil {
828777
t.Error("Received an error beginning transaction")
829778
}
830779
if txn2 == nil {
831780
t.Error("Did not receive a transaction from context")
832781
}
833-
// Case: Transaction begin is idempotent
834782
if txn1 != txn2 {
835783
t.Error("Got a different txn than was opened before")
836784
}
837-
test.txOpts.ReadOnly = true
838-
_, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
839-
if err != ErrCtxTxnOptMismatch {
840-
t.Error("begin transaction should fail with an error TxOptionsMismatch")
841-
}
842-
test.txOpts.ReadOnly = false
843-
test.txOpts.Isolation = sql.LevelSerializable
844-
_, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
845-
if err != ErrCtxTxnOptMismatch {
846-
t.Error("begin transaction should fail with an error TxOptionsMismatch")
847-
}
848-
txn3, err := beginFromContextWithOptions(ctx, test.withOpts, nil)
785+
// Case: DB opts mismatch, RW -> RO allowed, current txn is returned
786+
test.withOpts = readOnly
787+
txn3, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
849788
if err != nil {
850789
t.Error("Received an error beginning transaction")
851790
}

0 commit comments

Comments
 (0)