Skip to content

Commit 3a13b0d

Browse files
authored
p2p/protocols, p2p/testing; conditional propagation of context (ethersphere#1648)
* p2p/protocols, p2p/testing; conditional propagagation of context * p2p/protocols: NewPeer should allow empty spec for testing
1 parent 265f0fd commit 3a13b0d

File tree

4 files changed

+132
-117
lines changed

4 files changed

+132
-117
lines changed

p2p/protocols/context.go

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package protocols
2+
3+
import (
4+
"bufio"
5+
"bytes"
6+
"context"
7+
"io/ioutil"
8+
9+
"github.com/ethereum/go-ethereum/p2p"
10+
"github.com/ethereum/go-ethereum/rlp"
11+
"github.com/ethersphere/swarm/spancontext"
12+
opentracing "github.com/opentracing/opentracing-go"
13+
)
14+
15+
// msgWithContext is used to propagate marshalled context alongside message payloads
16+
type msgWithContext struct {
17+
Context []byte
18+
Msg []byte
19+
}
20+
21+
func encodeWithContext(ctx context.Context, msg interface{}) (interface{}, int, error) {
22+
var b bytes.Buffer
23+
writer := bufio.NewWriter(&b)
24+
tracer := opentracing.GlobalTracer()
25+
sctx := spancontext.FromContext(ctx)
26+
if sctx != nil {
27+
err := tracer.Inject(
28+
sctx,
29+
opentracing.Binary,
30+
writer)
31+
if err != nil {
32+
return nil, 0, err
33+
}
34+
}
35+
writer.Flush()
36+
msgBytes, err := rlp.EncodeToBytes(msg)
37+
if err != nil {
38+
return nil, 0, err
39+
}
40+
41+
return &msgWithContext{
42+
Context: b.Bytes(),
43+
Msg: msgBytes,
44+
}, len(msgBytes), nil
45+
}
46+
47+
func decodeWithContext(msg p2p.Msg) (context.Context, []byte, error) {
48+
var wmsg msgWithContext
49+
err := msg.Decode(&wmsg)
50+
if err != nil {
51+
return nil, nil, err
52+
}
53+
54+
ctx := context.Background()
55+
56+
if len(wmsg.Context) == 0 {
57+
return ctx, wmsg.Msg, nil
58+
}
59+
60+
tracer := opentracing.GlobalTracer()
61+
sctx, err := tracer.Extract(opentracing.Binary, bytes.NewReader(wmsg.Context))
62+
if err != nil {
63+
return nil, nil, err
64+
}
65+
ctx = spancontext.WithContext(ctx, sctx)
66+
return ctx, wmsg.Msg, nil
67+
}
68+
69+
func encodeWithoutContext(ctx context.Context, msg interface{}) (interface{}, int, error) {
70+
return msg, 0, nil
71+
}
72+
73+
func decodeWithoutContext(msg p2p.Msg) (context.Context, []byte, error) {
74+
b, err := ioutil.ReadAll(msg.Payload)
75+
if err != nil {
76+
return nil, nil, err
77+
}
78+
return context.Background(), b, nil
79+
}

p2p/protocols/protocol.go

Lines changed: 40 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@ devp2p subprotocols by abstracting away code standardly shared by protocols.
2929
package protocols
3030

3131
import (
32-
"bufio"
33-
"bytes"
3432
"context"
3533
"fmt"
3634
"io"
@@ -42,9 +40,7 @@ import (
4240
"github.com/ethereum/go-ethereum/metrics"
4341
"github.com/ethereum/go-ethereum/p2p"
4442
"github.com/ethereum/go-ethereum/rlp"
45-
"github.com/ethersphere/swarm/spancontext"
4643
"github.com/ethersphere/swarm/tracing"
47-
opentracing "github.com/opentracing/opentracing-go"
4844
)
4945

5046
// error codes used by this protocol scheme
@@ -115,13 +111,6 @@ func errorf(code int, format string, params ...interface{}) *Error {
115111
}
116112
}
117113

118-
// WrappedMsg is used to propagate marshalled context alongside message payloads
119-
type WrappedMsg struct {
120-
Context []byte
121-
Size uint32
122-
Payload []byte
123-
}
124-
125114
//For accounting, the design is to allow the Spec to describe which and how its messages are priced
126115
//To access this functionality, we provide a Hook interface which will call accounting methods
127116
//NOTE: there could be more such (horizontal) hooks in the future
@@ -157,6 +146,10 @@ type Spec struct {
157146
initOnce sync.Once
158147
codes map[reflect.Type]uint64
159148
types map[uint64]reflect.Type
149+
150+
// if the protocol does not allow extending the p2p msg to propagate context
151+
// even if context not disabled, context will propagate only tracing is enabled
152+
DisableContext bool
160153
}
161154

162155
func (s *Spec) init() {
@@ -208,17 +201,27 @@ type Peer struct {
208201
*p2p.Peer // the p2p.Peer object representing the remote
209202
rw p2p.MsgReadWriter // p2p.MsgReadWriter to send messages to and read messages from
210203
spec *Spec
204+
encode func(context.Context, interface{}) (interface{}, int, error)
205+
decode func(p2p.Msg) (context.Context, []byte, error)
211206
}
212207

213208
// NewPeer constructs a new peer
214209
// this constructor is called by the p2p.Protocol#Run function
215210
// the first two arguments are the arguments passed to p2p.Protocol.Run function
216211
// the third argument is the Spec describing the protocol
217-
func NewPeer(p *p2p.Peer, rw p2p.MsgReadWriter, spec *Spec) *Peer {
212+
func NewPeer(peer *p2p.Peer, rw p2p.MsgReadWriter, spec *Spec) *Peer {
213+
encode := encodeWithContext
214+
decode := decodeWithContext
215+
if spec == nil || spec.DisableContext || !tracing.Enabled {
216+
encode = encodeWithoutContext
217+
decode = decodeWithoutContext
218+
}
218219
return &Peer{
219-
Peer: p,
220-
rw: rw,
221-
spec: spec,
220+
Peer: peer,
221+
rw: rw,
222+
spec: spec,
223+
encode: encode,
224+
decode: decode,
222225
}
223226
}
224227

@@ -234,7 +237,6 @@ func (p *Peer) Run(handler func(ctx context.Context, msg interface{}) error) err
234237
metrics.GetOrRegisterCounter("peer.handleincoming.error", nil).Inc(1)
235238
log.Error("peer.handleIncoming", "err", err)
236239
}
237-
238240
return err
239241
}
240242
}
@@ -256,51 +258,32 @@ func (p *Peer) Send(ctx context.Context, msg interface{}) error {
256258
metrics.GetOrRegisterCounter("peer.send", nil).Inc(1)
257259
metrics.GetOrRegisterCounter(fmt.Sprintf("peer.send.%T", msg), nil).Inc(1)
258260

259-
var b bytes.Buffer
260-
if tracing.Enabled {
261-
writer := bufio.NewWriter(&b)
262-
263-
tracer := opentracing.GlobalTracer()
264-
265-
sctx := spancontext.FromContext(ctx)
266-
267-
if sctx != nil {
268-
err := tracer.Inject(
269-
sctx,
270-
opentracing.Binary,
271-
writer)
272-
if err != nil {
273-
return err
274-
}
275-
}
276-
277-
writer.Flush()
261+
code, found := p.spec.GetCode(msg)
262+
if !found {
263+
return errorf(ErrInvalidMsgType, "%v", code)
278264
}
279265

280-
r, err := rlp.EncodeToBytes(msg)
266+
wmsg, size, err := p.encode(ctx, msg)
281267
if err != nil {
282268
return err
283269
}
284270

285-
wmsg := WrappedMsg{
286-
Context: b.Bytes(),
287-
Size: uint32(len(r)),
288-
Payload: r,
271+
// if size is not set by the wrapper, need to serialise
272+
if size == 0 {
273+
r, err := rlp.EncodeToBytes(msg)
274+
if err != nil {
275+
return err
276+
}
277+
size = len(r)
289278
}
290-
291-
//if the accounting hook is set, call it
279+
// if the accounting hook is set, call it
292280
if p.spec.Hook != nil {
293-
err := p.spec.Hook.Send(p, wmsg.Size, msg)
281+
err = p.spec.Hook.Send(p, uint32(size), msg)
294282
if err != nil {
295-
p.Drop()
296283
return err
297284
}
298285
}
299286

300-
code, found := p.spec.GetCode(msg)
301-
if !found {
302-
return errorf(ErrInvalidMsgType, "%v", code)
303-
}
304287
return p2p.Send(p.rw, code, wmsg)
305288
}
306289

@@ -324,44 +307,23 @@ func (p *Peer) handleIncoming(handle func(ctx context.Context, msg interface{})
324307
return errorf(ErrMsgTooLong, "%v > %v", msg.Size, p.spec.MaxMsgSize)
325308
}
326309

327-
// unmarshal wrapped msg, which might contain context
328-
var wmsg WrappedMsg
329-
err = msg.Decode(&wmsg)
330-
if err != nil {
331-
log.Error(err.Error())
332-
return err
333-
}
334-
335-
ctx := context.Background()
336-
337-
// if tracing is enabled and the context coming within the request is
338-
// not empty, try to unmarshal it
339-
if tracing.Enabled && len(wmsg.Context) > 0 {
340-
var sctx opentracing.SpanContext
341-
342-
tracer := opentracing.GlobalTracer()
343-
sctx, err = tracer.Extract(
344-
opentracing.Binary,
345-
bytes.NewReader(wmsg.Context))
346-
if err != nil {
347-
log.Error(err.Error())
348-
return err
349-
}
350-
351-
ctx = spancontext.WithContext(ctx, sctx)
352-
}
353-
354310
val, ok := p.spec.NewMsg(msg.Code)
355311
if !ok {
356312
return errorf(ErrInvalidMsgCode, "%v", msg.Code)
357313
}
358-
if err := rlp.DecodeBytes(wmsg.Payload, val); err != nil {
314+
315+
ctx, msgBytes, err := p.decode(msg)
316+
if err != nil {
317+
return errorf(ErrDecode, "%v err=%v", msg.Code, err)
318+
}
319+
320+
if err := rlp.DecodeBytes(msgBytes, val); err != nil {
359321
return errorf(ErrDecode, "<= %v: %v", msg, err)
360322
}
361323

362-
//if the accounting hook is set, call it
324+
// if the accounting hook is set, call it
363325
if p.spec.Hook != nil {
364-
err := p.spec.Hook.Receive(p, wmsg.Size, val)
326+
err := p.spec.Hook.Receive(p, uint32(len(msgBytes)), val)
365327
if err != nil {
366328
return err
367329
}

p2p/protocols/protocol_test.go

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -249,9 +249,7 @@ func TestProtocolHook(t *testing.T) {
249249
runFunc := func(p *p2p.Peer, rw p2p.MsgReadWriter) error {
250250
peer := NewPeer(p, rw, spec)
251251
ctx := context.TODO()
252-
err := peer.Send(ctx, &dummyMsg{
253-
Content: "handshake"})
254-
252+
err := peer.Send(ctx, &dummyMsg{Content: "handshake"})
255253
if err != nil {
256254
t.Fatal(err)
257255
}
@@ -281,6 +279,7 @@ func TestProtocolHook(t *testing.T) {
281279
if err != nil {
282280
t.Fatal(err)
283281
}
282+
284283
testHook.mu.Lock()
285284
if testHook.msg == nil || testHook.msg.(*dummyMsg).Content != "handshake" {
286285
t.Fatal("Expected msg to be set, but it is not")
@@ -291,8 +290,8 @@ func TestProtocolHook(t *testing.T) {
291290
if testHook.peer == nil {
292291
t.Fatal("Expected peer to be set, is nil")
293292
}
294-
if peerId := testHook.peer.ID(); peerId != tester.Nodes[0].ID() && peerId != tester.Nodes[1].ID() {
295-
t.Fatalf("Expected peer ID to be set correctly, but it is not (got %v, exp %v or %v", peerId, tester.Nodes[0].ID(), tester.Nodes[1].ID())
293+
if peerID := testHook.peer.ID(); peerID != tester.Nodes[0].ID() && peerID != tester.Nodes[1].ID() {
294+
t.Fatalf("Expected peer ID to be set correctly, but it is not (got %v, exp %v or %v", peerID, tester.Nodes[0].ID(), tester.Nodes[1].ID())
296295
}
297296
if testHook.size != 11 { //11 is the length of the encoded message
298297
t.Fatalf("Expected size to be %d, but it is %d ", 1, testHook.size)
@@ -309,11 +308,10 @@ func TestProtocolHook(t *testing.T) {
309308
},
310309
})
311310

312-
<-testHook.waitC
313-
314311
if err != nil {
315312
t.Fatal(err)
316313
}
314+
<-testHook.waitC
317315

318316
testHook.mu.Lock()
319317
if testHook.msg == nil || testHook.msg.(*dummyMsg).Content != "response" {
@@ -600,24 +598,15 @@ func (d *dummyRW) WriteMsg(msg p2p.Msg) error {
600598
}
601599

602600
func (d *dummyRW) ReadMsg() (p2p.Msg, error) {
603-
enc := bytes.NewReader(d.getDummyMsg())
601+
r, err := rlp.EncodeToBytes(d.msg)
602+
if err != nil {
603+
return p2p.Msg{}, err
604+
}
605+
enc := bytes.NewReader(r)
604606
return p2p.Msg{
605607
Code: d.code,
606608
Size: d.size,
607609
Payload: enc,
608610
ReceivedAt: time.Now(),
609611
}, nil
610612
}
611-
612-
func (d *dummyRW) getDummyMsg() []byte {
613-
r, _ := rlp.EncodeToBytes(d.msg)
614-
var b bytes.Buffer
615-
wmsg := WrappedMsg{
616-
Context: b.Bytes(),
617-
Size: uint32(len(r)),
618-
Payload: r,
619-
}
620-
rr, _ := rlp.EncodeToBytes(wmsg)
621-
d.size = uint32(len(rr))
622-
return rr
623-
}

0 commit comments

Comments
 (0)