|
9 | 9 | "net" |
10 | 10 | "sort" |
11 | 11 | "strings" |
| 12 | + "sync/atomic" |
12 | 13 | "testing" |
13 | 14 | "time" |
14 | 15 | ) |
@@ -801,6 +802,88 @@ func TestNegotiatePeerSendsAndCloses(t *testing.T) { |
801 | 802 | } |
802 | 803 | } |
803 | 804 |
|
| 805 | +func newPair() (*chanPipe, *chanPipe) { |
| 806 | + a := make(chan []byte, 16) |
| 807 | + b := make(chan []byte, 16) |
| 808 | + aReadClosed := atomic.Bool{} |
| 809 | + bReadClosed := atomic.Bool{} |
| 810 | + return &chanPipe{r: a, w: b, myReadClosed: &aReadClosed, peerReadClosed: &bReadClosed}, |
| 811 | + &chanPipe{r: b, w: a, myReadClosed: &bReadClosed, peerReadClosed: &aReadClosed} |
| 812 | +} |
| 813 | + |
| 814 | +type chanPipe struct { |
| 815 | + r, w chan []byte |
| 816 | + buf bytes.Buffer |
| 817 | + |
| 818 | + myReadClosed *atomic.Bool |
| 819 | + peerReadClosed *atomic.Bool |
| 820 | +} |
| 821 | + |
| 822 | +func (cp *chanPipe) Read(b []byte) (int, error) { |
| 823 | + if cp.buf.Len() > 0 { |
| 824 | + return cp.buf.Read(b) |
| 825 | + } |
| 826 | + |
| 827 | + buf, ok := <-cp.r |
| 828 | + if !ok { |
| 829 | + return 0, io.EOF |
| 830 | + } |
| 831 | + |
| 832 | + cp.buf.Write(buf) |
| 833 | + return cp.buf.Read(b) |
| 834 | +} |
| 835 | + |
| 836 | +func (cp *chanPipe) Write(b []byte) (int, error) { |
| 837 | + if cp.peerReadClosed.Load() { |
| 838 | + panic("peer's read side closed") |
| 839 | + } |
| 840 | + copied := make([]byte, len(b)) |
| 841 | + copy(copied, b) |
| 842 | + cp.w <- copied |
| 843 | + return len(b), nil |
| 844 | +} |
| 845 | + |
| 846 | +func (cp *chanPipe) Close() error { |
| 847 | + cp.myReadClosed.Store(true) |
| 848 | + close(cp.w) |
| 849 | + return nil |
| 850 | +} |
| 851 | + |
| 852 | +func TestReadHandshakeOnClose(t *testing.T) { |
| 853 | + rw1, rw2 := newPair() |
| 854 | + |
| 855 | + clientDone := make(chan struct{}) |
| 856 | + go func() { |
| 857 | + l1 := NewMSSelect(rw1, "a") |
| 858 | + _, _ = l1.Write([]byte("hello")) |
| 859 | + _ = l1.Close() |
| 860 | + close(clientDone) |
| 861 | + }() |
| 862 | + |
| 863 | + serverDone := make(chan error) |
| 864 | + |
| 865 | + server := NewMultistreamMuxer[string]() |
| 866 | + server.AddHandler("a", func(protocol string, rwc io.ReadWriteCloser) error { |
| 867 | + _, err := io.ReadAll(rwc) |
| 868 | + rwc.Close() |
| 869 | + serverDone <- err |
| 870 | + return nil |
| 871 | + }) |
| 872 | + |
| 873 | + p, h, err := server.Negotiate(rw2) |
| 874 | + if err != nil { |
| 875 | + t.Fatal(err) |
| 876 | + } |
| 877 | + |
| 878 | + go h(p, rw2) |
| 879 | + |
| 880 | + err = <-serverDone |
| 881 | + if err != nil { |
| 882 | + t.Fatal(err) |
| 883 | + } |
| 884 | + <-clientDone |
| 885 | +} |
| 886 | + |
804 | 887 | type rwc struct { |
805 | 888 | *strings.Reader |
806 | 889 | } |
|
0 commit comments