Skip to content

Commit 935fe3f

Browse files
author
Divjot Arora
authored
GODRIVER-1394 Export TransactionState from session.Client (#518)
1 parent 15ac5cf commit 935fe3f

File tree

2 files changed

+55
-37
lines changed

2 files changed

+55
-37
lines changed

x/mongo/driver/session/client_session.go

Lines changed: 42 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -49,18 +49,36 @@ const (
4949
Implicit
5050
)
5151

52-
// State indicates the state of the FSM.
53-
type state uint8
52+
// TransactionState indicates the state of the transactions FSM.
53+
type TransactionState uint8
5454

5555
// Client Session states
5656
const (
57-
None state = iota
57+
None TransactionState = iota
5858
Starting
5959
InProgress
6060
Committed
6161
Aborted
6262
)
6363

64+
// String implements the fmt.Stringer interface.
65+
func (s TransactionState) String() string {
66+
switch s {
67+
case None:
68+
return "none"
69+
case Starting:
70+
return "starting"
71+
case InProgress:
72+
return "in progress"
73+
case Committed:
74+
return "committed"
75+
case Aborted:
76+
return "aborted"
77+
default:
78+
return "unknown"
79+
}
80+
}
81+
6482
// Client is a session for clients to run commands.
6583
type Client struct {
6684
*Server
@@ -89,10 +107,10 @@ type Client struct {
89107
transactionWc *writeconcern.WriteConcern
90108
transactionMaxCommitTime *time.Duration
91109

92-
pool *Pool
93-
state state
94-
PinnedServer *description.Server
95-
RecoveryToken bson.Raw
110+
pool *Pool
111+
TransactionState TransactionState
112+
PinnedServer *description.Server
113+
RecoveryToken bson.Raw
96114
}
97115

98116
func getClusterTime(clusterTime bson.Raw) (uint32, uint32) {
@@ -242,29 +260,29 @@ func (c *Client) EndSession() {
242260

243261
// TransactionInProgress returns true if the client session is in an active transaction.
244262
func (c *Client) TransactionInProgress() bool {
245-
return c.state == InProgress
263+
return c.TransactionState == InProgress
246264
}
247265

248266
// TransactionStarting returns true if the client session is starting a transaction.
249267
func (c *Client) TransactionStarting() bool {
250-
return c.state == Starting
268+
return c.TransactionState == Starting
251269
}
252270

253271
// TransactionRunning returns true if the client session has started the transaction
254272
// and it hasn't been committed or aborted
255273
func (c *Client) TransactionRunning() bool {
256-
return c != nil && (c.state == Starting || c.state == InProgress)
274+
return c != nil && (c.TransactionState == Starting || c.TransactionState == InProgress)
257275
}
258276

259277
// TransactionCommitted returns true of the client session just committed a transaciton.
260278
func (c *Client) TransactionCommitted() bool {
261-
return c.state == Committed
279+
return c.TransactionState == Committed
262280
}
263281

264282
// CheckStartTransaction checks to see if allowed to start transaction and returns
265283
// an error if not allowed
266284
func (c *Client) CheckStartTransaction() error {
267-
if c.state == InProgress || c.state == Starting {
285+
if c.TransactionState == InProgress || c.TransactionState == Starting {
268286
return ErrTransactInProgress
269287
}
270288
return nil
@@ -309,17 +327,17 @@ func (c *Client) StartTransaction(opts *TransactionOptions) error {
309327
return ErrUnackWCUnsupported
310328
}
311329

312-
c.state = Starting
330+
c.TransactionState = Starting
313331
c.PinnedServer = nil
314332
return nil
315333
}
316334

317335
// CheckCommitTransaction checks to see if allowed to commit transaction and returns
318336
// an error if not allowed.
319337
func (c *Client) CheckCommitTransaction() error {
320-
if c.state == None {
338+
if c.TransactionState == None {
321339
return ErrNoTransactStarted
322-
} else if c.state == Aborted {
340+
} else if c.TransactionState == Aborted {
323341
return ErrCommitAfterAbort
324342
}
325343
return nil
@@ -332,7 +350,7 @@ func (c *Client) CommitTransaction() error {
332350
if err != nil {
333351
return err
334352
}
335-
c.state = Committed
353+
c.TransactionState = Committed
336354
return nil
337355
}
338356

@@ -351,11 +369,11 @@ func (c *Client) UpdateCommitTransactionWriteConcern() {
351369
// CheckAbortTransaction checks to see if allowed to abort transaction and returns
352370
// an error if not allowed.
353371
func (c *Client) CheckAbortTransaction() error {
354-
if c.state == None {
372+
if c.TransactionState == None {
355373
return ErrNoTransactStarted
356-
} else if c.state == Committed {
374+
} else if c.TransactionState == Committed {
357375
return ErrAbortAfterCommit
358-
} else if c.state == Aborted {
376+
} else if c.TransactionState == Aborted {
359377
return ErrAbortTwice
360378
}
361379
return nil
@@ -368,7 +386,7 @@ func (c *Client) AbortTransaction() error {
368386
if err != nil {
369387
return err
370388
}
371-
c.state = Aborted
389+
c.TransactionState = Aborted
372390
c.clearTransactionOpts()
373391
return nil
374392
}
@@ -379,15 +397,15 @@ func (c *Client) ApplyCommand(desc description.Server) {
379397
// Do not change state if committing after already committed
380398
return
381399
}
382-
if c.state == Starting {
383-
c.state = InProgress
400+
if c.TransactionState == Starting {
401+
c.TransactionState = InProgress
384402
// If this is in a transaction and the server is a mongos, pin it
385403
if desc.Kind == description.Mongos {
386404
c.PinnedServer = &desc
387405
}
388-
} else if c.state == Committed || c.state == Aborted {
406+
} else if c.TransactionState == Committed || c.TransactionState == Aborted {
389407
c.clearTransactionOpts()
390-
c.state = None
408+
c.TransactionState = None
391409
}
392410
}
393411

x/mongo/driver/session/client_session_test.go

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import (
1212

1313
"github.com/stretchr/testify/require"
1414
"go.mongodb.org/mongo-driver/bson/primitive"
15-
"go.mongodb.org/mongo-driver/internal/testutil/helpers"
15+
testhelpers "go.mongodb.org/mongo-driver/internal/testutil/helpers"
1616
"go.mongodb.org/mongo-driver/mongo/description"
1717
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
1818
"go.mongodb.org/mongo-driver/x/mongo/driver/uuid"
@@ -134,14 +134,14 @@ func TestClientSession(t *testing.T) {
134134
t.Errorf("expected error, got %v", err)
135135
}
136136

137-
if sess.state != None {
138-
t.Errorf("incorrect session state, expected None, received %v", sess.state)
137+
if sess.TransactionState != None {
138+
t.Errorf("incorrect session state, expected None, received %v", sess.TransactionState)
139139
}
140140

141141
err = sess.StartTransaction(nil)
142142
require.Nil(t, err, "error starting transaction: %s", err)
143-
if sess.state != Starting {
144-
t.Errorf("incorrect session state, expected Starting, received %v", sess.state)
143+
if sess.TransactionState != Starting {
144+
t.Errorf("incorrect session state, expected Starting, received %v", sess.TransactionState)
145145
}
146146

147147
err = sess.StartTransaction(nil)
@@ -150,8 +150,8 @@ func TestClientSession(t *testing.T) {
150150
}
151151

152152
sess.ApplyCommand(description.Server{Kind: description.Standalone})
153-
if sess.state != InProgress {
154-
t.Errorf("incorrect session state, expected InProgress, received %v", sess.state)
153+
if sess.TransactionState != InProgress {
154+
t.Errorf("incorrect session state, expected InProgress, received %v", sess.TransactionState)
155155
}
156156

157157
err = sess.StartTransaction(nil)
@@ -161,8 +161,8 @@ func TestClientSession(t *testing.T) {
161161

162162
err = sess.CommitTransaction()
163163
require.Nil(t, err, "error committing transaction: %s", err)
164-
if sess.state != Committed {
165-
t.Errorf("incorrect session state, expected Committed, received %v", sess.state)
164+
if sess.TransactionState != Committed {
165+
t.Errorf("incorrect session state, expected Committed, received %v", sess.TransactionState)
166166
}
167167

168168
err = sess.AbortTransaction()
@@ -172,14 +172,14 @@ func TestClientSession(t *testing.T) {
172172

173173
err = sess.StartTransaction(nil)
174174
require.Nil(t, err, "error starting transaction: %s", err)
175-
if sess.state != Starting {
176-
t.Errorf("incorrect session state, expected Starting, received %v", sess.state)
175+
if sess.TransactionState != Starting {
176+
t.Errorf("incorrect session state, expected Starting, received %v", sess.TransactionState)
177177
}
178178

179179
err = sess.AbortTransaction()
180180
require.Nil(t, err, "error aborting transaction: %s", err)
181-
if sess.state != Aborted {
182-
t.Errorf("incorrect session state, expected Aborted, received %v", sess.state)
181+
if sess.TransactionState != Aborted {
182+
t.Errorf("incorrect session state, expected Aborted, received %v", sess.TransactionState)
183183
}
184184

185185
err = sess.AbortTransaction()

0 commit comments

Comments
 (0)