Skip to content

Commit 780e539

Browse files
committed
ovsdb: add internal receive loop for notifications and callbacks
1 parent 1aa633e commit 780e539

File tree

3 files changed

+240
-26
lines changed

3 files changed

+240
-26
lines changed

ovsdb/client.go

Lines changed: 135 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,12 @@ import (
1818
"bytes"
1919
"encoding/json"
2020
"fmt"
21+
"io"
2122
"log"
2223
"net"
24+
"strings"
25+
"sync"
26+
"sync/atomic"
2327

2428
"github.com/digitalocean/go-openvswitch/ovsdb/internal/jsonrpc"
2529
)
@@ -28,6 +32,20 @@ import (
2832
type Client struct {
2933
c *jsonrpc.Conn
3034
ll *log.Logger
35+
36+
// Incremented atomically when sending RPCs.
37+
rpcID *int64
38+
39+
// Callbacks for RPC responses.
40+
cbMu sync.RWMutex
41+
callbacks map[int]chan rpcResponse
42+
43+
wg *sync.WaitGroup
44+
}
45+
46+
type rpcResponse struct {
47+
Result json.RawMessage
48+
Error error
3149
}
3250

3351
// An OptionFunc is a function which can configure a Client.
@@ -60,14 +78,41 @@ func New(conn net.Conn, options ...OptionFunc) (*Client, error) {
6078
}
6179
}
6280

81+
// Set up RPC request IDs.
82+
var rpcID int64
83+
client.rpcID = &rpcID
84+
85+
// Set up the JSON-RPC connection.
6386
client.c = jsonrpc.NewConn(conn, client.ll)
6487

88+
// Set up callbacks.
89+
client.callbacks = make(map[int]chan rpcResponse)
90+
91+
// Start up any background routines.
92+
var wg sync.WaitGroup
93+
wg.Add(1)
94+
95+
// Handle all incoming RPC responses and notifications.
96+
go func() {
97+
defer wg.Done()
98+
client.listen()
99+
}()
100+
101+
client.wg = &wg
102+
65103
return client, nil
66104
}
67105

106+
// requestID returns the next available request ID for an RPC.
107+
func (c *Client) requestID() int {
108+
return int(atomic.AddInt64(c.rpcID, 1))
109+
}
110+
68111
// Close closes a Client's connection.
69112
func (c *Client) Close() error {
70-
return c.c.Close()
113+
err := c.c.Close()
114+
c.wg.Wait()
115+
return err
71116
}
72117

73118
// ListDatabases returns the name of all databases known to the OVSDB server.
@@ -82,6 +127,11 @@ func (c *Client) ListDatabases() ([]string, error) {
82127

83128
// rpc performs a single RPC request, and checks the response for errors.
84129
func (c *Client) rpc(method string, out interface{}, args ...interface{}) error {
130+
// Unmarshal results into empty struct if no out specified.
131+
if out == nil {
132+
out = &struct{}{}
133+
}
134+
85135
// Captures any OVSDB errors.
86136
r := result{
87137
Reply: out,
@@ -90,10 +140,24 @@ func (c *Client) rpc(method string, out interface{}, args ...interface{}) error
90140
req := jsonrpc.Request{
91141
Method: method,
92142
Params: args,
93-
// Let the client handle the request ID.
143+
ID: c.requestID(),
144+
}
145+
146+
// Add callback for this RPC ID to return results via channel.
147+
ch := make(chan rpcResponse, 0)
148+
c.addCallback(req.ID, ch)
149+
150+
if err := c.c.Send(req); err != nil {
151+
return err
152+
}
153+
154+
// Wait for callback to fire.
155+
res := <-ch
156+
if err := res.Error; err != nil {
157+
return err
94158
}
95159

96-
if err := c.c.Execute(req, &r); err != nil {
160+
if err := json.Unmarshal(res.Result, &r); err != nil {
97161
return err
98162
}
99163

@@ -105,6 +169,70 @@ func (c *Client) rpc(method string, out interface{}, args ...interface{}) error
105169
return nil
106170
}
107171

172+
// listen starts an RPC receive loop that can return RPC results to
173+
// clients via a callback.
174+
func (c *Client) listen() {
175+
for {
176+
res, err := c.c.Receive()
177+
if err != nil {
178+
// EOF or closed connection means time to stop serving.
179+
if err == io.EOF || strings.Contains(err.Error(), "use of closed network") {
180+
return
181+
}
182+
183+
// For any other connection errors, just keep trying.
184+
continue
185+
}
186+
187+
// TODO(mdlayher): deal with RPC notifications.
188+
189+
// Handle any JSON-RPC top-level errors.
190+
if err := res.Err(); err != nil {
191+
c.doCallback(*res.ID, rpcResponse{
192+
Error: err,
193+
})
194+
continue
195+
}
196+
197+
// Return RPC results via callback.
198+
c.doCallback(*res.ID, rpcResponse{
199+
Result: res.Result,
200+
})
201+
}
202+
}
203+
204+
// addCallback registers a callback for an RPC response for the specified ID,
205+
// and accepts a channel to return the results on.
206+
func (c *Client) addCallback(id int, ch chan rpcResponse) {
207+
c.cbMu.Lock()
208+
defer c.cbMu.Unlock()
209+
210+
if _, ok := c.callbacks[id]; ok {
211+
// This ID was already registered.
212+
panicf("OVSDB callback with ID %d already registered", id)
213+
}
214+
215+
c.callbacks[id] = ch
216+
}
217+
218+
// doCallback performs a callback for an RPC response and clears the
219+
// callback on completion.
220+
func (c *Client) doCallback(id int, res rpcResponse) {
221+
c.cbMu.Lock()
222+
defer c.cbMu.Unlock()
223+
224+
ch, ok := c.callbacks[id]
225+
if !ok {
226+
// Nobody is listening to this callback.
227+
panicf("OVSDB callback with ID %d has no listeners", id)
228+
}
229+
230+
// Return result, clean up channel, and remove this callback.
231+
ch <- res
232+
close(ch)
233+
delete(c.callbacks, id)
234+
}
235+
108236
// A result is used to unmarshal JSON-RPC results, and to check for any errors.
109237
type result struct {
110238
Reply interface{}
@@ -143,3 +271,7 @@ type Error struct {
143271
func (e *Error) Error() string {
144272
return fmt.Sprintf("%s: %s: %s", e.Err, e.Details, e.Syntax)
145273
}
274+
275+
func panicf(format string, a ...interface{}) {
276+
panic(fmt.Sprintf(format, a...))
277+
}

ovsdb/client_integration_test.go

Lines changed: 55 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,31 +15,62 @@
1515
package ovsdb_test
1616

1717
import (
18-
"os"
18+
"fmt"
19+
"sync"
1920
"testing"
2021

2122
"github.com/digitalocean/go-openvswitch/ovsdb"
2223
"github.com/google/go-cmp/cmp"
2324
)
2425

2526
func TestClientIntegration(t *testing.T) {
26-
// Assume the standard Linux location for the socket.
27-
const sock = "/var/run/openvswitch/db.sock"
28-
if _, err := os.Open(sock); err != nil {
29-
t.Skipf("could not access %q: %v", sock, err)
30-
}
31-
32-
c, err := ovsdb.Dial("unix", sock)
33-
if err != nil {
34-
t.Fatalf("failed to dial: %v", err)
35-
}
27+
c := dialOVSDB(t)
3628
defer c.Close()
3729

3830
t.Run("databases", func(t *testing.T) {
3931
testClientDatabases(t, c)
4032
})
4133
}
4234

35+
func TestClientIntegrationConcurrent(t *testing.T) {
36+
c := dialOVSDB(t)
37+
defer c.Close()
38+
39+
const n = 512
40+
41+
// Wait for all goroutines to start before performing RPCs,
42+
// wait for them all to exit before ending the test.
43+
var startWG, doneWG sync.WaitGroup
44+
startWG.Add(n)
45+
doneWG.Add(n)
46+
47+
// Block all goroutines until they're done spinning up.
48+
sigC := make(chan struct{}, 0)
49+
50+
for i := 0; i < n; i++ {
51+
go func(c *ovsdb.Client) {
52+
// Block goroutines until all are spun up.
53+
startWG.Done()
54+
<-sigC
55+
56+
for j := 0; j < 4; j++ {
57+
_, err := c.ListDatabases()
58+
if err != nil {
59+
panic(fmt.Sprintf("failed to query concurrently: %v", err))
60+
}
61+
}
62+
63+
doneWG.Done()
64+
}(c)
65+
}
66+
67+
// Unblock all goroutines once they're all spun up, and wait
68+
// for them all to finish reading.
69+
startWG.Wait()
70+
close(sigC)
71+
doneWG.Wait()
72+
}
73+
4374
func testClientDatabases(t *testing.T, c *ovsdb.Client) {
4475
dbs, err := c.ListDatabases()
4576
if err != nil {
@@ -52,3 +83,16 @@ func testClientDatabases(t *testing.T, c *ovsdb.Client) {
5283
t.Fatalf("unexpected databases (-want +got):\n%s", diff)
5384
}
5485
}
86+
87+
func dialOVSDB(t *testing.T) *ovsdb.Client {
88+
t.Helper()
89+
90+
// Assume the standard Linux location for the socket.
91+
const sock = "/var/run/openvswitch/db.sock"
92+
c, err := ovsdb.Dial("unix", sock)
93+
if err != nil {
94+
t.Skipf("could not access %q: %v", sock, err)
95+
}
96+
97+
return c
98+
}

ovsdb/client_test.go

Lines changed: 50 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,45 @@
1515
package ovsdb_test
1616

1717
import (
18+
"encoding/json"
1819
"fmt"
20+
"log"
21+
"os"
1922
"testing"
2023

2124
"github.com/digitalocean/go-openvswitch/ovsdb"
2225
"github.com/digitalocean/go-openvswitch/ovsdb/internal/jsonrpc"
2326
"github.com/google/go-cmp/cmp"
2427
)
2528

26-
func TestClientError(t *testing.T) {
29+
func TestClientJSONRPCError(t *testing.T) {
2730
const str = "some error"
2831

29-
c, done := testClient(t, func(_ jsonrpc.Request) jsonrpc.Response {
32+
c, _, done := testClient(t, func(_ jsonrpc.Request) jsonrpc.Response {
3033
return jsonrpc.Response{
31-
ID: 1,
32-
Result: &ovsdb.Error{
34+
ID: intPtr(1),
35+
Error: str,
36+
}
37+
})
38+
defer done()
39+
40+
_, err := c.ListDatabases()
41+
if err == nil {
42+
t.Fatal("expected an error, but none occurred")
43+
}
44+
}
45+
46+
func TestClientOVSDBError(t *testing.T) {
47+
const str = "some error"
48+
49+
c, _, done := testClient(t, func(_ jsonrpc.Request) jsonrpc.Response {
50+
return jsonrpc.Response{
51+
ID: intPtr(1),
52+
Result: mustMarshalJSON(t, &ovsdb.Error{
3353
Err: str,
3454
Details: "malformed",
3555
Syntax: "{}",
36-
},
56+
}),
3757
}
3858
})
3959
defer done()
@@ -56,7 +76,7 @@ func TestClientError(t *testing.T) {
5676
func TestClientListDatabases(t *testing.T) {
5777
want := []string{"Open_vSwitch", "test"}
5878

59-
c, done := testClient(t, func(req jsonrpc.Request) jsonrpc.Response {
79+
c, _, done := testClient(t, func(req jsonrpc.Request) jsonrpc.Response {
6080
if diff := cmp.Diff("list_dbs", req.Method); diff != "" {
6181
panicf("unexpected RPC method (-want +got):\n%s", diff)
6282
}
@@ -66,8 +86,8 @@ func TestClientListDatabases(t *testing.T) {
6686
}
6787

6888
return jsonrpc.Response{
69-
ID: 1,
70-
Result: want,
89+
ID: intPtr(1),
90+
Result: mustMarshalJSON(t, want),
7191
}
7292
})
7393
defer done()
@@ -82,17 +102,35 @@ func TestClientListDatabases(t *testing.T) {
82102
}
83103
}
84104

85-
func testClient(t *testing.T, fn jsonrpc.TestFunc) (*ovsdb.Client, func()) {
105+
func testClient(t *testing.T, fn jsonrpc.TestFunc) (*ovsdb.Client, chan<- *jsonrpc.Response, func()) {
86106
t.Helper()
87107

88-
conn, done := jsonrpc.TestNetConn(t, fn)
108+
conn, notifC, done := jsonrpc.TestNetConn(t, fn)
89109

90-
c, err := ovsdb.New(conn)
110+
c, err := ovsdb.New(conn, ovsdb.Debug(log.New(os.Stderr, "", 0)))
91111
if err != nil {
92112
t.Fatalf("failed to dial: %v", err)
93113
}
94114

95-
return c, done
115+
return c, notifC, func() {
116+
_ = c.Close()
117+
done()
118+
}
119+
}
120+
121+
func mustMarshalJSON(t *testing.T, v interface{}) []byte {
122+
t.Helper()
123+
124+
b, err := json.Marshal(v)
125+
if err != nil {
126+
t.Fatalf("failed to marshal JSON: %v", err)
127+
}
128+
129+
return b
130+
}
131+
132+
func intPtr(i int) *int {
133+
return &i
96134
}
97135

98136
func panicf(format string, a ...interface{}) {

0 commit comments

Comments
 (0)