Skip to content

Commit 83192a7

Browse files
authored
Merge pull request #15 from digitalocean/mdl-context
ovsdb: implement cancelation via context
2 parents 836fee0 + 0932e22 commit 83192a7

File tree

7 files changed

+157
-43
lines changed

7 files changed

+157
-43
lines changed

ovsdb/client.go

Lines changed: 33 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
package ovsdb
1616

1717
import (
18-
"encoding/json"
18+
"context"
1919
"fmt"
2020
"io"
2121
"log"
@@ -42,11 +42,6 @@ type Client struct {
4242
wg *sync.WaitGroup
4343
}
4444

45-
type rpcResponse struct {
46-
Result json.RawMessage
47-
Error error
48-
}
49-
5045
// An OptionFunc is a function which can configure a Client.
5146
type OptionFunc func(c *Client) error
5247

@@ -115,16 +110,21 @@ func (c *Client) Close() error {
115110
}
116111

117112
// rpc performs a single RPC request, and checks the response for errors.
118-
func (c *Client) rpc(method string, out interface{}, args ...interface{}) error {
113+
func (c *Client) rpc(ctx context.Context, method string, out interface{}, args ...interface{}) error {
114+
// Was the context canceled before sending the RPC?
115+
select {
116+
case <-ctx.Done():
117+
return ctx.Err()
118+
default:
119+
}
120+
119121
// Unmarshal results into empty struct if no out specified.
120122
if out == nil {
121123
out = &struct{}{}
122124
}
123125

124126
// Captures any OVSDB errors.
125-
r := result{
126-
Reply: out,
127-
}
127+
r := result{Reply: out}
128128

129129
req := jsonrpc.Request{
130130
Method: method,
@@ -133,29 +133,23 @@ func (c *Client) rpc(method string, out interface{}, args ...interface{}) error
133133
}
134134

135135
// Add callback for this RPC ID to return results via channel.
136-
ch := make(chan rpcResponse, 0)
136+
ch := make(chan rpcResponse, 1)
137+
defer close(ch)
137138
c.addCallback(req.ID, ch)
138139

139140
if err := c.c.Send(req); err != nil {
140141
return err
141142
}
142143

143-
// Wait for callback to fire.
144-
res := <-ch
145-
if err := res.Error; err != nil {
146-
return err
144+
// Await RPC completion or cancelation.
145+
select {
146+
case <-ctx.Done():
147+
// RPC canceled; clean up callback.
148+
return c.cancelCallback(ctx, req.ID)
149+
case res := <-ch:
150+
// RPC complete.
151+
return rpcResult(res, &r)
147152
}
148-
149-
if err := json.Unmarshal(res.Result, &r); err != nil {
150-
return err
151-
}
152-
153-
// OVSDB server returned an error, return it.
154-
if r.Err != nil {
155-
return r.Err
156-
}
157-
158-
return nil
159153
}
160154

161155
// listen starts an RPC receive loop that can return RPC results to
@@ -216,12 +210,23 @@ func (c *Client) doCallback(id int, res rpcResponse) {
216210
return
217211
}
218212

219-
// Return result, clean up channel, and remove this callback.
213+
// Return result and remove this callback.
220214
ch <- res
221-
close(ch)
222215
delete(c.callbacks, id)
223216
}
224217

218+
// cancelCallback is invoked when an RPC is canceled by its context.
219+
func (c *Client) cancelCallback(ctx context.Context, id int) error {
220+
// RPC canceled; acquire the callback mutex and clean up the callback
221+
// for this RPC.
222+
c.cbMu.Lock()
223+
defer c.cbMu.Unlock()
224+
225+
delete(c.callbacks, id)
226+
227+
return ctx.Err()
228+
}
229+
225230
func panicf(format string, a ...interface{}) {
226231
panic(fmt.Sprintf(format, a...))
227232
}

ovsdb/client_integration_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
package ovsdb_test
1616

1717
import (
18+
"context"
1819
"fmt"
1920
"sync"
2021
"testing"
@@ -54,7 +55,7 @@ func TestClientIntegrationConcurrent(t *testing.T) {
5455
<-sigC
5556

5657
for j := 0; j < 4; j++ {
57-
_, err := c.ListDatabases()
58+
_, err := c.ListDatabases(context.Background())
5859
if err != nil {
5960
panic(fmt.Sprintf("failed to query concurrently: %v", err))
6061
}
@@ -72,7 +73,7 @@ func TestClientIntegrationConcurrent(t *testing.T) {
7273
}
7374

7475
func testClientDatabases(t *testing.T, c *ovsdb.Client) {
75-
dbs, err := c.ListDatabases()
76+
dbs, err := c.ListDatabases(context.Background())
7677
if err != nil {
7778
t.Fatalf("failed to list databases: %v", err)
7879
}

ovsdb/client_test.go

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@
1515
package ovsdb_test
1616

1717
import (
18+
"context"
1819
"encoding/json"
1920
"fmt"
2021
"log"
2122
"os"
2223
"testing"
24+
"time"
2325

2426
"github.com/digitalocean/go-openvswitch/ovsdb"
2527
"github.com/digitalocean/go-openvswitch/ovsdb/internal/jsonrpc"
@@ -37,7 +39,7 @@ func TestClientJSONRPCError(t *testing.T) {
3739
})
3840
defer done()
3941

40-
_, err := c.ListDatabases()
42+
_, err := c.ListDatabases(context.Background())
4143
if err == nil {
4244
t.Fatal("expected an error, but none occurred")
4345
}
@@ -58,7 +60,7 @@ func TestClientOVSDBError(t *testing.T) {
5860
})
5961
defer done()
6062

61-
_, err := c.ListDatabases()
63+
_, err := c.ListDatabases(context.Background())
6264
if err == nil {
6365
t.Fatal("expected an error, but none occurred")
6466
}
@@ -88,11 +90,60 @@ func TestClientBadCallback(t *testing.T) {
8890
ID: intPtr(10),
8991
}
9092

91-
if _, err := c.ListDatabases(); err != nil {
93+
if _, err := c.ListDatabases(context.Background()); err != nil {
9294
t.Fatalf("unexpected error: %v", err)
9395
}
9496
}
9597

98+
func TestClientContextCancelBeforeRPC(t *testing.T) {
99+
// Context canceled before RPC even begins.
100+
ctx, cancel := context.WithCancel(context.Background())
101+
cancel()
102+
103+
c, _, done := testClient(t, func(_ jsonrpc.Request) jsonrpc.Response {
104+
return jsonrpc.Response{
105+
ID: intPtr(1),
106+
Result: mustMarshalJSON(t, []string{"foo"}),
107+
}
108+
})
109+
defer done()
110+
111+
_, err := c.ListDatabases(ctx)
112+
if err != context.Canceled {
113+
t.Fatalf("expected context canceled error: %v", err)
114+
}
115+
}
116+
117+
func TestClientContextCancelDuringRPC(t *testing.T) {
118+
if testing.Short() {
119+
t.Skip("skipping during short test run")
120+
}
121+
122+
// Context canceled during long RPC.
123+
ctx, cancel := context.WithCancel(context.Background())
124+
defer cancel()
125+
126+
c, _, done := testClient(t, func(_ jsonrpc.Request) jsonrpc.Response {
127+
// RPC canceled; RPC server still processing.
128+
// TODO(mdlayher): try to do something smarter than sleeping in a test.
129+
cancel()
130+
<-ctx.Done()
131+
132+
time.Sleep(500 * time.Millisecond)
133+
134+
return jsonrpc.Response{
135+
ID: intPtr(1),
136+
Result: mustMarshalJSON(t, []string{"foo"}),
137+
}
138+
})
139+
defer done()
140+
141+
_, err := c.ListDatabases(ctx)
142+
if err != context.Canceled {
143+
t.Fatalf("expected context canceled error: %v", err)
144+
}
145+
}
146+
96147
func testClient(t *testing.T, fn jsonrpc.TestFunc) (*ovsdb.Client, chan<- *jsonrpc.Response, func()) {
97148
t.Helper()
98149

ovsdb/internal/jsonrpc/testconn.go

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ func TestNetConn(t *testing.T, fn TestFunc) (net.Conn, chan<- *Response, func())
5757
}
5858

5959
var wg sync.WaitGroup
60-
wg.Add(2)
60+
wg.Add(1)
6161

6262
notifC := make(chan *Response, 16)
6363

@@ -67,7 +67,7 @@ func TestNetConn(t *testing.T, fn TestFunc) (net.Conn, chan<- *Response, func())
6767
// Accept a single connection.
6868
c, err := l.Accept()
6969
if err != nil {
70-
if strings.Contains(err.Error(), "use of closed network") {
70+
if isNetworkCloseError(err) {
7171
return
7272
}
7373

@@ -76,14 +76,28 @@ func TestNetConn(t *testing.T, fn TestFunc) (net.Conn, chan<- *Response, func())
7676
defer c.Close()
7777

7878
dec := json.NewDecoder(c)
79+
80+
var encMu sync.RWMutex
7981
enc := json.NewEncoder(c)
8082

8183
// Push RPC notifications to the client.
84+
var notifWG sync.WaitGroup
85+
notifWG.Add(1)
86+
defer notifWG.Wait()
87+
8288
go func() {
83-
defer wg.Done()
89+
defer notifWG.Done()
8490

8591
for n := range notifC {
86-
if err := enc.Encode(n); err != nil {
92+
encMu.Lock()
93+
err := enc.Encode(n)
94+
encMu.Unlock()
95+
96+
if err != nil {
97+
if isNetworkCloseError(err) {
98+
return
99+
}
100+
87101
panicf("failed to encode notification: %v", err)
88102
}
89103
}
@@ -93,15 +107,24 @@ func TestNetConn(t *testing.T, fn TestFunc) (net.Conn, chan<- *Response, func())
93107
for {
94108
var req Request
95109
if err := dec.Decode(&req); err != nil {
96-
if err == io.EOF {
110+
if isNetworkCloseError(err) {
97111
return
98112
}
99113

100-
panicf("failed to decode request: %v", err)
114+
panicf("failed to decode request: %#v", err)
101115
}
102116

103117
res := fn(req)
104-
if err := enc.Encode(res); err != nil {
118+
119+
encMu.Lock()
120+
err := enc.Encode(res)
121+
encMu.Unlock()
122+
123+
if err != nil {
124+
if isNetworkCloseError(err) {
125+
return
126+
}
127+
105128
panicf("failed to encode response: %v", err)
106129
}
107130
}
@@ -124,3 +147,9 @@ func TestNetConn(t *testing.T, fn TestFunc) (net.Conn, chan<- *Response, func())
124147
func panicf(format string, a ...interface{}) {
125148
panic(fmt.Sprintf(format, a...))
126149
}
150+
151+
func isNetworkCloseError(err error) bool {
152+
return err == io.EOF ||
153+
strings.Contains(err.Error(), "use of closed network") ||
154+
strings.Contains(err.Error(), "connection reset by peer")
155+
}

ovsdb/result.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,31 @@ type result struct {
2626
Err *Error
2727
}
2828

29+
// An rpcResponse is a response used in RPC callbacks.
30+
type rpcResponse struct {
31+
Result json.RawMessage
32+
Error error
33+
}
34+
35+
// rpcResult handles any errors from an rpcResponse and unmarshals results into
36+
// a result.
37+
func rpcResult(res rpcResponse, r *result) error {
38+
if err := res.Error; err != nil {
39+
return err
40+
}
41+
42+
if err := json.Unmarshal(res.Result, &r); err != nil {
43+
return err
44+
}
45+
46+
// OVSDB server returned an error, return it.
47+
if r.Err != nil {
48+
return r.Err
49+
}
50+
51+
return nil
52+
}
53+
2954
// errPrefix is a prefix that occurs if an error is present in a JSON-RPC response.
3055
var errPrefix = []byte(`{"error":`)
3156

ovsdb/rpc.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414

1515
package ovsdb
1616

17+
import "context"
18+
1719
// ListDatabases returns the name of all databases known to the OVSDB server.
18-
func (c *Client) ListDatabases() ([]string, error) {
20+
func (c *Client) ListDatabases(ctx context.Context) ([]string, error) {
1921
var dbs []string
20-
if err := c.rpc("list_dbs", &dbs); err != nil {
22+
if err := c.rpc(ctx, "list_dbs", &dbs); err != nil {
2123
return nil, err
2224
}
2325

ovsdb/rpc_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
package ovsdb_test
1616

1717
import (
18+
"context"
1819
"testing"
1920

2021
"github.com/digitalocean/go-openvswitch/ovsdb/internal/jsonrpc"
@@ -40,7 +41,7 @@ func TestClientListDatabases(t *testing.T) {
4041
})
4142
defer done()
4243

43-
dbs, err := c.ListDatabases()
44+
dbs, err := c.ListDatabases(context.Background())
4445
if err != nil {
4546
t.Fatalf("failed to list databases: %v", err)
4647
}

0 commit comments

Comments
 (0)