Skip to content

Commit 810f1e3

Browse files
committed
go/p2p/protocol: Replace global registry with parametric
Avoid global state to enable writting unit tests with multiple peers in the same process.
1 parent c39388d commit 810f1e3

File tree

3 files changed

+22
-22
lines changed

3 files changed

+22
-22
lines changed

go/p2p/p2p.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ type p2p struct {
8383
registerAddresses []multiaddr.Multiaddr
8484
topics map[string]*topicHandler
8585

86+
protocolRegistry *protocol.Registry
87+
8688
logger *logging.Logger
8789
}
8890

@@ -281,7 +283,7 @@ func (p *p2p) Publish(_ context.Context, topic string, msg any) {
281283

282284
// Implements api.Service.
283285
func (p *p2p) RegisterHandler(topic string, handler api.Handler) {
284-
protocol.ValidateTopicID(topic)
286+
p.protocolRegistry.ValidateTopicID(topic)
285287

286288
p.Lock()
287289
defer p.Unlock()
@@ -336,7 +338,7 @@ func (p *p2p) PeerManager() api.PeerManager {
336338

337339
// Implements api.Service.
338340
func (p *p2p) RegisterProtocolServer(srv rpc.Server) {
339-
protocol.ValidateProtocolID(srv.Protocol())
341+
p.protocolRegistry.ValidateProtocolID(srv.Protocol())
340342

341343
p.host.SetStreamHandler(srv.Protocol(), srv.HandleStream)
342344

@@ -436,6 +438,7 @@ func New(identity *identity.Identity, chainContext string, store *persistent.Com
436438
pubsub: pubsub,
437439
registerAddresses: cfg.Addresses,
438440
topics: make(map[string]*topicHandler),
441+
protocolRegistry: protocol.NewRegistry(),
439442
logger: logger,
440443
}, nil
441444
}

go/p2p/protocol/protocol.go

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,33 +12,32 @@ import (
1212
"github.com/oasisprotocol/oasis-core/go/p2p/api"
1313
)
1414

15-
type protocolRegistry struct {
15+
// Registry is responsible for ensuring unique protocol ids.
16+
type Registry struct {
1617
mu sync.Mutex
1718
protocols map[core.ProtocolID]struct{}
1819
}
1920

20-
func newProtocolRegistry() *protocolRegistry {
21-
return &protocolRegistry{
21+
func NewRegistry() *Registry {
22+
return &Registry{
2223
protocols: make(map[core.ProtocolID]struct{}),
2324
}
2425
}
2526

26-
var registry = newProtocolRegistry()
27-
2827
// ValidateProtocolID panics if the protocol id is not unique.
29-
func ValidateProtocolID(p core.ProtocolID) {
30-
registry.mu.Lock()
31-
defer registry.mu.Unlock()
28+
func (r *Registry) ValidateProtocolID(p core.ProtocolID) {
29+
r.mu.Lock()
30+
defer r.mu.Unlock()
3231

33-
if _, ok := registry.protocols[p]; ok {
32+
if _, ok := r.protocols[p]; ok {
3433
panic(fmt.Sprintf("p2p/protocol: protocol or topic with name '%s' already exists", p))
3534
}
36-
registry.protocols[p] = struct{}{}
35+
r.protocols[p] = struct{}{}
3736
}
3837

3938
// ValidateTopicID panics if the topic id is not unique.
40-
func ValidateTopicID(topic string) {
41-
ValidateProtocolID(core.ProtocolID(topic))
39+
func (r *Registry) ValidateTopicID(topic string) {
40+
r.ValidateProtocolID(core.ProtocolID(topic))
4241
}
4342

4443
// NewProtocolID generates a protocol identifier for a consensus P2P protocol.

go/p2p/protocol/protocol_test.go

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,10 @@ func TestProtocolID(t *testing.T) {
5656
require.Equal(expected, NewTopicIDForRuntime(chainContext, runtimeID, kind, version))
5757
})
5858

59-
registry = newProtocolRegistry()
60-
6159
t.Run("ValidateProtocolID", func(_ *testing.T) {
62-
ValidateProtocolID("protocol-1")
63-
ValidateProtocolID("protocol-2")
60+
r := NewRegistry()
61+
r.ValidateProtocolID("protocol-1")
62+
r.ValidateProtocolID("protocol-2")
6463
})
6564

6665
t.Run("ValidateProtocolID panics", func(t *testing.T) {
@@ -69,9 +68,8 @@ func TestProtocolID(t *testing.T) {
6968
t.Errorf("validate protocol id should fail")
7069
}
7170
}()
72-
ValidateProtocolID("protocol")
73-
ValidateProtocolID("protocol")
71+
r := NewRegistry()
72+
r.ValidateProtocolID("protocol")
73+
r.ValidateProtocolID("protocol")
7474
})
75-
76-
registry = newProtocolRegistry()
7775
}

0 commit comments

Comments
 (0)