Skip to content

Commit c4993a5

Browse files
author
Benjamin Rewis
committed
Stop treating context errors as network errors where possible. (#1045)
GODRIVER-2468 GODRIVER-1965
1 parent 16c2ea0 commit c4993a5

File tree

6 files changed

+52
-60
lines changed

6 files changed

+52
-60
lines changed

mongo/session.go

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -200,10 +200,6 @@ func (s *sessionImpl) WithTransaction(ctx context.Context, fn func(sessCtx Sessi
200200
default:
201201
}
202202

203-
// End if context has timed out or been canceled, as retrying has no chance of success.
204-
if ctx.Err() != nil {
205-
return res, err
206-
}
207203
if errorHasLabel(err, driver.TransientTransactionError) {
208204
continue
209205
}
@@ -218,10 +214,9 @@ func (s *sessionImpl) WithTransaction(ctx context.Context, fn func(sessCtx Sessi
218214
CommitLoop:
219215
for {
220216
err = s.CommitTransaction(ctx)
221-
// End when error is nil (transaction has been committed), or when context has timed out or been
222-
// canceled, as retrying has no chance of success.
223-
if err == nil || ctx.Err() != nil {
224-
return res, err
217+
// End when error is nil, as transaction has been committed.
218+
if err == nil {
219+
return res, nil
225220
}
226221

227222
select {

x/mongo/driver/operation.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -529,15 +529,17 @@ func (op Operation) Execute(ctx context.Context, scratch []byte) error {
529529
serviceID: startedInfo.serviceID,
530530
}
531531

532-
// Check if there's enough time to perform a round trip before the Context deadline. If ctx is
533-
// a Timeout Context, use the 90th percentile RTT as a threshold. Otherwise, use the minimum observed
534-
// RTT.
535-
if deadline, ok := ctx.Deadline(); ok {
532+
// Check for possible context error. If no context error, check if there's enough time to perform a
533+
// round trip before the Context deadline. If ctx is a Timeout Context, use the 90th percentile RTT
534+
// as a threshold. Otherwise, use the minimum observed RTT.
535+
if ctx.Err() != nil {
536+
err = ctx.Err()
537+
} else if deadline, ok := ctx.Deadline(); ok {
536538
if internal.IsTimeoutContext(ctx) && time.Now().Add(srvr.RTT90()).After(deadline) {
537539
err = internal.WrapErrorf(ErrDeadlineWouldBeExceeded,
538540
"Remaining timeout %v applied from Timeout is less than 90th percentile RTT", time.Until(deadline))
539541
} else if time.Now().Add(srvr.MinRTT()).After(deadline) {
540-
err = op.networkError(context.DeadlineExceeded)
542+
err = context.DeadlineExceeded
541543
}
542544
}
543545

x/mongo/driver/operation_test.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -605,6 +605,48 @@ func TestOperation(t *testing.T) {
605605
assert.Nil(t, err, "ExecuteExhaust error: %v", err)
606606
assert.True(t, conn.CurrentlyStreaming(), "expected CurrentlyStreaming to be true")
607607
})
608+
t.Run("context deadline exceeded not marked as TransientTransactionError", func(t *testing.T) {
609+
conn := new(mockConnection)
610+
// Create a context that's already timed out.
611+
ctx, cancel := context.WithDeadline(context.Background(), time.Unix(893934480, 0))
612+
defer cancel()
613+
614+
op := Operation{
615+
Database: "foobar",
616+
Deployment: SingleConnectionDeployment{C: conn},
617+
CommandFn: func(dst []byte, desc description.SelectedServer) ([]byte, error) {
618+
dst = bsoncore.AppendInt32Element(dst, "ping", 1)
619+
return dst, nil
620+
},
621+
}
622+
623+
err := op.Execute(ctx, nil)
624+
assert.NotNil(t, err, "expected an error from Execute(), got nil")
625+
// Assert that error is just context deadline exceeded and is therefore not a driver.Error marked
626+
// with the TransientTransactionError label.
627+
assert.Equal(t, err, context.DeadlineExceeded, "expected context.DeadlineExceeded error, got %v", err)
628+
})
629+
t.Run("canceled context not marked as TransientTransactionError", func(t *testing.T) {
630+
conn := new(mockConnection)
631+
// Create a context and cancel it immediately.
632+
ctx, cancel := context.WithCancel(context.Background())
633+
cancel()
634+
635+
op := Operation{
636+
Database: "foobar",
637+
Deployment: SingleConnectionDeployment{C: conn},
638+
CommandFn: func(dst []byte, desc description.SelectedServer) ([]byte, error) {
639+
dst = bsoncore.AppendInt32Element(dst, "ping", 1)
640+
return dst, nil
641+
},
642+
}
643+
644+
err := op.Execute(ctx, nil)
645+
assert.NotNil(t, err, "expected an error from Execute(), got nil")
646+
// Assert that error is just context canceled and is therefore not a driver.Error marked with
647+
// the TransientTransactionError label.
648+
assert.Equal(t, err, context.Canceled, "expected context.Canceled error, got %v", err)
649+
})
608650
}
609651

610652
func createExhaustServerResponse(response bsoncore.Document, moreToCome bool) []byte {

x/mongo/driver/topology/connection.go

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -331,11 +331,6 @@ func (c *connection) writeWireMessage(ctx context.Context, wm []byte) error {
331331
if atomic.LoadInt64(&c.state) != connConnected {
332332
return ConnectionError{ConnectionID: c.id, message: "connection is closed"}
333333
}
334-
select {
335-
case <-ctx.Done():
336-
return ConnectionError{ConnectionID: c.id, Wrapped: ctx.Err(), message: "failed to write"}
337-
default:
338-
}
339334

340335
var deadline time.Time
341336
if c.writeTimeout != 0 {
@@ -388,14 +383,6 @@ func (c *connection) readWireMessage(ctx context.Context, dst []byte) ([]byte, e
388383
return dst, ConnectionError{ConnectionID: c.id, message: "connection is closed"}
389384
}
390385

391-
select {
392-
case <-ctx.Done():
393-
// We closeConnection the connection because we don't know if there is an unread message on the wire.
394-
c.close()
395-
return nil, ConnectionError{ConnectionID: c.id, Wrapped: ctx.Err(), message: "failed to read"}
396-
default:
397-
}
398-
399386
var deadline time.Time
400387
if c.readTimeout != 0 {
401388
deadline = time.Now().Add(c.readTimeout)

x/mongo/driver/topology/connection_errors_test.go

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -50,19 +50,5 @@ func TestConnectionErrors(t *testing.T) {
5050
err := conn.connect(ctx)
5151
assert.True(t, errors.Is(err, context.Canceled), "expected error %v, got %v", context.Canceled, err)
5252
})
53-
t.Run("write error", func(t *testing.T) {
54-
ctx, cancel := context.WithCancel(context.Background())
55-
cancel()
56-
conn := &connection{id: "foobar", nc: &net.TCPConn{}, state: connConnected}
57-
err := conn.writeWireMessage(ctx, []byte{})
58-
assert.True(t, errors.Is(err, context.Canceled), "expected error %v, got %v", context.Canceled, err)
59-
})
60-
t.Run("read error", func(t *testing.T) {
61-
ctx, cancel := context.WithCancel(context.Background())
62-
cancel()
63-
conn := &connection{id: "foobar", nc: &net.TCPConn{}, state: connConnected}
64-
_, err := conn.readWireMessage(ctx, []byte{})
65-
assert.True(t, errors.Is(err, context.Canceled), "expected error %v, got %v", context.Canceled, err)
66-
})
6753
})
6854
}

x/mongo/driver/topology/connection_test.go

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -351,16 +351,6 @@ func TestConnection(t *testing.T) {
351351
t.Errorf("errors do not match. got %v; want %v", got, want)
352352
}
353353
})
354-
t.Run("completed context", func(t *testing.T) {
355-
ctx, cancel := context.WithCancel(context.Background())
356-
cancel()
357-
conn := &connection{id: "foobar", nc: &net.TCPConn{}, state: connConnected}
358-
want := ConnectionError{ConnectionID: "foobar", Wrapped: ctx.Err(), message: "failed to write"}
359-
got := conn.writeWireMessage(ctx, []byte{})
360-
if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
361-
t.Errorf("errors do not match. got %v; want %v", got, want)
362-
}
363-
})
364354
t.Run("deadlines", func(t *testing.T) {
365355
testCases := []struct {
366356
name string
@@ -490,16 +480,6 @@ func TestConnection(t *testing.T) {
490480
t.Errorf("errors do not match. got %v; want %v", got, want)
491481
}
492482
})
493-
t.Run("completed context", func(t *testing.T) {
494-
ctx, cancel := context.WithCancel(context.Background())
495-
cancel()
496-
conn := &connection{id: "foobar", nc: &net.TCPConn{}, state: connConnected}
497-
want := ConnectionError{ConnectionID: "foobar", Wrapped: ctx.Err(), message: "failed to read"}
498-
_, got := conn.readWireMessage(ctx, []byte{})
499-
if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
500-
t.Errorf("errors do not match. got %v; want %v", got, want)
501-
}
502-
})
503483
t.Run("deadlines", func(t *testing.T) {
504484
testCases := []struct {
505485
name string

0 commit comments

Comments
 (0)