Skip to content

Commit d2176f4

Browse files
thinkAfCodfjl
andauthored
p2p/discover: pass node instead of node ID to TALKREQ handler (#31075)
This is for the implementation of Portal Network in the Shisui client. Their handler needs access to the node object in order to send further calls to the requesting node. This is a breaking API change but it should be fine, since there are basically no known users of TALKREQ outside of Portal network. --------- Signed-off-by: thinkAfCod <[email protected]> Co-authored-by: Felix Lange <[email protected]>
1 parent 3e4fbce commit d2176f4

File tree

6 files changed

+45
-16
lines changed

6 files changed

+45
-16
lines changed

p2p/discover/v5_talk.go

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ const talkHandlerLaunchTimeout = 400 * time.Millisecond
3939
// Note that talk handlers are expected to come up with a response very quickly, within at
4040
// most 200ms or so. If the handler takes longer than that, the remote end may time out
4141
// and wont receive the response.
42-
type TalkRequestHandler func(enode.ID, *net.UDPAddr, []byte) []byte
42+
type TalkRequestHandler func(*enode.Node, *net.UDPAddr, []byte) []byte
4343

4444
type talkSystem struct {
4545
transport *UDPv5
@@ -72,13 +72,19 @@ func (t *talkSystem) register(protocol string, handler TalkRequestHandler) {
7272

7373
// handleRequest handles a talk request.
7474
func (t *talkSystem) handleRequest(id enode.ID, addr netip.AddrPort, req *v5wire.TalkRequest) {
75+
n := t.transport.codec.SessionNode(id, addr.String())
76+
if n == nil {
77+
// The node must be contained in the session here, since we wouldn't have
78+
// received the request otherwise.
79+
panic("missing node in session")
80+
}
7581
t.mutex.Lock()
7682
handler, ok := t.handlers[req.Protocol]
7783
t.mutex.Unlock()
7884

7985
if !ok {
8086
resp := &v5wire.TalkResponse{ReqID: req.ReqID}
81-
t.transport.sendResponse(id, addr, resp)
87+
t.transport.sendResponse(n.ID(), addr, resp)
8288
return
8389
}
8490

@@ -90,9 +96,9 @@ func (t *talkSystem) handleRequest(id enode.ID, addr netip.AddrPort, req *v5wire
9096
go func() {
9197
defer func() { t.slots <- struct{}{} }()
9298
udpAddr := &net.UDPAddr{IP: addr.Addr().AsSlice(), Port: int(addr.Port())}
93-
respMessage := handler(id, udpAddr, req.Message)
99+
respMessage := handler(n, udpAddr, req.Message)
94100
resp := &v5wire.TalkResponse{ReqID: req.ReqID, Message: respMessage}
95-
t.transport.sendFromAnotherThread(id, addr, resp)
101+
t.transport.sendFromAnotherThread(n.ID(), addr, resp)
96102
}()
97103
case <-timeout.C:
98104
// Couldn't get it in time, drop the request.

p2p/discover/v5_udp.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ type codecV5 interface {
6464
// CurrentChallenge returns the most recent WHOAREYOU challenge that was encoded to given node.
6565
// This will return a non-nil value if there is an active handshake attempt with the node, and nil otherwise.
6666
CurrentChallenge(id enode.ID, addr string) *v5wire.Whoareyou
67+
68+
// SessionNode returns a node that has completed the handshake.
69+
SessionNode(id enode.ID, addr string) *enode.Node
6770
}
6871

6972
// UDPv5 is the implementation of protocol version 5.

p2p/discover/v5_udp_test.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ func TestUDPv5_talkHandling(t *testing.T) {
492492
defer test.close()
493493

494494
var recvMessage []byte
495-
test.udp.RegisterTalkHandler("test", func(id enode.ID, addr *net.UDPAddr, message []byte) []byte {
495+
test.udp.RegisterTalkHandler("test", func(n *enode.Node, addr *net.UDPAddr, message []byte) []byte {
496496
recvMessage = message
497497
return []byte("test response")
498498
})
@@ -811,6 +811,10 @@ func (c *testCodec) Decode(input []byte, addr string) (enode.ID, *enode.Node, v5
811811
return frame.NodeID, nil, p, nil
812812
}
813813

814+
func (c *testCodec) SessionNode(id enode.ID, addr string) *enode.Node {
815+
return c.test.nodesByID[id].Node()
816+
}
817+
814818
func (c *testCodec) decodeFrame(input []byte) (frame testCodecFrame, p v5wire.Packet, err error) {
815819
if err = rlp.DecodeBytes(input, &frame); err != nil {
816820
return frame, nil, fmt.Errorf("invalid frame: %v", err)

p2p/discover/v5wire/encoding.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ func (c *Codec) encodeHandshakeHeader(toID enode.ID, addr string, challenge *Who
359359
}
360360

361361
// TODO: this should happen when the first authenticated message is received
362-
c.sc.storeNewSession(toID, addr, session)
362+
c.sc.storeNewSession(toID, addr, session, challenge.Node)
363363

364364
// Encode the auth header.
365365
var (
@@ -534,7 +534,7 @@ func (c *Codec) decodeHandshakeMessage(fromAddr string, head *Header, headerData
534534
}
535535

536536
// Handshake OK, drop the challenge and store the new session keys.
537-
c.sc.storeNewSession(auth.h.SrcID, fromAddr, session)
537+
c.sc.storeNewSession(auth.h.SrcID, fromAddr, session, node)
538538
c.sc.deleteHandshake(auth.h.SrcID, fromAddr)
539539
return node, msg, nil
540540
}
@@ -656,6 +656,10 @@ func (c *Codec) decryptMessage(input, nonce, headerData, readKey []byte) (Packet
656656
return DecodeMessage(msgdata[0], msgdata[1:])
657657
}
658658

659+
func (c *Codec) SessionNode(id enode.ID, addr string) *enode.Node {
660+
return c.sc.readNode(id, addr)
661+
}
662+
659663
// checkValid performs some basic validity checks on the header.
660664
// The packetLen here is the length remaining after the static header.
661665
func (h *StaticHeader) checkValid(packetLen int, protocolID [6]byte) error {

p2p/discover/v5wire/encoding_test.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ func TestHandshake_rekey(t *testing.T) {
166166
readKey: []byte("BBBBBBBBBBBBBBBB"),
167167
writeKey: []byte("AAAAAAAAAAAAAAAA"),
168168
}
169-
net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), session)
169+
net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), session, net.nodeB.n())
170170

171171
// A -> B FINDNODE (encrypted with zero keys)
172172
findnode, authTag := net.nodeA.encode(t, net.nodeB, &Findnode{})
@@ -209,8 +209,8 @@ func TestHandshake_rekey2(t *testing.T) {
209209
readKey: []byte("CCCCCCCCCCCCCCCC"),
210210
writeKey: []byte("DDDDDDDDDDDDDDDD"),
211211
}
212-
net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), initKeysA)
213-
net.nodeB.c.sc.storeNewSession(net.nodeA.id(), net.nodeA.addr(), initKeysB)
212+
net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), initKeysA, net.nodeB.n())
213+
net.nodeB.c.sc.storeNewSession(net.nodeA.id(), net.nodeA.addr(), initKeysB, net.nodeA.n())
214214

215215
// A -> B FINDNODE encrypted with initKeysA
216216
findnode, authTag := net.nodeA.encode(t, net.nodeB, &Findnode{Distances: []uint{3}})
@@ -362,8 +362,8 @@ func TestTestVectorsV5(t *testing.T) {
362362
ENRSeq: 2,
363363
},
364364
prep: func(net *handshakeTest) {
365-
net.nodeA.c.sc.storeNewSession(idB, addr, session)
366-
net.nodeB.c.sc.storeNewSession(idA, addr, session.keysFlipped())
365+
net.nodeA.c.sc.storeNewSession(idB, addr, session, net.nodeB.n())
366+
net.nodeB.c.sc.storeNewSession(idA, addr, session.keysFlipped(), net.nodeA.n())
367367
},
368368
},
369369
{
@@ -499,8 +499,8 @@ func BenchmarkV5_DecodePing(b *testing.B) {
499499
readKey: []byte{233, 203, 93, 195, 86, 47, 177, 186, 227, 43, 2, 141, 244, 230, 120, 17},
500500
writeKey: []byte{79, 145, 252, 171, 167, 216, 252, 161, 208, 190, 176, 106, 214, 39, 178, 134},
501501
}
502-
net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), session)
503-
net.nodeB.c.sc.storeNewSession(net.nodeA.id(), net.nodeA.addr(), session.keysFlipped())
502+
net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), session, net.nodeB.n())
503+
net.nodeB.c.sc.storeNewSession(net.nodeA.id(), net.nodeA.addr(), session.keysFlipped(), net.nodeA.n())
504504
addrB := net.nodeA.addr()
505505
ping := &Ping{ReqID: []byte("reqid"), ENRSeq: 5}
506506
enc, _, err := net.nodeA.c.Encode(net.nodeB.id(), addrB, ping, nil)

p2p/discover/v5wire/session.go

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,12 @@ type session struct {
5454
writeKey []byte
5555
readKey []byte
5656
nonceCounter uint32
57+
node *enode.Node
5758
}
5859

5960
// keysFlipped returns a copy of s with the read and write keys flipped.
6061
func (s *session) keysFlipped() *session {
61-
return &session{s.readKey, s.writeKey, s.nonceCounter}
62+
return &session{s.readKey, s.writeKey, s.nonceCounter, s.node}
6263
}
6364

6465
func NewSessionCache(maxItems int, clock mclock.Clock) *SessionCache {
@@ -103,8 +104,19 @@ func (sc *SessionCache) readKey(id enode.ID, addr string) []byte {
103104
return nil
104105
}
105106

107+
func (sc *SessionCache) readNode(id enode.ID, addr string) *enode.Node {
108+
if s := sc.session(id, addr); s != nil {
109+
return s.node
110+
}
111+
return nil
112+
}
113+
106114
// storeNewSession stores new encryption keys in the cache.
107-
func (sc *SessionCache) storeNewSession(id enode.ID, addr string, s *session) {
115+
func (sc *SessionCache) storeNewSession(id enode.ID, addr string, s *session, n *enode.Node) {
116+
if n == nil {
117+
panic("nil node in storeNewSession")
118+
}
119+
s.node = n
108120
sc.sessions.Add(sessionID{id, addr}, s)
109121
}
110122

0 commit comments

Comments
 (0)