@@ -19,6 +19,8 @@ package pulsar
1919
2020import (
2121 "context"
22+ "errors"
23+ "fmt"
2224 "sync"
2325 "sync/atomic"
2426 "time"
@@ -33,9 +35,9 @@ type subscription struct {
3335}
3436
3537type transaction struct {
36- sync.Mutex
38+ mu sync.Mutex
3739 txnID TxnID
38- state TxnState
40+ state atomic. Int32
3941 tcClient * transactionCoordinatorClient
4042 registerPartitions map [string ]bool
4143 registerAckSubscriptions map [subscription ]bool
@@ -54,96 +56,106 @@ type transaction struct {
5456 // 1. When the transaction is committed or aborted, a bool will be read from opsFlow chan.
5557 // 2. When the opsCount increment from 0 to 1, a bool will be read from opsFlow chan.
5658 opsFlow chan bool
57- opsCount int32
59+ opsCount atomic. Int32
5860 opTimeout time.Duration
5961 log log.Logger
6062}
6163
6264func newTransaction (id TxnID , tcClient * transactionCoordinatorClient , timeout time.Duration ) * transaction {
6365 transaction := & transaction {
6466 txnID : id ,
65- state : TxnOpen ,
6667 registerPartitions : make (map [string ]bool ),
6768 registerAckSubscriptions : make (map [subscription ]bool ),
6869 opsFlow : make (chan bool , 1 ),
69- opTimeout : 5 * time . Second ,
70+ opTimeout : tcClient . client . operationTimeout ,
7071 tcClient : tcClient ,
7172 }
72- //This means there are not pending requests with this transaction. The transaction can be committed or aborted.
73+ transaction .state .Store (int32 (TxnOpen ))
74+ // This means there are not pending requests with this transaction. The transaction can be committed or aborted.
7375 transaction .opsFlow <- true
7476 go func () {
75- //Set the state of the transaction to timeout after timeout
77+ // Set the state of the transaction to timeout after timeout
7678 <- time .After (timeout )
77- atomic . CompareAndSwapInt32 (( * int32 )( & transaction .state ), int32 (TxnOpen ), int32 (TxnTimeout ))
79+ transaction .state . CompareAndSwap ( int32 (TxnOpen ), int32 (TxnTimeout ))
7880 }()
7981 transaction .log = tcClient .log .SubLogger (log.Fields {})
8082 return transaction
8183}
8284
8385func (txn * transaction ) GetState () TxnState {
84- return txn .state
86+ return TxnState ( txn .state . Load ())
8587}
8688
87- func (txn * transaction ) Commit (_ context.Context ) error {
88- if ! (atomic . CompareAndSwapInt32 (( * int32 )( & txn .state ), int32 (TxnOpen ), int32 (TxnCommitting )) ||
89- txn . state == TxnCommitting ) {
90- return newError (InvalidStatus , "Expect transaction state is TxnOpen but " + txn . state . string ( ))
89+ func (txn * transaction ) Commit (ctx context.Context ) error {
90+ if ! (txn .state . CompareAndSwap ( int32 (TxnOpen ), int32 (TxnCommitting ))) {
91+ txnState := txn . state . Load ()
92+ return newError (InvalidStatus , txnStateErrorMessage ( TxnOpen , TxnState ( txnState ) ))
9193 }
9294
93- //Wait for all operations to complete
95+ // Wait for all operations to complete
9496 select {
9597 case <- txn .opsFlow :
98+ case <- ctx .Done ():
99+ txn .state .Store (int32 (TxnOpen ))
100+ return ctx .Err ()
96101 case <- time .After (txn .opTimeout ):
102+ txn .state .Store (int32 (TxnTimeout ))
97103 return newError (TimeoutError , "There are some operations that are not completed after the timeout." )
98104 }
99- //Send commit transaction command to transaction coordinator
105+ // Send commit transaction command to transaction coordinator
100106 err := txn .tcClient .endTxn (& txn .txnID , pb .TxnAction_COMMIT )
101107 if err == nil {
102- atomic . StoreInt32 (( * int32 )( & txn .state ), int32 (TxnCommitted ))
108+ txn .state . Store ( int32 (TxnCommitted ))
103109 } else {
104- if e , ok := err .(* Error ); ok && (e .Result () == TransactionNoFoundError || e .Result () == InvalidStatus ) {
105- atomic .StoreInt32 ((* int32 )(& txn .state ), int32 (TxnError ))
110+ var e * Error
111+ if errors .As (err , & e ) && (e .Result () == TransactionNoFoundError || e .Result () == InvalidStatus ) {
112+ txn .state .Store (int32 (TxnError ))
106113 return err
107114 }
108115 txn .opsFlow <- true
109116 }
110117 return err
111118}
112119
113- func (txn * transaction ) Abort (_ context.Context ) error {
114- if ! (atomic . CompareAndSwapInt32 (( * int32 )( & txn .state ), int32 (TxnOpen ), int32 (TxnAborting )) ||
115- txn . state == TxnAborting ) {
116- return newError (InvalidStatus , "Expect transaction state is TxnOpen but " + txn . state . string ( ))
120+ func (txn * transaction ) Abort (ctx context.Context ) error {
121+ if ! (txn .state . CompareAndSwap ( int32 (TxnOpen ), int32 (TxnAborting ))) {
122+ txnState := txn . state . Load ()
123+ return newError (InvalidStatus , txnStateErrorMessage ( TxnOpen , TxnState ( txnState ) ))
117124 }
118125
119- //Wait for all operations to complete
126+ // Wait for all operations to complete
120127 select {
121128 case <- txn .opsFlow :
129+ case <- ctx .Done ():
130+ txn .state .Store (int32 (TxnOpen ))
131+ return ctx .Err ()
122132 case <- time .After (txn .opTimeout ):
133+ txn .state .Store (int32 (TxnTimeout ))
123134 return newError (TimeoutError , "There are some operations that are not completed after the timeout." )
124135 }
125- //Send abort transaction command to transaction coordinator
136+ // Send abort transaction command to transaction coordinator
126137 err := txn .tcClient .endTxn (& txn .txnID , pb .TxnAction_ABORT )
127138 if err == nil {
128- atomic . StoreInt32 (( * int32 )( & txn .state ), int32 (TxnAborted ))
139+ txn .state . Store ( int32 (TxnAborted ))
129140 } else {
130- if e , ok := err .( * Error ); ok && ( e . Result () == TransactionNoFoundError || e . Result () == InvalidStatus ) {
131- atomic . StoreInt32 (( * int32 )( & txn . state ), int32 ( TxnError ))
132- } else {
133- txn . opsFlow <- true
141+ var e * Error
142+ if errors . As ( err , & e ) && ( e . Result () == TransactionNoFoundError || e . Result () == InvalidStatus ) {
143+ txn . state . Store ( int32 ( TxnError ))
144+ return err
134145 }
146+ txn .opsFlow <- true
135147 }
136148 return err
137149}
138150
139151func (txn * transaction ) registerSendOrAckOp () error {
140- if atomic . AddInt32 ( & txn .opsCount , 1 ) == 1 {
141- //There are new operations that not completed
152+ if txn .opsCount . Add ( 1 ) == 1 {
153+ // There are new operations that were not completed
142154 select {
143155 case <- txn .opsFlow :
144156 return nil
145157 case <- time .After (txn .opTimeout ):
146- if _ , err := txn .checkIfOpen (); err != nil {
158+ if err := txn .verifyOpen (); err != nil {
147159 return err
148160 }
149161 return newError (TimeoutError , "Failed to get the semaphore to register the send/ack operation" )
@@ -154,23 +166,22 @@ func (txn *transaction) registerSendOrAckOp() error {
154166
155167func (txn * transaction ) endSendOrAckOp (err error ) {
156168 if err != nil {
157- atomic . StoreInt32 (( * int32 )( & txn .state ), int32 (TxnError ))
169+ txn .state . Store ( int32 (TxnError ))
158170 }
159- if atomic . AddInt32 ( & txn .opsCount , - 1 ) == 0 {
160- //This means there are not pending send/ack requests
171+ if txn .opsCount . Add ( - 1 ) == 0 {
172+ // This means there are no pending send/ack requests
161173 txn .opsFlow <- true
162174 }
163175}
164176
165177func (txn * transaction ) registerProducerTopic (topic string ) error {
166- isOpen , err := txn .checkIfOpen ()
167- if ! isOpen {
178+ if err := txn .verifyOpen (); err != nil {
168179 return err
169180 }
170181 _ , ok := txn .registerPartitions [topic ]
171182 if ! ok {
172- txn .Lock ()
173- defer txn .Unlock ()
183+ txn .mu . Lock ()
184+ defer txn .mu . Unlock ()
174185 if _ , ok = txn .registerPartitions [topic ]; ! ok {
175186 err := txn .tcClient .addPublishPartitionToTxn (& txn .txnID , []string {topic })
176187 if err != nil {
@@ -183,8 +194,7 @@ func (txn *transaction) registerProducerTopic(topic string) error {
183194}
184195
185196func (txn * transaction ) registerAckTopic (topic string , subName string ) error {
186- isOpen , err := txn .checkIfOpen ()
187- if ! isOpen {
197+ if err := txn .verifyOpen (); err != nil {
188198 return err
189199 }
190200 sub := subscription {
@@ -193,8 +203,8 @@ func (txn *transaction) registerAckTopic(topic string, subName string) error {
193203 }
194204 _ , ok := txn .registerAckSubscriptions [sub ]
195205 if ! ok {
196- txn .Lock ()
197- defer txn .Unlock ()
206+ txn .mu . Lock ()
207+ defer txn .mu . Unlock ()
198208 if _ , ok = txn .registerAckSubscriptions [sub ]; ! ok {
199209 err := txn .tcClient .addSubscriptionToTxn (& txn .txnID , topic , subName )
200210 if err != nil {
@@ -210,14 +220,15 @@ func (txn *transaction) GetTxnID() TxnID {
210220 return txn .txnID
211221}
212222
213- func (txn * transaction ) checkIfOpen () (bool , error ) {
214- if txn .state == TxnOpen {
215- return true , nil
223+ func (txn * transaction ) verifyOpen () error {
224+ txnState := txn .state .Load ()
225+ if txnState != int32 (TxnOpen ) {
226+ return newError (InvalidStatus , txnStateErrorMessage (TxnOpen , TxnState (txnState )))
216227 }
217- return false , newError ( InvalidStatus , "Expect transaction state is TxnOpen but " + txn . state . string ())
228+ return nil
218229}
219230
220- func (state TxnState ) string () string {
231+ func (state TxnState ) String () string {
221232 switch state {
222233 case TxnOpen :
223234 return "TxnOpen"
@@ -237,3 +248,8 @@ func (state TxnState) string() string {
237248 return "Unknown"
238249 }
239250}
251+
252+ //nolint:unparam
253+ func txnStateErrorMessage (expected , actual TxnState ) string {
254+ return fmt .Sprintf ("Expected transaction state: %s, actual: %s" , expected , actual )
255+ }
0 commit comments