diff --git a/device/device.go b/device/device.go index 8e5572432..7deaa7c5e 100644 --- a/device/device.go +++ b/device/device.go @@ -392,11 +392,10 @@ func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() { device.peers.RLock() for _, peer := range device.peers.keyMap { - peer.keypairs.RLock() - sendKeepalive := peer.keypairs.current != nil && !peer.keypairs.current.created.Add(RejectAfterTime).Before(time.Now()) - peer.keypairs.RUnlock() - if sendKeepalive { - peer.SendKeepalive() + if current := peer.keypairs.Current(); current != nil { + if current.created.Add(RejectAfterTime).Before(time.Now()) { + peer.SendKeepalive() + } } } device.peers.RUnlock() diff --git a/device/keypair.go b/device/keypair.go index e3540d7d4..eea75c060 100644 --- a/device/keypair.go +++ b/device/keypair.go @@ -33,15 +33,15 @@ type Keypair struct { } type Keypairs struct { - sync.RWMutex + sync.Mutex current *Keypair previous *Keypair - next atomic.Pointer[Keypair] + next *Keypair } func (kp *Keypairs) Current() *Keypair { - kp.RLock() - defer kp.RUnlock() + kp.Lock() + defer kp.Unlock() return kp.current } diff --git a/device/noise-protocol.go b/device/noise-protocol.go index 117e960a8..405be7c68 100644 --- a/device/noise-protocol.go +++ b/device/noise-protocol.go @@ -581,12 +581,12 @@ func (peer *Peer) BeginSymmetricSession() error { defer keypairs.Unlock() previous := keypairs.previous - next := keypairs.next.Load() + next := keypairs.next current := keypairs.current if isInitiator { if next != nil { - keypairs.next.Store(nil) + keypairs.next = nil keypairs.previous = next device.DeleteKeypair(current) } else { @@ -595,7 +595,7 @@ func (peer *Peer) BeginSymmetricSession() error { device.DeleteKeypair(previous) keypairs.current = keypair } else { - keypairs.next.Store(keypair) + keypairs.next = keypair device.DeleteKeypair(next) keypairs.previous = nil device.DeleteKeypair(previous) @@ -607,18 +607,15 @@ func (peer *Peer) BeginSymmetricSession() error { func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool { keypairs := &peer.keypairs - if keypairs.next.Load() != receivedKeypair { - return false - } keypairs.Lock() defer keypairs.Unlock() - if keypairs.next.Load() != receivedKeypair { + if keypairs.next != receivedKeypair { return false } old := keypairs.previous keypairs.previous = keypairs.current peer.device.DeleteKeypair(old) - keypairs.current = keypairs.next.Load() - keypairs.next.Store(nil) + keypairs.current = keypairs.next + keypairs.next = nil return true } diff --git a/device/noise_test.go b/device/noise_test.go index 587d1e55d..278b6451d 100644 --- a/device/noise_test.go +++ b/device/noise_test.go @@ -148,7 +148,7 @@ func TestNoiseHandshake(t *testing.T) { t.Fatal("failed to derive keypair for peer 2", err) } - key1 := peer1.keypairs.next.Load() + key1 := peer1.keypairs.next key2 := peer2.keypairs.current // encrypting / decryption test diff --git a/device/peer.go b/device/peer.go index 8266dacc0..904dd083c 100644 --- a/device/peer.go +++ b/device/peer.go @@ -202,10 +202,10 @@ func (peer *Peer) ZeroAndFlushAll() { keypairs.Lock() device.DeleteKeypair(keypairs.previous) device.DeleteKeypair(keypairs.current) - device.DeleteKeypair(keypairs.next.Load()) + device.DeleteKeypair(keypairs.next) keypairs.previous = nil keypairs.current = nil - keypairs.next.Store(nil) + keypairs.next = nil keypairs.Unlock() // clear handshake state @@ -232,8 +232,8 @@ func (peer *Peer) ExpireCurrentKeypairs() { if keypairs.current != nil { keypairs.current.sendNonce.Store(RejectAfterMessages) } - if next := keypairs.next.Load(); next != nil { - next.sendNonce.Store(RejectAfterMessages) + if keypairs.next != nil { + keypairs.next.sendNonce.Store(RejectAfterMessages) } keypairs.Unlock() }