Skip to content

Commit 5d68ffd

Browse files
committed
fix: finish reading handshake on lazyConn close
1 parent 990321d commit 5d68ffd

File tree

2 files changed

+100
-3
lines changed

2 files changed

+100
-3
lines changed

lazyClient.go

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,7 @@ func (l *lazyClientConn[T]) Write(b []byte) (int, error) {
134134
return l.con.Write(b)
135135
}
136136

137-
// Close closes the underlying io.ReadWriteCloser
138-
//
139-
// This does not flush anything.
137+
// Close closes the underlying io.ReadWriteCloser after finishing the handshake.
140138
func (l *lazyClientConn[T]) Close() error {
141139
// As the client, we flush the handshake on close to cover an
142140
// interesting edge-case where the server only speaks a single protocol
@@ -147,6 +145,22 @@ func (l *lazyClientConn[T]) Close() error {
147145
// closed the stream for reading. I mean, we're the initiator so that's
148146
// strange... but it's still allowed
149147
_ = l.Flush()
148+
149+
// Finish reading the handshake before we close the connection/stream. This
150+
// is necessary so that the other side can finish sending its response to our
151+
// multistream header before we tell it we are done reading.
152+
//
153+
// Example:
154+
// We open a QUIC stream, write the protocol `/a`, send 1 byte of application
155+
// data, and immediately close.
156+
//
157+
// This can result in a single packet that contains the stream data along
158+
// with a STOP_SENDING frame. The other side may be unable to negotiate
159+
// multistream select since it can't write to the stream anymore and may
160+
// drop the stream.
161+
//
162+
// Note: We currently handle this case in Go(https://github.com/multiformats/go-multistream/pull/87), but rust-libp2p does not.
163+
l.rhandshakeOnce.Do(l.doReadHandshake)
150164
return l.con.Close()
151165
}
152166

multistream_test.go

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"net"
1010
"sort"
1111
"strings"
12+
"sync/atomic"
1213
"testing"
1314
"time"
1415
)
@@ -801,6 +802,88 @@ func TestNegotiatePeerSendsAndCloses(t *testing.T) {
801802
}
802803
}
803804

805+
func newPair() (*chanPipe, *chanPipe) {
806+
a := make(chan []byte, 16)
807+
b := make(chan []byte, 16)
808+
aReadClosed := atomic.Bool{}
809+
bReadClosed := atomic.Bool{}
810+
return &chanPipe{r: a, w: b, myReadClosed: &aReadClosed, peerReadClosed: &bReadClosed},
811+
&chanPipe{r: b, w: a, myReadClosed: &bReadClosed, peerReadClosed: &aReadClosed}
812+
}
813+
814+
type chanPipe struct {
815+
r, w chan []byte
816+
buf bytes.Buffer
817+
818+
myReadClosed *atomic.Bool
819+
peerReadClosed *atomic.Bool
820+
}
821+
822+
func (cp *chanPipe) Read(b []byte) (int, error) {
823+
if cp.buf.Len() > 0 {
824+
return cp.buf.Read(b)
825+
}
826+
827+
buf, ok := <-cp.r
828+
if !ok {
829+
return 0, io.EOF
830+
}
831+
832+
cp.buf.Write(buf)
833+
return cp.buf.Read(b)
834+
}
835+
836+
func (cp *chanPipe) Write(b []byte) (int, error) {
837+
if cp.peerReadClosed.Load() {
838+
panic("peer's read side closed")
839+
}
840+
copied := make([]byte, len(b))
841+
copy(copied, b)
842+
cp.w <- copied
843+
return len(b), nil
844+
}
845+
846+
func (cp *chanPipe) Close() error {
847+
cp.myReadClosed.Store(true)
848+
close(cp.w)
849+
return nil
850+
}
851+
852+
func TestReadHandshakeOnClose(t *testing.T) {
853+
rw1, rw2 := newPair()
854+
855+
clientDone := make(chan struct{})
856+
go func() {
857+
l1 := NewMSSelect(rw1, "a")
858+
_, _ = l1.Write([]byte("hello"))
859+
_ = l1.Close()
860+
close(clientDone)
861+
}()
862+
863+
serverDone := make(chan error)
864+
865+
server := NewMultistreamMuxer[string]()
866+
server.AddHandler("a", func(protocol string, rwc io.ReadWriteCloser) error {
867+
_, err := io.ReadAll(rwc)
868+
rwc.Close()
869+
serverDone <- err
870+
return nil
871+
})
872+
873+
p, h, err := server.Negotiate(rw2)
874+
if err != nil {
875+
t.Fatal(err)
876+
}
877+
878+
go h(p, rw2)
879+
880+
err = <-serverDone
881+
if err != nil {
882+
t.Fatal(err)
883+
}
884+
<-clientDone
885+
}
886+
804887
type rwc struct {
805888
*strings.Reader
806889
}

0 commit comments

Comments
 (0)