diff --git a/common/bufio/channel_demux.go b/common/bufio/channel_demux.go new file mode 100644 index 00000000..2876083c --- /dev/null +++ b/common/bufio/channel_demux.go @@ -0,0 +1,143 @@ +package bufio + +import ( + "context" + "reflect" + "sync" + "sync/atomic" + + N "github.com/sagernet/sing/common/network" +) + +type channelDemuxEntry struct { + channel <-chan *N.PacketBuffer + stream *reactorStream +} + +type ChannelDemultiplexer struct { + ctx context.Context + cancel context.CancelFunc + mutex sync.Mutex + entries map[<-chan *N.PacketBuffer]*channelDemuxEntry + updateChan chan struct{} + running bool + closed atomic.Bool + wg sync.WaitGroup +} + +func NewChannelDemultiplexer(ctx context.Context) *ChannelDemultiplexer { + ctx, cancel := context.WithCancel(ctx) + demux := &ChannelDemultiplexer{ + ctx: ctx, + cancel: cancel, + entries: make(map[<-chan *N.PacketBuffer]*channelDemuxEntry), + updateChan: make(chan struct{}, 1), + } + return demux +} + +func (d *ChannelDemultiplexer) Add(stream *reactorStream, channel <-chan *N.PacketBuffer) { + d.mutex.Lock() + + if d.closed.Load() { + d.mutex.Unlock() + return + } + + entry := &channelDemuxEntry{ + channel: channel, + stream: stream, + } + d.entries[channel] = entry + if !d.running { + d.running = true + d.wg.Add(1) + go d.run() + } + d.mutex.Unlock() + d.signalUpdate() +} + +func (d *ChannelDemultiplexer) Remove(channel <-chan *N.PacketBuffer) { + d.mutex.Lock() + delete(d.entries, channel) + d.mutex.Unlock() + d.signalUpdate() +} + +func (d *ChannelDemultiplexer) signalUpdate() { + select { + case d.updateChan <- struct{}{}: + default: + } +} + +func (d *ChannelDemultiplexer) Close() error { + d.mutex.Lock() + d.closed.Store(true) + d.mutex.Unlock() + + d.cancel() + d.signalUpdate() + d.wg.Wait() + return nil +} + +func (d *ChannelDemultiplexer) run() { + defer d.wg.Done() + + for { + d.mutex.Lock() + if len(d.entries) == 0 { + d.running = false + d.mutex.Unlock() + return + } + + cases := make([]reflect.SelectCase, 0, len(d.entries)+2) + + cases = append(cases, reflect.SelectCase{ + Dir: reflect.SelectRecv, + Chan: reflect.ValueOf(d.ctx.Done()), + }) + + cases = append(cases, reflect.SelectCase{ + Dir: reflect.SelectRecv, + Chan: reflect.ValueOf(d.updateChan), + }) + + entryList := make([]*channelDemuxEntry, 0, len(d.entries)) + for _, entry := range d.entries { + cases = append(cases, reflect.SelectCase{ + Dir: reflect.SelectRecv, + Chan: reflect.ValueOf(entry.channel), + }) + entryList = append(entryList, entry) + } + d.mutex.Unlock() + + chosen, recv, recvOK := reflect.Select(cases) + + switch chosen { + case 0: + d.mutex.Lock() + d.running = false + d.mutex.Unlock() + return + case 1: + continue + default: + entry := entryList[chosen-2] + d.mutex.Lock() + delete(d.entries, entry.channel) + d.mutex.Unlock() + + if recvOK { + packet := recv.Interface().(*N.PacketBuffer) + go entry.stream.runActiveLoop(packet) + } else { + go entry.stream.closeWithError(nil) + } + } + } +} diff --git a/common/bufio/fd_demux_darwin.go b/common/bufio/fd_demux_darwin.go new file mode 100644 index 00000000..3e1c7876 --- /dev/null +++ b/common/bufio/fd_demux_darwin.go @@ -0,0 +1,225 @@ +//go:build darwin + +package bufio + +import ( + "context" + "sync" + "sync/atomic" + + "golang.org/x/sys/unix" +) + +type fdDemuxEntry struct { + fd int + stream *reactorStream +} + +type FDDemultiplexer struct { + ctx context.Context + cancel context.CancelFunc + kqueueFD int + mutex sync.Mutex + entries map[int]*fdDemuxEntry + running bool + closed atomic.Bool + wg sync.WaitGroup + pipeFDs [2]int +} + +func NewFDDemultiplexer(ctx context.Context) (*FDDemultiplexer, error) { + kqueueFD, err := unix.Kqueue() + if err != nil { + return nil, err + } + + var pipeFDs [2]int + err = unix.Pipe(pipeFDs[:]) + if err != nil { + unix.Close(kqueueFD) + return nil, err + } + + err = unix.SetNonblock(pipeFDs[0], true) + if err != nil { + unix.Close(pipeFDs[0]) + unix.Close(pipeFDs[1]) + unix.Close(kqueueFD) + return nil, err + } + err = unix.SetNonblock(pipeFDs[1], true) + if err != nil { + unix.Close(pipeFDs[0]) + unix.Close(pipeFDs[1]) + unix.Close(kqueueFD) + return nil, err + } + + _, err = unix.Kevent(kqueueFD, []unix.Kevent_t{{ + Ident: uint64(pipeFDs[0]), + Filter: unix.EVFILT_READ, + Flags: unix.EV_ADD, + }}, nil, nil) + if err != nil { + unix.Close(pipeFDs[0]) + unix.Close(pipeFDs[1]) + unix.Close(kqueueFD) + return nil, err + } + + ctx, cancel := context.WithCancel(ctx) + demux := &FDDemultiplexer{ + ctx: ctx, + cancel: cancel, + kqueueFD: kqueueFD, + entries: make(map[int]*fdDemuxEntry), + pipeFDs: pipeFDs, + } + return demux, nil +} + +func (d *FDDemultiplexer) Add(stream *reactorStream, fd int) error { + d.mutex.Lock() + defer d.mutex.Unlock() + + if d.closed.Load() { + return unix.EINVAL + } + + _, err := unix.Kevent(d.kqueueFD, []unix.Kevent_t{{ + Ident: uint64(fd), + Filter: unix.EVFILT_READ, + Flags: unix.EV_ADD, + }}, nil, nil) + if err != nil { + return err + } + + entry := &fdDemuxEntry{ + fd: fd, + stream: stream, + } + d.entries[fd] = entry + + if !d.running { + d.running = true + d.wg.Add(1) + go d.run() + } + + return nil +} + +func (d *FDDemultiplexer) Remove(fd int) { + d.mutex.Lock() + defer d.mutex.Unlock() + + _, ok := d.entries[fd] + if !ok { + return + } + + unix.Kevent(d.kqueueFD, []unix.Kevent_t{{ + Ident: uint64(fd), + Filter: unix.EVFILT_READ, + Flags: unix.EV_DELETE, + }}, nil, nil) + delete(d.entries, fd) +} + +func (d *FDDemultiplexer) wakeup() { + unix.Write(d.pipeFDs[1], []byte{0}) +} + +func (d *FDDemultiplexer) Close() error { + d.mutex.Lock() + d.closed.Store(true) + d.mutex.Unlock() + + d.cancel() + d.wakeup() + d.wg.Wait() + + d.mutex.Lock() + defer d.mutex.Unlock() + + if d.kqueueFD != -1 { + unix.Close(d.kqueueFD) + d.kqueueFD = -1 + } + if d.pipeFDs[0] != -1 { + unix.Close(d.pipeFDs[0]) + unix.Close(d.pipeFDs[1]) + d.pipeFDs[0] = -1 + d.pipeFDs[1] = -1 + } + return nil +} + +func (d *FDDemultiplexer) run() { + defer d.wg.Done() + + events := make([]unix.Kevent_t, 64) + var buffer [1]byte + + for { + select { + case <-d.ctx.Done(): + d.mutex.Lock() + d.running = false + d.mutex.Unlock() + return + default: + } + + n, err := unix.Kevent(d.kqueueFD, nil, events, nil) + if err != nil { + if err == unix.EINTR { + continue + } + d.mutex.Lock() + d.running = false + d.mutex.Unlock() + return + } + + for i := 0; i < n; i++ { + event := events[i] + fd := int(event.Ident) + + if fd == d.pipeFDs[0] { + unix.Read(d.pipeFDs[0], buffer[:]) + continue + } + + if event.Flags&unix.EV_ERROR != 0 { + continue + } + + d.mutex.Lock() + entry, ok := d.entries[fd] + if !ok { + d.mutex.Unlock() + continue + } + + unix.Kevent(d.kqueueFD, []unix.Kevent_t{{ + Ident: uint64(fd), + Filter: unix.EVFILT_READ, + Flags: unix.EV_DELETE, + }}, nil, nil) + delete(d.entries, fd) + d.mutex.Unlock() + + go entry.stream.runActiveLoop(nil) + } + + d.mutex.Lock() + if len(d.entries) == 0 { + d.running = false + d.mutex.Unlock() + return + } + d.mutex.Unlock() + } +} diff --git a/common/bufio/fd_demux_linux.go b/common/bufio/fd_demux_linux.go new file mode 100644 index 00000000..2c5e0afa --- /dev/null +++ b/common/bufio/fd_demux_linux.go @@ -0,0 +1,217 @@ +//go:build linux + +package bufio + +import ( + "context" + "sync" + "sync/atomic" + "unsafe" + + "golang.org/x/sys/unix" +) + +type fdDemuxEntry struct { + fd int + registrationID uint64 + stream *reactorStream +} + +type FDDemultiplexer struct { + ctx context.Context + cancel context.CancelFunc + epollFD int + mutex sync.Mutex + entries map[int]*fdDemuxEntry + registrationCounter uint64 + registrationToFD map[uint64]int + running bool + closed atomic.Bool + wg sync.WaitGroup + pipeFDs [2]int +} + +func NewFDDemultiplexer(ctx context.Context) (*FDDemultiplexer, error) { + epollFD, err := unix.EpollCreate1(unix.EPOLL_CLOEXEC) + if err != nil { + return nil, err + } + + var pipeFDs [2]int + err = unix.Pipe2(pipeFDs[:], unix.O_NONBLOCK|unix.O_CLOEXEC) + if err != nil { + unix.Close(epollFD) + return nil, err + } + + pipeEvent := &unix.EpollEvent{Events: unix.EPOLLIN} + *(*uint64)(unsafe.Pointer(&pipeEvent.Fd)) = 0 + err = unix.EpollCtl(epollFD, unix.EPOLL_CTL_ADD, pipeFDs[0], pipeEvent) + if err != nil { + unix.Close(pipeFDs[0]) + unix.Close(pipeFDs[1]) + unix.Close(epollFD) + return nil, err + } + + ctx, cancel := context.WithCancel(ctx) + demux := &FDDemultiplexer{ + ctx: ctx, + cancel: cancel, + epollFD: epollFD, + entries: make(map[int]*fdDemuxEntry), + registrationToFD: make(map[uint64]int), + pipeFDs: pipeFDs, + } + return demux, nil +} + +func (d *FDDemultiplexer) Add(stream *reactorStream, fd int) error { + d.mutex.Lock() + defer d.mutex.Unlock() + + if d.closed.Load() { + return unix.EINVAL + } + + d.registrationCounter++ + registrationID := d.registrationCounter + + event := &unix.EpollEvent{Events: unix.EPOLLIN | unix.EPOLLRDHUP} + *(*uint64)(unsafe.Pointer(&event.Fd)) = registrationID + + err := unix.EpollCtl(d.epollFD, unix.EPOLL_CTL_ADD, fd, event) + if err != nil { + return err + } + + entry := &fdDemuxEntry{ + fd: fd, + registrationID: registrationID, + stream: stream, + } + d.entries[fd] = entry + d.registrationToFD[registrationID] = fd + + if !d.running { + d.running = true + d.wg.Add(1) + go d.run() + } + + return nil +} + +func (d *FDDemultiplexer) Remove(fd int) { + d.mutex.Lock() + defer d.mutex.Unlock() + + entry, ok := d.entries[fd] + if !ok { + return + } + + unix.EpollCtl(d.epollFD, unix.EPOLL_CTL_DEL, fd, nil) + delete(d.registrationToFD, entry.registrationID) + delete(d.entries, fd) +} + +func (d *FDDemultiplexer) wakeup() { + unix.Write(d.pipeFDs[1], []byte{0}) +} + +func (d *FDDemultiplexer) Close() error { + d.mutex.Lock() + d.closed.Store(true) + d.mutex.Unlock() + + d.cancel() + d.wakeup() + d.wg.Wait() + + d.mutex.Lock() + defer d.mutex.Unlock() + + if d.epollFD != -1 { + unix.Close(d.epollFD) + d.epollFD = -1 + } + if d.pipeFDs[0] != -1 { + unix.Close(d.pipeFDs[0]) + unix.Close(d.pipeFDs[1]) + d.pipeFDs[0] = -1 + d.pipeFDs[1] = -1 + } + return nil +} + +func (d *FDDemultiplexer) run() { + defer d.wg.Done() + + events := make([]unix.EpollEvent, 64) + var buffer [1]byte + + for { + select { + case <-d.ctx.Done(): + d.mutex.Lock() + d.running = false + d.mutex.Unlock() + return + default: + } + + n, err := unix.EpollWait(d.epollFD, events, -1) + if err != nil { + if err == unix.EINTR { + continue + } + d.mutex.Lock() + d.running = false + d.mutex.Unlock() + return + } + + for i := 0; i < n; i++ { + event := events[i] + registrationID := *(*uint64)(unsafe.Pointer(&event.Fd)) + + if registrationID == 0 { + unix.Read(d.pipeFDs[0], buffer[:]) + continue + } + + if event.Events&(unix.EPOLLIN|unix.EPOLLRDHUP|unix.EPOLLHUP|unix.EPOLLERR) == 0 { + continue + } + + d.mutex.Lock() + fd, ok := d.registrationToFD[registrationID] + if !ok { + d.mutex.Unlock() + continue + } + + entry := d.entries[fd] + if entry == nil || entry.registrationID != registrationID { + d.mutex.Unlock() + continue + } + + unix.EpollCtl(d.epollFD, unix.EPOLL_CTL_DEL, fd, nil) + delete(d.registrationToFD, registrationID) + delete(d.entries, fd) + d.mutex.Unlock() + + go entry.stream.runActiveLoop(nil) + } + + d.mutex.Lock() + if len(d.entries) == 0 { + d.running = false + d.mutex.Unlock() + return + } + d.mutex.Unlock() + } +} diff --git a/common/bufio/fd_demux_stub.go b/common/bufio/fd_demux_stub.go new file mode 100644 index 00000000..d2248cf5 --- /dev/null +++ b/common/bufio/fd_demux_stub.go @@ -0,0 +1,25 @@ +//go:build !linux && !darwin && !windows + +package bufio + +import ( + "context" + + E "github.com/sagernet/sing/common/exceptions" +) + +type FDDemultiplexer struct{} + +func NewFDDemultiplexer(ctx context.Context) (*FDDemultiplexer, error) { + return nil, E.New("FDDemultiplexer not supported on this platform") +} + +func (d *FDDemultiplexer) Add(stream *reactorStream, fd int) error { + return E.New("FDDemultiplexer not supported on this platform") +} + +func (d *FDDemultiplexer) Remove(fd int) {} + +func (d *FDDemultiplexer) Close() error { + return nil +} diff --git a/common/bufio/fd_demux_windows.go b/common/bufio/fd_demux_windows.go new file mode 100644 index 00000000..06795ebe --- /dev/null +++ b/common/bufio/fd_demux_windows.go @@ -0,0 +1,227 @@ +//go:build windows + +package bufio + +import ( + "context" + "sync" + "sync/atomic" + "unsafe" + + "github.com/sagernet/sing/common/wepoll" + + "golang.org/x/sys/windows" +) + +type fdDemuxEntry struct { + ioStatusBlock windows.IO_STATUS_BLOCK + pollInfo wepoll.AFDPollInfo + stream *reactorStream + fd int + handle windows.Handle + baseHandle windows.Handle + registrationID uint64 + cancelled bool + pinner wepoll.Pinner +} + +type FDDemultiplexer struct { + ctx context.Context + cancel context.CancelFunc + iocp windows.Handle + afd *wepoll.AFD + mutex sync.Mutex + entries map[int]*fdDemuxEntry + registrationCounter uint64 + running bool + closed atomic.Bool + wg sync.WaitGroup +} + +func NewFDDemultiplexer(ctx context.Context) (*FDDemultiplexer, error) { + iocp, err := windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0) + if err != nil { + return nil, err + } + + afd, err := wepoll.NewAFD(iocp, "Go") + if err != nil { + windows.CloseHandle(iocp) + return nil, err + } + + ctx, cancel := context.WithCancel(ctx) + demux := &FDDemultiplexer{ + ctx: ctx, + cancel: cancel, + iocp: iocp, + afd: afd, + entries: make(map[int]*fdDemuxEntry), + } + return demux, nil +} + +func (d *FDDemultiplexer) Add(stream *reactorStream, fd int) error { + d.mutex.Lock() + defer d.mutex.Unlock() + + if d.closed.Load() { + return windows.ERROR_INVALID_HANDLE + } + + handle := windows.Handle(fd) + baseHandle, err := wepoll.GetBaseSocket(handle) + if err != nil { + return err + } + + d.registrationCounter++ + registrationID := d.registrationCounter + + entry := &fdDemuxEntry{ + stream: stream, + fd: fd, + handle: handle, + baseHandle: baseHandle, + registrationID: registrationID, + } + + entry.pinner.Pin(entry) + + events := uint32(wepoll.AFD_POLL_RECEIVE | wepoll.AFD_POLL_DISCONNECT | wepoll.AFD_POLL_ABORT | wepoll.AFD_POLL_LOCAL_CLOSE) + err = d.afd.Poll(baseHandle, events, &entry.ioStatusBlock, &entry.pollInfo) + if err != nil { + entry.pinner.Unpin() + return err + } + + d.entries[fd] = entry + + if !d.running { + d.running = true + d.wg.Add(1) + go d.run() + } + + return nil +} + +func (d *FDDemultiplexer) Remove(fd int) { + d.mutex.Lock() + defer d.mutex.Unlock() + + entry, ok := d.entries[fd] + if !ok { + return + } + + entry.cancelled = true + if d.afd != nil { + d.afd.Cancel(&entry.ioStatusBlock) + } +} + +func (d *FDDemultiplexer) wakeup() { + windows.PostQueuedCompletionStatus(d.iocp, 0, 0, nil) +} + +func (d *FDDemultiplexer) Close() error { + d.mutex.Lock() + d.closed.Store(true) + d.mutex.Unlock() + + d.cancel() + d.wakeup() + d.wg.Wait() + + d.mutex.Lock() + defer d.mutex.Unlock() + + for fd, entry := range d.entries { + entry.pinner.Unpin() + delete(d.entries, fd) + } + + if d.afd != nil { + d.afd.Close() + d.afd = nil + } + if d.iocp != 0 { + windows.CloseHandle(d.iocp) + d.iocp = 0 + } + return nil +} + +func (d *FDDemultiplexer) run() { + defer d.wg.Done() + + completions := make([]wepoll.OverlappedEntry, 64) + + for { + select { + case <-d.ctx.Done(): + d.mutex.Lock() + d.running = false + d.mutex.Unlock() + return + default: + } + + var numRemoved uint32 + err := wepoll.GetQueuedCompletionStatusEx(d.iocp, &completions[0], 64, &numRemoved, windows.INFINITE, false) + if err != nil { + d.mutex.Lock() + d.running = false + d.mutex.Unlock() + return + } + + for i := uint32(0); i < numRemoved; i++ { + event := completions[i] + + if event.Overlapped == nil { + continue + } + + entry := (*fdDemuxEntry)(unsafe.Pointer(event.Overlapped)) + + d.mutex.Lock() + + if d.entries[entry.fd] != entry { + d.mutex.Unlock() + continue + } + + entry.pinner.Unpin() + delete(d.entries, entry.fd) + + if entry.cancelled { + d.mutex.Unlock() + continue + } + + if uint32(entry.ioStatusBlock.Status) == wepoll.STATUS_CANCELLED { + d.mutex.Unlock() + continue + } + + events := entry.pollInfo.Handles[0].Events + if events&(wepoll.AFD_POLL_RECEIVE|wepoll.AFD_POLL_DISCONNECT|wepoll.AFD_POLL_ABORT|wepoll.AFD_POLL_LOCAL_CLOSE) == 0 { + d.mutex.Unlock() + continue + } + + d.mutex.Unlock() + go entry.stream.runActiveLoop(nil) + } + + d.mutex.Lock() + if len(d.entries) == 0 { + d.running = false + d.mutex.Unlock() + return + } + d.mutex.Unlock() + } +} diff --git a/common/bufio/fd_demux_windows_test.go b/common/bufio/fd_demux_windows_test.go new file mode 100644 index 00000000..030a3dfb --- /dev/null +++ b/common/bufio/fd_demux_windows_test.go @@ -0,0 +1,305 @@ +//go:build windows + +package bufio + +import ( + "context" + "net" + "sync" + "sync/atomic" + "syscall" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func getSocketFD(t *testing.T, conn net.PacketConn) int { + syscallConn, ok := conn.(syscall.Conn) + require.True(t, ok) + rawConn, err := syscallConn.SyscallConn() + require.NoError(t, err) + var fd int + err = rawConn.Control(func(f uintptr) { fd = int(f) }) + require.NoError(t, err) + return fd +} + +func TestFDDemultiplexer_Create(t *testing.T) { + t.Parallel() + + demux, err := NewFDDemultiplexer(context.Background()) + require.NoError(t, err) + + err = demux.Close() + require.NoError(t, err) +} + +func TestFDDemultiplexer_CreateMultiple(t *testing.T) { + t.Parallel() + + demux1, err := NewFDDemultiplexer(context.Background()) + require.NoError(t, err) + defer demux1.Close() + + demux2, err := NewFDDemultiplexer(context.Background()) + require.NoError(t, err) + defer demux2.Close() +} + +func TestFDDemultiplexer_AddRemove(t *testing.T) { + t.Parallel() + + demux, err := NewFDDemultiplexer(context.Background()) + require.NoError(t, err) + defer demux.Close() + + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer conn.Close() + + fd := getSocketFD(t, conn) + + stream := &reactorStream{} + + err = demux.Add(stream, fd) + require.NoError(t, err) + + demux.Remove(fd) +} + +func TestFDDemultiplexer_RapidAddRemove(t *testing.T) { + t.Parallel() + + demux, err := NewFDDemultiplexer(context.Background()) + require.NoError(t, err) + defer demux.Close() + + const iterations = 50 + + for i := 0; i < iterations; i++ { + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + + fd := getSocketFD(t, conn) + stream := &reactorStream{} + + err = demux.Add(stream, fd) + require.NoError(t, err) + + demux.Remove(fd) + conn.Close() + } +} + +func TestFDDemultiplexer_ConcurrentAccess(t *testing.T) { + t.Parallel() + + demux, err := NewFDDemultiplexer(context.Background()) + require.NoError(t, err) + defer demux.Close() + + const numGoroutines = 10 + const iterations = 20 + + var wg sync.WaitGroup + wg.Add(numGoroutines) + + for g := 0; g < numGoroutines; g++ { + go func() { + defer wg.Done() + + for i := 0; i < iterations; i++ { + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + continue + } + + fd := getSocketFD(t, conn) + stream := &reactorStream{} + + err = demux.Add(stream, fd) + if err == nil { + demux.Remove(fd) + } + conn.Close() + } + }() + } + + wg.Wait() +} + +func TestFDDemultiplexer_ReceiveEvent(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + demux, err := NewFDDemultiplexer(ctx) + require.NoError(t, err) + defer demux.Close() + + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer conn.Close() + + fd := getSocketFD(t, conn) + + triggered := make(chan struct{}, 1) + stream := &reactorStream{ + state: atomic.Int32{}, + } + stream.connection = &reactorConnection{ + upload: stream, + download: stream, + done: make(chan struct{}), + } + + originalRunActiveLoop := stream.runActiveLoop + _ = originalRunActiveLoop + + err = demux.Add(stream, fd) + require.NoError(t, err) + + sender, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer sender.Close() + + _, err = sender.WriteTo([]byte("test data"), conn.LocalAddr()) + require.NoError(t, err) + + time.Sleep(200 * time.Millisecond) + + select { + case <-triggered: + default: + } + + demux.Remove(fd) +} + +func TestFDDemultiplexer_CloseWhilePolling(t *testing.T) { + t.Parallel() + + demux, err := NewFDDemultiplexer(context.Background()) + require.NoError(t, err) + + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer conn.Close() + + fd := getSocketFD(t, conn) + stream := &reactorStream{} + + err = demux.Add(stream, fd) + require.NoError(t, err) + + time.Sleep(50 * time.Millisecond) + + done := make(chan struct{}) + go func() { + demux.Close() + close(done) + }() + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("Close blocked - possible deadlock") + } +} + +func TestFDDemultiplexer_RemoveNonExistent(t *testing.T) { + t.Parallel() + + demux, err := NewFDDemultiplexer(context.Background()) + require.NoError(t, err) + defer demux.Close() + + demux.Remove(99999) +} + +func TestFDDemultiplexer_AddAfterClose(t *testing.T) { + t.Parallel() + + demux, err := NewFDDemultiplexer(context.Background()) + require.NoError(t, err) + + err = demux.Close() + require.NoError(t, err) + + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer conn.Close() + + fd := getSocketFD(t, conn) + stream := &reactorStream{} + + err = demux.Add(stream, fd) + require.Error(t, err) +} + +func TestFDDemultiplexer_MultipleSocketsSimultaneous(t *testing.T) { + t.Parallel() + + demux, err := NewFDDemultiplexer(context.Background()) + require.NoError(t, err) + defer demux.Close() + + const numSockets = 5 + conns := make([]net.PacketConn, numSockets) + fds := make([]int, numSockets) + + for i := 0; i < numSockets; i++ { + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer conn.Close() + conns[i] = conn + + fd := getSocketFD(t, conn) + fds[i] = fd + + stream := &reactorStream{} + err = demux.Add(stream, fd) + require.NoError(t, err) + } + + for i := 0; i < numSockets; i++ { + demux.Remove(fds[i]) + } +} + +func TestFDDemultiplexer_ContextCancellation(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + demux, err := NewFDDemultiplexer(ctx) + require.NoError(t, err) + + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer conn.Close() + + fd := getSocketFD(t, conn) + stream := &reactorStream{} + + err = demux.Add(stream, fd) + require.NoError(t, err) + + cancel() + + time.Sleep(100 * time.Millisecond) + + done := make(chan struct{}) + go func() { + demux.Close() + close(done) + }() + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("Close blocked after context cancellation") + } +} diff --git a/common/bufio/packet_reactor.go b/common/bufio/packet_reactor.go new file mode 100644 index 00000000..4ded4ed6 --- /dev/null +++ b/common/bufio/packet_reactor.go @@ -0,0 +1,390 @@ +package bufio + +import ( + "context" + "sync" + "sync/atomic" + "time" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +const ( + batchReadTimeout = 250 * time.Millisecond +) + +const ( + stateIdle int32 = 0 + stateActive int32 = 1 + stateClosed int32 = 2 +) + +type PacketReactor struct { + ctx context.Context + cancel context.CancelFunc + channelDemux *ChannelDemultiplexer + fdDemux *FDDemultiplexer + fdDemuxOnce sync.Once + fdDemuxErr error +} + +func NewPacketReactor(ctx context.Context) *PacketReactor { + ctx, cancel := context.WithCancel(ctx) + return &PacketReactor{ + ctx: ctx, + cancel: cancel, + channelDemux: NewChannelDemultiplexer(ctx), + } +} + +func (r *PacketReactor) getFDDemultiplexer() (*FDDemultiplexer, error) { + r.fdDemuxOnce.Do(func() { + r.fdDemux, r.fdDemuxErr = NewFDDemultiplexer(r.ctx) + }) + return r.fdDemux, r.fdDemuxErr +} + +func (r *PacketReactor) Close() error { + r.cancel() + var errs []error + if r.channelDemux != nil { + errs = append(errs, r.channelDemux.Close()) + } + if r.fdDemux != nil { + errs = append(errs, r.fdDemux.Close()) + } + return E.Errors(errs...) +} + +type reactorConnection struct { + ctx context.Context + cancel context.CancelFunc + reactor *PacketReactor + onClose N.CloseHandlerFunc + upload *reactorStream + download *reactorStream + + closeOnce sync.Once + done chan struct{} + err error +} + +type reactorStream struct { + connection *reactorConnection + + source N.PacketReader + destination N.PacketWriter + originSource N.PacketReader + + notifier N.ReadNotifier + options N.ReadWaitOptions + readWaiter N.PacketReadWaiter + readCounters []N.CountFunc + writeCounters []N.CountFunc + + state atomic.Int32 +} + +func (r *PacketReactor) Copy(ctx context.Context, source N.PacketConn, destination N.PacketConn, onClose N.CloseHandlerFunc) { + ctx, cancel := context.WithCancel(ctx) + conn := &reactorConnection{ + ctx: ctx, + cancel: cancel, + reactor: r, + onClose: onClose, + done: make(chan struct{}), + } + + conn.upload = r.prepareStream(conn, source, destination) + select { + case <-conn.done: + return + default: + } + + conn.download = r.prepareStream(conn, destination, source) + select { + case <-conn.done: + return + default: + } + + r.registerStream(conn.upload) + r.registerStream(conn.download) +} + +func (r *PacketReactor) prepareStream(conn *reactorConnection, source N.PacketReader, destination N.PacketWriter) *reactorStream { + stream := &reactorStream{ + connection: conn, + source: source, + destination: destination, + originSource: source, + } + + for { + source, stream.readCounters = N.UnwrapCountPacketReader(source, stream.readCounters) + destination, stream.writeCounters = N.UnwrapCountPacketWriter(destination, stream.writeCounters) + if cachedReader, isCached := source.(N.CachedPacketReader); isCached { + packet := cachedReader.ReadCachedPacket() + if packet != nil { + dataLen := packet.Buffer.Len() + err := destination.WritePacket(packet.Buffer, packet.Destination) + N.PutPacketBuffer(packet) + if err != nil { + conn.closeWithError(err) + return stream + } + for _, counter := range stream.readCounters { + counter(int64(dataLen)) + } + for _, counter := range stream.writeCounters { + counter(int64(dataLen)) + } + continue + } + } + break + } + stream.source = source + stream.destination = destination + + stream.options = N.NewReadWaitOptions(source, destination) + + stream.readWaiter, _ = CreatePacketReadWaiter(source) + if stream.readWaiter != nil { + stream.readWaiter.InitializeReadWaiter(stream.options) + } + + if notifierSource, ok := source.(N.ReadNotifierSource); ok { + stream.notifier = notifierSource.CreateReadNotifier() + } + + return stream +} + +func (r *PacketReactor) registerStream(stream *reactorStream) { + if stream.notifier == nil { + go stream.runLegacyCopy() + return + } + + switch notifier := stream.notifier.(type) { + case *N.ChannelNotifier: + r.channelDemux.Add(stream, notifier.Channel) + case *N.FileDescriptorNotifier: + fdDemux, err := r.getFDDemultiplexer() + if err != nil { + go stream.runLegacyCopy() + return + } + err = fdDemux.Add(stream, notifier.FD) + if err != nil { + go stream.runLegacyCopy() + } + default: + go stream.runLegacyCopy() + } +} + +func (s *reactorStream) runActiveLoop(firstPacket *N.PacketBuffer) { + if s.source == nil { + if firstPacket != nil { + firstPacket.Buffer.Release() + N.PutPacketBuffer(firstPacket) + } + return + } + if !s.state.CompareAndSwap(stateIdle, stateActive) { + if firstPacket != nil { + firstPacket.Buffer.Release() + N.PutPacketBuffer(firstPacket) + } + return + } + + notFirstTime := false + + if firstPacket != nil { + err := s.writePacketWithCounters(firstPacket) + if err != nil { + s.closeWithError(err) + return + } + notFirstTime = true + } + + for { + if s.state.Load() == stateClosed { + return + } + + if setter, ok := s.source.(interface{ SetReadDeadline(time.Time) error }); ok { + setter.SetReadDeadline(time.Now().Add(batchReadTimeout)) + } + + var ( + buffer *N.PacketBuffer + destination M.Socksaddr + err error + ) + + if s.readWaiter != nil { + var readBuffer *buf.Buffer + readBuffer, destination, err = s.readWaiter.WaitReadPacket() + if readBuffer != nil { + buffer = N.NewPacketBuffer() + buffer.Buffer = readBuffer + buffer.Destination = destination + } + } else { + readBuffer := s.options.NewPacketBuffer() + destination, err = s.source.ReadPacket(readBuffer) + if err != nil { + readBuffer.Release() + } else { + buffer = N.NewPacketBuffer() + buffer.Buffer = readBuffer + buffer.Destination = destination + } + } + + if err != nil { + if E.IsTimeout(err) { + if setter, ok := s.source.(interface{ SetReadDeadline(time.Time) error }); ok { + setter.SetReadDeadline(time.Time{}) + } + if s.state.CompareAndSwap(stateActive, stateIdle) { + s.returnToPool() + } + return + } + if !notFirstTime { + err = N.ReportHandshakeFailure(s.originSource, err) + } + s.closeWithError(err) + return + } + + err = s.writePacketWithCounters(buffer) + if err != nil { + if !notFirstTime { + err = N.ReportHandshakeFailure(s.originSource, err) + } + s.closeWithError(err) + return + } + notFirstTime = true + } +} + +func (s *reactorStream) writePacketWithCounters(packet *N.PacketBuffer) error { + buffer := packet.Buffer + destination := packet.Destination + dataLen := buffer.Len() + + s.options.PostReturn(buffer) + err := s.destination.WritePacket(buffer, destination) + N.PutPacketBuffer(packet) + if err != nil { + buffer.Leak() + return err + } + + for _, counter := range s.readCounters { + counter(int64(dataLen)) + } + for _, counter := range s.writeCounters { + counter(int64(dataLen)) + } + return nil +} + +func (s *reactorStream) returnToPool() { + if s.state.Load() != stateIdle { + return + } + + switch notifier := s.notifier.(type) { + case *N.ChannelNotifier: + s.connection.reactor.channelDemux.Add(s, notifier.Channel) + if s.state.Load() != stateIdle { + s.connection.reactor.channelDemux.Remove(notifier.Channel) + } + case *N.FileDescriptorNotifier: + if s.connection.reactor.fdDemux != nil { + err := s.connection.reactor.fdDemux.Add(s, notifier.FD) + if err != nil { + s.closeWithError(err) + return + } + if s.state.Load() != stateIdle { + s.connection.reactor.fdDemux.Remove(notifier.FD) + } + } + } +} + +func (s *reactorStream) runLegacyCopy() { + _, err := CopyPacket(s.destination, s.source) + s.closeWithError(err) +} + +func (s *reactorStream) closeWithError(err error) { + s.connection.closeWithError(err) +} + +func (c *reactorConnection) closeWithError(err error) { + c.closeOnce.Do(func() { + c.err = err + c.cancel() + + if c.upload != nil { + c.upload.state.Store(stateClosed) + } + if c.download != nil { + c.download.state.Store(stateClosed) + } + + c.removeFromDemultiplexers() + + if c.upload != nil { + common.Close(c.upload.originSource) + } + if c.download != nil { + common.Close(c.download.originSource) + } + + if c.onClose != nil { + c.onClose(c.err) + } + + close(c.done) + }) +} + +func (c *reactorConnection) removeFromDemultiplexers() { + if c.upload != nil && c.upload.notifier != nil { + switch notifier := c.upload.notifier.(type) { + case *N.ChannelNotifier: + c.reactor.channelDemux.Remove(notifier.Channel) + case *N.FileDescriptorNotifier: + if c.reactor.fdDemux != nil { + c.reactor.fdDemux.Remove(notifier.FD) + } + } + } + if c.download != nil && c.download.notifier != nil { + switch notifier := c.download.notifier.(type) { + case *N.ChannelNotifier: + c.reactor.channelDemux.Remove(notifier.Channel) + case *N.FileDescriptorNotifier: + if c.reactor.fdDemux != nil { + c.reactor.fdDemux.Remove(notifier.FD) + } + } + } +} diff --git a/common/bufio/packet_reactor_test.go b/common/bufio/packet_reactor_test.go new file mode 100644 index 00000000..ae155130 --- /dev/null +++ b/common/bufio/packet_reactor_test.go @@ -0,0 +1,1485 @@ +//go:build darwin || linux || windows + +package bufio + +import ( + "context" + "crypto/md5" + "crypto/rand" + "errors" + "io" + "net" + "os" + "sync" + "sync/atomic" + "syscall" + "testing" + "time" + + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type testPacketPipe struct { + inChan chan *N.PacketBuffer + outChan chan *N.PacketBuffer + localAddr M.Socksaddr + closed atomic.Bool + closeOnce sync.Once + done chan struct{} +} + +func newTestPacketPipe(localAddr M.Socksaddr) *testPacketPipe { + return &testPacketPipe{ + inChan: make(chan *N.PacketBuffer, 256), + outChan: make(chan *N.PacketBuffer, 256), + localAddr: localAddr, + done: make(chan struct{}), + } +} + +func (p *testPacketPipe) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { + select { + case packet, ok := <-p.inChan: + if !ok { + return M.Socksaddr{}, io.EOF + } + _, err = buffer.ReadOnceFrom(packet.Buffer) + destination = packet.Destination + packet.Buffer.Release() + N.PutPacketBuffer(packet) + return destination, err + case <-p.done: + return M.Socksaddr{}, net.ErrClosed + } +} + +func (p *testPacketPipe) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + if p.closed.Load() { + buffer.Release() + return net.ErrClosed + } + packet := N.NewPacketBuffer() + newBuf := buf.NewSize(buffer.Len()) + newBuf.Write(buffer.Bytes()) + packet.Buffer = newBuf + packet.Destination = destination + buffer.Release() + select { + case p.outChan <- packet: + return nil + case <-p.done: + packet.Buffer.Release() + N.PutPacketBuffer(packet) + return net.ErrClosed + } +} + +func (p *testPacketPipe) Close() error { + p.closeOnce.Do(func() { + p.closed.Store(true) + close(p.done) + }) + return nil +} + +func (p *testPacketPipe) LocalAddr() net.Addr { + return p.localAddr.UDPAddr() +} + +func (p *testPacketPipe) SetDeadline(t time.Time) error { + return nil +} + +func (p *testPacketPipe) SetReadDeadline(t time.Time) error { + return nil +} + +func (p *testPacketPipe) SetWriteDeadline(t time.Time) error { + return nil +} + +func (p *testPacketPipe) CreateReadNotifier() N.ReadNotifier { + return &N.ChannelNotifier{Channel: p.inChan} +} + +func (p *testPacketPipe) send(data []byte, destination M.Socksaddr) { + packet := N.NewPacketBuffer() + newBuf := buf.NewSize(len(data)) + newBuf.Write(data) + packet.Buffer = newBuf + packet.Destination = destination + p.inChan <- packet +} + +func (p *testPacketPipe) receive() (*N.PacketBuffer, bool) { + select { + case packet, ok := <-p.outChan: + return packet, ok + case <-p.done: + return nil, false + } +} + +type fdPacketConn struct { + N.NetPacketConn + fd int + targetAddr M.Socksaddr +} + +func newFDPacketConn(t *testing.T, conn net.PacketConn, targetAddr M.Socksaddr) *fdPacketConn { + syscallConn, ok := conn.(syscall.Conn) + require.True(t, ok, "connection must implement syscall.Conn") + rawConn, err := syscallConn.SyscallConn() + require.NoError(t, err) + var fd int + err = rawConn.Control(func(f uintptr) { fd = int(f) }) + require.NoError(t, err) + return &fdPacketConn{ + NetPacketConn: NewPacketConn(conn), + fd: fd, + targetAddr: targetAddr, + } +} + +func (c *fdPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { + _, err = c.NetPacketConn.ReadPacket(buffer) + if err != nil { + return M.Socksaddr{}, err + } + return c.targetAddr, nil +} + +func (c *fdPacketConn) CreateReadNotifier() N.ReadNotifier { + return &N.FileDescriptorNotifier{FD: c.fd} +} + +type channelPacketConn struct { + N.NetPacketConn + packetChan chan *N.PacketBuffer + done chan struct{} + closeOnce sync.Once + targetAddr M.Socksaddr + deadlineLock sync.Mutex + deadline time.Time + deadlineChan chan struct{} +} + +func newChannelPacketConn(conn net.PacketConn, targetAddr M.Socksaddr) *channelPacketConn { + c := &channelPacketConn{ + NetPacketConn: NewPacketConn(conn), + packetChan: make(chan *N.PacketBuffer, 256), + done: make(chan struct{}), + targetAddr: targetAddr, + deadlineChan: make(chan struct{}), + } + go c.readLoop() + return c +} + +func (c *channelPacketConn) readLoop() { + for { + select { + case <-c.done: + return + default: + } + buffer := buf.NewPacket() + _, err := c.NetPacketConn.ReadPacket(buffer) + if err != nil { + buffer.Release() + close(c.packetChan) + return + } + packet := N.NewPacketBuffer() + packet.Buffer = buffer + packet.Destination = c.targetAddr + select { + case c.packetChan <- packet: + case <-c.done: + buffer.Release() + N.PutPacketBuffer(packet) + return + } + } +} + +func (c *channelPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { + c.deadlineLock.Lock() + deadline := c.deadline + deadlineChan := c.deadlineChan + c.deadlineLock.Unlock() + + var timer <-chan time.Time + if !deadline.IsZero() { + d := time.Until(deadline) + if d <= 0 { + return M.Socksaddr{}, os.ErrDeadlineExceeded + } + t := time.NewTimer(d) + defer t.Stop() + timer = t.C + } + + select { + case packet, ok := <-c.packetChan: + if !ok { + return M.Socksaddr{}, net.ErrClosed + } + _, err = buffer.ReadOnceFrom(packet.Buffer) + destination = packet.Destination + packet.Buffer.Release() + N.PutPacketBuffer(packet) + return + case <-c.done: + return M.Socksaddr{}, net.ErrClosed + case <-deadlineChan: + return M.Socksaddr{}, os.ErrDeadlineExceeded + case <-timer: + return M.Socksaddr{}, os.ErrDeadlineExceeded + } +} + +func (c *channelPacketConn) SetReadDeadline(t time.Time) error { + c.deadlineLock.Lock() + c.deadline = t + if c.deadlineChan != nil { + close(c.deadlineChan) + } + c.deadlineChan = make(chan struct{}) + c.deadlineLock.Unlock() + return nil +} + +func (c *channelPacketConn) CreateReadNotifier() N.ReadNotifier { + return &N.ChannelNotifier{Channel: c.packetChan} +} + +func (c *channelPacketConn) Close() error { + c.closeOnce.Do(func() { + close(c.done) + }) + return c.NetPacketConn.Close() +} + +type batchHashPair struct { + sendHash map[int][]byte + recvHash map[int][]byte +} + +func TestBatchCopy_Pipe_DataIntegrity(t *testing.T) { + t.Parallel() + + addr1 := M.ParseSocksaddrHostPort("127.0.0.1", 10001) + addr2 := M.ParseSocksaddrHostPort("127.0.0.1", 10002) + + pipeA := newTestPacketPipe(addr1) + pipeB := newTestPacketPipe(addr2) + defer pipeA.Close() + defer pipeB.Close() + + copier := NewPacketReactor(context.Background()) + defer copier.Close() + + go func() { + copier.Copy(context.Background(), pipeA, pipeB, nil) + }() + + time.Sleep(50 * time.Millisecond) + + const times = 50 + const chunkSize = 9000 + + sendHash := make(map[int][]byte) + recvHash := make(map[int][]byte) + + recvDone := make(chan struct{}) + go func() { + defer close(recvDone) + for i := 0; i < times; i++ { + packet, ok := pipeB.receive() + if !ok { + t.Logf("recv channel closed at %d", i) + return + } + hash := md5.Sum(packet.Buffer.Bytes()) + recvHash[int(packet.Buffer.Byte(0))] = hash[:] + packet.Buffer.Release() + N.PutPacketBuffer(packet) + } + }() + + for i := 0; i < times; i++ { + data := make([]byte, chunkSize) + rand.Read(data[1:]) + data[0] = byte(i) + hash := md5.Sum(data) + sendHash[i] = hash[:] + pipeA.send(data, addr2) + time.Sleep(5 * time.Millisecond) + } + + select { + case <-recvDone: + case <-time.After(10 * time.Second): + t.Fatal("timeout waiting for receive") + } + + assert.Equal(t, sendHash, recvHash, "data mismatch") +} + +func TestBatchCopy_Pipe_Bidirectional(t *testing.T) { + t.Parallel() + + addr1 := M.ParseSocksaddrHostPort("127.0.0.1", 10001) + addr2 := M.ParseSocksaddrHostPort("127.0.0.1", 10002) + + pipeA := newTestPacketPipe(addr1) + pipeB := newTestPacketPipe(addr2) + defer pipeA.Close() + defer pipeB.Close() + + copier := NewPacketReactor(context.Background()) + defer copier.Close() + + go func() { + copier.Copy(context.Background(), pipeA, pipeB, nil) + }() + + time.Sleep(50 * time.Millisecond) + + const times = 50 + const chunkSize = 9000 + + pingCh := make(chan batchHashPair, 1) + pongCh := make(chan batchHashPair, 1) + + go func() { + sendHash := make(map[int][]byte) + recvHash := make(map[int][]byte) + + for i := 0; i < times; i++ { + data := make([]byte, chunkSize) + rand.Read(data[1:]) + data[0] = byte(i) + hash := md5.Sum(data) + sendHash[i] = hash[:] + pipeA.send(data, addr2) + time.Sleep(5 * time.Millisecond) + } + + for i := 0; i < times; i++ { + packet, ok := pipeA.receive() + if !ok { + return + } + hash := md5.Sum(packet.Buffer.Bytes()) + recvHash[int(packet.Buffer.Byte(0))] = hash[:] + packet.Buffer.Release() + N.PutPacketBuffer(packet) + } + + pingCh <- batchHashPair{sendHash: sendHash, recvHash: recvHash} + }() + + go func() { + sendHash := make(map[int][]byte) + recvHash := make(map[int][]byte) + + for i := 0; i < times; i++ { + packet, ok := pipeB.receive() + if !ok { + return + } + hash := md5.Sum(packet.Buffer.Bytes()) + recvHash[int(packet.Buffer.Byte(0))] = hash[:] + packet.Buffer.Release() + N.PutPacketBuffer(packet) + } + + for i := 0; i < times; i++ { + data := make([]byte, chunkSize) + rand.Read(data[1:]) + data[0] = byte(i) + hash := md5.Sum(data) + sendHash[i] = hash[:] + pipeB.send(data, addr1) + time.Sleep(5 * time.Millisecond) + } + + pongCh <- batchHashPair{sendHash: sendHash, recvHash: recvHash} + }() + + var aPair, bPair batchHashPair + select { + case aPair = <-pingCh: + case <-time.After(10 * time.Second): + t.Fatal("timeout waiting for A") + } + select { + case bPair = <-pongCh: + case <-time.After(10 * time.Second): + t.Fatal("timeout waiting for B") + } + + assert.Equal(t, aPair.sendHash, bPair.recvHash, "A->B mismatch") + assert.Equal(t, bPair.sendHash, aPair.recvHash, "B->A mismatch") +} + +func TestBatchCopy_FDPoller_DataIntegrity(t *testing.T) { + t.Parallel() + + clientConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer clientConn.Close() + + proxyAConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + + proxyBConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + + serverConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer serverConn.Close() + + serverAddr := M.SocksaddrFromNet(serverConn.LocalAddr()) + clientAddr := M.SocksaddrFromNet(clientConn.LocalAddr()) + proxyAAddr := M.SocksaddrFromNet(proxyAConn.LocalAddr()) + proxyBAddr := M.SocksaddrFromNet(proxyBConn.LocalAddr()) + + proxyA := newFDPacketConn(t, proxyAConn, serverAddr) + proxyB := newFDPacketConn(t, proxyBConn, clientAddr) + + copier := NewPacketReactor(context.Background()) + defer copier.Close() + + go func() { + copier.Copy(context.Background(), proxyA, proxyB, nil) + }() + + time.Sleep(50 * time.Millisecond) + + const times = 50 + const chunkSize = 9000 + + pingCh := make(chan batchHashPair, 1) + pongCh := make(chan batchHashPair, 1) + errCh := make(chan error, 2) + + go func() { + sendHash := make(map[int][]byte) + recvHash := make(map[int][]byte) + recvBuf := make([]byte, 65536) + + for i := 0; i < times; i++ { + data := make([]byte, chunkSize) + rand.Read(data[1:]) + data[0] = byte(i) + hash := md5.Sum(data) + sendHash[i] = hash[:] + _, err := clientConn.WriteTo(data, proxyAAddr.UDPAddr()) + if err != nil { + errCh <- err + return + } + time.Sleep(5 * time.Millisecond) + } + + for i := 0; i < times; i++ { + n, _, err := clientConn.ReadFrom(recvBuf) + if err != nil { + errCh <- err + return + } + hash := md5.Sum(recvBuf[:n]) + recvHash[int(recvBuf[0])] = hash[:] + } + + pingCh <- batchHashPair{sendHash: sendHash, recvHash: recvHash} + }() + + go func() { + sendHash := make(map[int][]byte) + recvHash := make(map[int][]byte) + recvBuf := make([]byte, 65536) + + for i := 0; i < times; i++ { + n, _, err := serverConn.ReadFrom(recvBuf) + if err != nil { + errCh <- err + return + } + hash := md5.Sum(recvBuf[:n]) + recvHash[int(recvBuf[0])] = hash[:] + } + + for i := 0; i < times; i++ { + data := make([]byte, chunkSize) + rand.Read(data[1:]) + data[0] = byte(i) + hash := md5.Sum(data) + sendHash[i] = hash[:] + _, err := serverConn.WriteTo(data, proxyBAddr.UDPAddr()) + if err != nil { + errCh <- err + return + } + time.Sleep(5 * time.Millisecond) + } + + pongCh <- batchHashPair{sendHash: sendHash, recvHash: recvHash} + }() + + var clientPair, serverPair batchHashPair + for i := 0; i < 2; i++ { + select { + case clientPair = <-pingCh: + case serverPair = <-pongCh: + case err := <-errCh: + t.Fatal(err) + case <-time.After(15 * time.Second): + t.Fatal("timeout") + } + } + + assert.Equal(t, clientPair.sendHash, serverPair.recvHash, "client->server mismatch") + assert.Equal(t, serverPair.sendHash, clientPair.recvHash, "server->client mismatch") +} + +func TestBatchCopy_ChannelPoller_DataIntegrity(t *testing.T) { + t.Parallel() + + clientConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer clientConn.Close() + + proxyAConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + + proxyBConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + + serverConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer serverConn.Close() + + serverAddr := M.SocksaddrFromNet(serverConn.LocalAddr()) + clientAddr := M.SocksaddrFromNet(clientConn.LocalAddr()) + proxyAAddr := M.SocksaddrFromNet(proxyAConn.LocalAddr()) + proxyBAddr := M.SocksaddrFromNet(proxyBConn.LocalAddr()) + + proxyA := newChannelPacketConn(proxyAConn, serverAddr) + proxyB := newChannelPacketConn(proxyBConn, clientAddr) + + copier := NewPacketReactor(context.Background()) + defer copier.Close() + + go func() { + copier.Copy(context.Background(), proxyA, proxyB, nil) + }() + + time.Sleep(50 * time.Millisecond) + + const times = 50 + const chunkSize = 9000 + + pingCh := make(chan batchHashPair, 1) + pongCh := make(chan batchHashPair, 1) + errCh := make(chan error, 2) + + go func() { + sendHash := make(map[int][]byte) + recvHash := make(map[int][]byte) + recvBuf := make([]byte, 65536) + + for i := 0; i < times; i++ { + data := make([]byte, chunkSize) + rand.Read(data[1:]) + data[0] = byte(i) + hash := md5.Sum(data) + sendHash[i] = hash[:] + _, err := clientConn.WriteTo(data, proxyAAddr.UDPAddr()) + if err != nil { + errCh <- err + return + } + time.Sleep(5 * time.Millisecond) + } + + for i := 0; i < times; i++ { + n, _, err := clientConn.ReadFrom(recvBuf) + if err != nil { + errCh <- err + return + } + hash := md5.Sum(recvBuf[:n]) + recvHash[int(recvBuf[0])] = hash[:] + } + + pingCh <- batchHashPair{sendHash: sendHash, recvHash: recvHash} + }() + + go func() { + sendHash := make(map[int][]byte) + recvHash := make(map[int][]byte) + recvBuf := make([]byte, 65536) + + for i := 0; i < times; i++ { + n, _, err := serverConn.ReadFrom(recvBuf) + if err != nil { + errCh <- err + return + } + hash := md5.Sum(recvBuf[:n]) + recvHash[int(recvBuf[0])] = hash[:] + } + + for i := 0; i < times; i++ { + data := make([]byte, chunkSize) + rand.Read(data[1:]) + data[0] = byte(i) + hash := md5.Sum(data) + sendHash[i] = hash[:] + _, err := serverConn.WriteTo(data, proxyBAddr.UDPAddr()) + if err != nil { + errCh <- err + return + } + time.Sleep(5 * time.Millisecond) + } + + pongCh <- batchHashPair{sendHash: sendHash, recvHash: recvHash} + }() + + var clientPair, serverPair batchHashPair + for i := 0; i < 2; i++ { + select { + case clientPair = <-pingCh: + case serverPair = <-pongCh: + case err := <-errCh: + t.Fatal(err) + case <-time.After(15 * time.Second): + t.Fatal("timeout") + } + } + + assert.Equal(t, clientPair.sendHash, serverPair.recvHash, "client->server mismatch") + assert.Equal(t, serverPair.sendHash, clientPair.recvHash, "server->client mismatch") +} + +func TestBatchCopy_MixedMode_DataIntegrity(t *testing.T) { + t.Parallel() + + clientConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer clientConn.Close() + + proxyAConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + + proxyBConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + + serverConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer serverConn.Close() + + serverAddr := M.SocksaddrFromNet(serverConn.LocalAddr()) + clientAddr := M.SocksaddrFromNet(clientConn.LocalAddr()) + proxyAAddr := M.SocksaddrFromNet(proxyAConn.LocalAddr()) + proxyBAddr := M.SocksaddrFromNet(proxyBConn.LocalAddr()) + + proxyA := newFDPacketConn(t, proxyAConn, serverAddr) + proxyB := newChannelPacketConn(proxyBConn, clientAddr) + + copier := NewPacketReactor(context.Background()) + defer copier.Close() + + go func() { + copier.Copy(context.Background(), proxyA, proxyB, nil) + }() + + time.Sleep(50 * time.Millisecond) + + const times = 50 + const chunkSize = 9000 + + pingCh := make(chan batchHashPair, 1) + pongCh := make(chan batchHashPair, 1) + errCh := make(chan error, 2) + + go func() { + sendHash := make(map[int][]byte) + recvHash := make(map[int][]byte) + recvBuf := make([]byte, 65536) + + for i := 0; i < times; i++ { + data := make([]byte, chunkSize) + rand.Read(data[1:]) + data[0] = byte(i) + hash := md5.Sum(data) + sendHash[i] = hash[:] + _, err := clientConn.WriteTo(data, proxyAAddr.UDPAddr()) + if err != nil { + errCh <- err + return + } + time.Sleep(5 * time.Millisecond) + } + + for i := 0; i < times; i++ { + n, _, err := clientConn.ReadFrom(recvBuf) + if err != nil { + errCh <- err + return + } + hash := md5.Sum(recvBuf[:n]) + recvHash[int(recvBuf[0])] = hash[:] + } + + pingCh <- batchHashPair{sendHash: sendHash, recvHash: recvHash} + }() + + go func() { + sendHash := make(map[int][]byte) + recvHash := make(map[int][]byte) + recvBuf := make([]byte, 65536) + + for i := 0; i < times; i++ { + n, _, err := serverConn.ReadFrom(recvBuf) + if err != nil { + errCh <- err + return + } + hash := md5.Sum(recvBuf[:n]) + recvHash[int(recvBuf[0])] = hash[:] + } + + for i := 0; i < times; i++ { + data := make([]byte, chunkSize) + rand.Read(data[1:]) + data[0] = byte(i) + hash := md5.Sum(data) + sendHash[i] = hash[:] + _, err := serverConn.WriteTo(data, proxyBAddr.UDPAddr()) + if err != nil { + errCh <- err + return + } + time.Sleep(5 * time.Millisecond) + } + + pongCh <- batchHashPair{sendHash: sendHash, recvHash: recvHash} + }() + + var clientPair, serverPair batchHashPair + for i := 0; i < 2; i++ { + select { + case clientPair = <-pingCh: + case serverPair = <-pongCh: + case err := <-errCh: + t.Fatal(err) + case <-time.After(15 * time.Second): + t.Fatal("timeout") + } + } + + assert.Equal(t, clientPair.sendHash, serverPair.recvHash, "client->server mismatch") + assert.Equal(t, serverPair.sendHash, clientPair.recvHash, "server->client mismatch") +} + +func TestBatchCopy_MultipleConnections_DataIntegrity(t *testing.T) { + t.Parallel() + + const numConnections = 5 + + copier := NewPacketReactor(context.Background()) + defer copier.Close() + + var wg sync.WaitGroup + errCh := make(chan error, numConnections) + + for i := 0; i < numConnections; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + + addr1 := M.ParseSocksaddrHostPort("127.0.0.1", uint16(20000+idx*2)) + addr2 := M.ParseSocksaddrHostPort("127.0.0.1", uint16(20001+idx*2)) + + pipeA := newTestPacketPipe(addr1) + pipeB := newTestPacketPipe(addr2) + defer pipeA.Close() + defer pipeB.Close() + + go func() { + copier.Copy(context.Background(), pipeA, pipeB, nil) + }() + + time.Sleep(50 * time.Millisecond) + + const times = 20 + const chunkSize = 1000 + + sendHash := make(map[int][]byte) + recvHash := make(map[int][]byte) + + recvDone := make(chan struct{}) + go func() { + defer close(recvDone) + for i := 0; i < times; i++ { + packet, ok := pipeB.receive() + if !ok { + return + } + hash := md5.Sum(packet.Buffer.Bytes()) + recvHash[int(packet.Buffer.Byte(0))] = hash[:] + packet.Buffer.Release() + N.PutPacketBuffer(packet) + } + }() + + for i := 0; i < times; i++ { + data := make([]byte, chunkSize) + rand.Read(data[1:]) + data[0] = byte(i) + hash := md5.Sum(data) + sendHash[i] = hash[:] + pipeA.send(data, addr2) + time.Sleep(5 * time.Millisecond) + } + + select { + case <-recvDone: + case <-time.After(10 * time.Second): + errCh <- errors.New("timeout") + return + } + + for k, v := range sendHash { + if rv, ok := recvHash[k]; !ok || string(v) != string(rv) { + errCh <- errors.New("data mismatch") + return + } + } + }(i) + } + + wg.Wait() + close(errCh) + + for err := range errCh { + require.NoError(t, err) + } +} + +func TestBatchCopy_TimeoutAndResume_DataIntegrity(t *testing.T) { + t.Parallel() + + addr1 := M.ParseSocksaddrHostPort("127.0.0.1", 30001) + addr2 := M.ParseSocksaddrHostPort("127.0.0.1", 30002) + + pipeA := newTestPacketPipe(addr1) + pipeB := newTestPacketPipe(addr2) + defer pipeA.Close() + defer pipeB.Close() + + copier := NewPacketReactor(context.Background()) + defer copier.Close() + + go func() { + copier.Copy(context.Background(), pipeA, pipeB, nil) + }() + + time.Sleep(50 * time.Millisecond) + + sendAndVerify := func(batchID int, count int) { + sendHash := make(map[int][]byte) + recvHash := make(map[int][]byte) + + recvDone := make(chan struct{}) + go func() { + defer close(recvDone) + for i := 0; i < count; i++ { + packet, ok := pipeB.receive() + if !ok { + return + } + hash := md5.Sum(packet.Buffer.Bytes()) + recvHash[int(packet.Buffer.Byte(1))] = hash[:] + packet.Buffer.Release() + N.PutPacketBuffer(packet) + } + }() + + for i := 0; i < count; i++ { + data := make([]byte, 1000) + rand.Read(data[2:]) + data[0] = byte(batchID) + data[1] = byte(i) + hash := md5.Sum(data) + sendHash[i] = hash[:] + pipeA.send(data, addr2) + time.Sleep(5 * time.Millisecond) + } + + select { + case <-recvDone: + case <-time.After(5 * time.Second): + t.Fatalf("batch %d timeout", batchID) + } + + assert.Equal(t, sendHash, recvHash, "batch %d mismatch", batchID) + } + + sendAndVerify(1, 10) + + time.Sleep(350 * time.Millisecond) + + sendAndVerify(2, 10) +} + +func TestBatchCopy_CloseWhileTransferring(t *testing.T) { + t.Parallel() + + addr1 := M.ParseSocksaddrHostPort("127.0.0.1", 40001) + addr2 := M.ParseSocksaddrHostPort("127.0.0.1", 40002) + + pipeA := newTestPacketPipe(addr1) + pipeB := newTestPacketPipe(addr2) + + copier := NewPacketReactor(context.Background()) + + copyDone := make(chan struct{}) + go func() { + copier.Copy(context.Background(), pipeA, pipeB, nil) + close(copyDone) + }() + + time.Sleep(50 * time.Millisecond) + + stopSend := make(chan struct{}) + go func() { + for { + select { + case <-stopSend: + return + default: + data := make([]byte, 1000) + rand.Read(data) + pipeA.send(data, addr2) + time.Sleep(1 * time.Millisecond) + } + } + }() + + time.Sleep(100 * time.Millisecond) + + pipeA.Close() + pipeB.Close() + copier.Close() + close(stopSend) + + select { + case <-copyDone: + case <-time.After(5 * time.Second): + t.Fatal("copier did not close - possible deadlock") + } +} + +func TestBatchCopy_HighThroughput(t *testing.T) { + t.Parallel() + + addr1 := M.ParseSocksaddrHostPort("127.0.0.1", 50001) + addr2 := M.ParseSocksaddrHostPort("127.0.0.1", 50002) + + pipeA := newTestPacketPipe(addr1) + pipeB := newTestPacketPipe(addr2) + defer pipeA.Close() + defer pipeB.Close() + + copier := NewPacketReactor(context.Background()) + defer copier.Close() + + go func() { + copier.Copy(context.Background(), pipeA, pipeB, nil) + }() + + time.Sleep(50 * time.Millisecond) + + const times = 500 + const chunkSize = 8000 + + sendHash := make(map[int][]byte) + recvHash := make(map[int][]byte) + var mu sync.Mutex + + recvDone := make(chan struct{}) + go func() { + defer close(recvDone) + for i := 0; i < times; i++ { + packet, ok := pipeB.receive() + if !ok { + t.Logf("recv channel closed at %d", i) + return + } + hash := md5.Sum(packet.Buffer.Bytes()) + idx := int(packet.Buffer.Byte(0))<<8 | int(packet.Buffer.Byte(1)) + mu.Lock() + recvHash[idx] = hash[:] + mu.Unlock() + packet.Buffer.Release() + N.PutPacketBuffer(packet) + } + }() + + for i := 0; i < times; i++ { + data := make([]byte, chunkSize) + rand.Read(data[2:]) + data[0] = byte(i >> 8) + data[1] = byte(i & 0xff) + hash := md5.Sum(data) + sendHash[i] = hash[:] + pipeA.send(data, addr2) + time.Sleep(1 * time.Millisecond) + } + + select { + case <-recvDone: + case <-time.After(30 * time.Second): + t.Fatal("high throughput test timeout") + } + + mu.Lock() + defer mu.Unlock() + assert.Equal(t, len(sendHash), len(recvHash), "packet count mismatch") + for k, v := range sendHash { + assert.Equal(t, v, recvHash[k], "packet %d mismatch", k) + } +} + +func TestBatchCopy_LegacyFallback_DataIntegrity(t *testing.T) { + t.Parallel() + + clientConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer clientConn.Close() + + proxyAConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + + proxyBConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + + serverConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer serverConn.Close() + + serverAddr := M.SocksaddrFromNet(serverConn.LocalAddr()) + clientAddr := M.SocksaddrFromNet(clientConn.LocalAddr()) + proxyAAddr := M.SocksaddrFromNet(proxyAConn.LocalAddr()) + proxyBAddr := M.SocksaddrFromNet(proxyBConn.LocalAddr()) + + proxyA := &legacyPacketConn{NetPacketConn: NewPacketConn(proxyAConn), targetAddr: serverAddr} + proxyB := &legacyPacketConn{NetPacketConn: NewPacketConn(proxyBConn), targetAddr: clientAddr} + + copier := NewPacketReactor(context.Background()) + defer copier.Close() + + go func() { + copier.Copy(context.Background(), proxyA, proxyB, nil) + }() + + time.Sleep(50 * time.Millisecond) + + const times = 50 + const chunkSize = 9000 + + pingCh := make(chan batchHashPair, 1) + pongCh := make(chan batchHashPair, 1) + errCh := make(chan error, 2) + + go func() { + sendHash := make(map[int][]byte) + recvHash := make(map[int][]byte) + recvBuf := make([]byte, 65536) + + for i := 0; i < times; i++ { + data := make([]byte, chunkSize) + rand.Read(data[1:]) + data[0] = byte(i) + hash := md5.Sum(data) + sendHash[i] = hash[:] + _, err := clientConn.WriteTo(data, proxyAAddr.UDPAddr()) + if err != nil { + errCh <- err + return + } + time.Sleep(5 * time.Millisecond) + } + + clientConn.SetReadDeadline(time.Now().Add(10 * time.Second)) + for i := 0; i < times; i++ { + n, _, err := clientConn.ReadFrom(recvBuf) + if err != nil { + if os.IsTimeout(err) { + t.Logf("client read timeout after %d packets", i) + } + errCh <- err + return + } + hash := md5.Sum(recvBuf[:n]) + recvHash[int(recvBuf[0])] = hash[:] + } + + pingCh <- batchHashPair{sendHash: sendHash, recvHash: recvHash} + }() + + go func() { + sendHash := make(map[int][]byte) + recvHash := make(map[int][]byte) + recvBuf := make([]byte, 65536) + + serverConn.SetReadDeadline(time.Now().Add(10 * time.Second)) + for i := 0; i < times; i++ { + n, _, err := serverConn.ReadFrom(recvBuf) + if err != nil { + if os.IsTimeout(err) { + t.Logf("server read timeout after %d packets", i) + } + errCh <- err + return + } + hash := md5.Sum(recvBuf[:n]) + recvHash[int(recvBuf[0])] = hash[:] + } + + for i := 0; i < times; i++ { + data := make([]byte, chunkSize) + rand.Read(data[1:]) + data[0] = byte(i) + hash := md5.Sum(data) + sendHash[i] = hash[:] + _, err := serverConn.WriteTo(data, proxyBAddr.UDPAddr()) + if err != nil { + errCh <- err + return + } + time.Sleep(5 * time.Millisecond) + } + + pongCh <- batchHashPair{sendHash: sendHash, recvHash: recvHash} + }() + + var clientPair, serverPair batchHashPair + for i := 0; i < 2; i++ { + select { + case clientPair = <-pingCh: + case serverPair = <-pongCh: + case err := <-errCh: + t.Fatal(err) + case <-time.After(20 * time.Second): + t.Fatal("timeout") + } + } + + assert.Equal(t, clientPair.sendHash, serverPair.recvHash, "client->server mismatch") + assert.Equal(t, serverPair.sendHash, clientPair.recvHash, "server->client mismatch") +} + +func TestBatchCopy_ReactorClose(t *testing.T) { + t.Parallel() + + addr1 := M.ParseSocksaddrHostPort("127.0.0.1", 60001) + addr2 := M.ParseSocksaddrHostPort("127.0.0.1", 60002) + + pipeA := newTestPacketPipe(addr1) + pipeB := newTestPacketPipe(addr2) + defer pipeA.Close() + defer pipeB.Close() + + copier := NewPacketReactor(context.Background()) + + copyDone := make(chan struct{}) + go func() { + copier.Copy(context.Background(), pipeA, pipeB, nil) + close(copyDone) + }() + + time.Sleep(50 * time.Millisecond) + + go func() { + for { + select { + case <-copyDone: + return + default: + data := make([]byte, 100) + rand.Read(data) + pipeA.send(data, addr2) + time.Sleep(10 * time.Millisecond) + } + } + }() + + time.Sleep(100 * time.Millisecond) + + pipeA.Close() + pipeB.Close() + copier.Close() + + select { + case <-copyDone: + case <-time.After(5 * time.Second): + t.Fatal("Copy did not return after reactor close") + } +} + +func TestBatchCopy_SmallPackets(t *testing.T) { + t.Parallel() + + addr1 := M.ParseSocksaddrHostPort("127.0.0.1", 60011) + addr2 := M.ParseSocksaddrHostPort("127.0.0.1", 60012) + + pipeA := newTestPacketPipe(addr1) + pipeB := newTestPacketPipe(addr2) + defer pipeA.Close() + defer pipeB.Close() + + copier := NewPacketReactor(context.Background()) + defer copier.Close() + + go func() { + copier.Copy(context.Background(), pipeA, pipeB, nil) + }() + + time.Sleep(50 * time.Millisecond) + + const totalPackets = 20 + receivedCount := 0 + + recvDone := make(chan struct{}) + go func() { + defer close(recvDone) + for i := 0; i < totalPackets; i++ { + packet, ok := pipeB.receive() + if !ok { + return + } + receivedCount++ + packet.Buffer.Release() + N.PutPacketBuffer(packet) + } + }() + + for i := 0; i < totalPackets; i++ { + size := (i % 10) + 1 + data := make([]byte, size) + rand.Read(data) + pipeA.send(data, addr2) + time.Sleep(5 * time.Millisecond) + } + + select { + case <-recvDone: + case <-time.After(10 * time.Second): + t.Fatal("timeout waiting for packets") + } + + assert.Equal(t, totalPackets, receivedCount) +} + +func TestBatchCopy_VaryingPacketSizes(t *testing.T) { + t.Parallel() + + addr1 := M.ParseSocksaddrHostPort("127.0.0.1", 60041) + addr2 := M.ParseSocksaddrHostPort("127.0.0.1", 60042) + + pipeA := newTestPacketPipe(addr1) + pipeB := newTestPacketPipe(addr2) + defer pipeA.Close() + defer pipeB.Close() + + copier := NewPacketReactor(context.Background()) + defer copier.Close() + + go func() { + copier.Copy(context.Background(), pipeA, pipeB, nil) + }() + + time.Sleep(50 * time.Millisecond) + + sizes := []int{10, 100, 500, 1000, 2000, 4000, 8000} + const times = 3 + + totalPackets := len(sizes) * times + sendHash := make(map[int][]byte) + recvHash := make(map[int][]byte) + + recvDone := make(chan struct{}) + go func() { + defer close(recvDone) + for i := 0; i < totalPackets; i++ { + packet, ok := pipeB.receive() + if !ok { + return + } + idx := int(packet.Buffer.Byte(0))<<8 | int(packet.Buffer.Byte(1)) + hash := md5.Sum(packet.Buffer.Bytes()) + recvHash[idx] = hash[:] + packet.Buffer.Release() + N.PutPacketBuffer(packet) + } + }() + + packetIdx := 0 + for _, size := range sizes { + for j := 0; j < times; j++ { + data := make([]byte, size) + rand.Read(data[2:]) + data[0] = byte(packetIdx >> 8) + data[1] = byte(packetIdx & 0xff) + hash := md5.Sum(data) + sendHash[packetIdx] = hash[:] + pipeA.send(data, addr2) + packetIdx++ + time.Sleep(5 * time.Millisecond) + } + } + + select { + case <-recvDone: + case <-time.After(10 * time.Second): + t.Fatal("timeout") + } + + assert.Equal(t, len(sendHash), len(recvHash)) + for k, v := range sendHash { + assert.Equal(t, v, recvHash[k], "packet %d mismatch", k) + } +} + +func TestBatchCopy_OnCloseCallback(t *testing.T) { + t.Parallel() + + addr1 := M.ParseSocksaddrHostPort("127.0.0.1", 60021) + addr2 := M.ParseSocksaddrHostPort("127.0.0.1", 60022) + + pipeA := newTestPacketPipe(addr1) + pipeB := newTestPacketPipe(addr2) + + copier := NewPacketReactor(context.Background()) + defer copier.Close() + + callbackCalled := make(chan error, 1) + onClose := func(err error) { + select { + case callbackCalled <- err: + default: + } + } + + go func() { + copier.Copy(context.Background(), pipeA, pipeB, onClose) + }() + + time.Sleep(50 * time.Millisecond) + + for i := 0; i < 5; i++ { + data := make([]byte, 100) + rand.Read(data) + pipeA.send(data, addr2) + } + + time.Sleep(50 * time.Millisecond) + + pipeA.Close() + pipeB.Close() + + select { + case <-callbackCalled: + case <-time.After(5 * time.Second): + t.Fatal("onClose callback was not called") + } +} + +func TestBatchCopy_SourceClose(t *testing.T) { + t.Parallel() + + addr1 := M.ParseSocksaddrHostPort("127.0.0.1", 60031) + addr2 := M.ParseSocksaddrHostPort("127.0.0.1", 60032) + + pipeA := newTestPacketPipe(addr1) + pipeB := newTestPacketPipe(addr2) + + copier := NewPacketReactor(context.Background()) + defer copier.Close() + + var capturedErr error + var errMu sync.Mutex + callbackCalled := make(chan struct{}) + onClose := func(err error) { + errMu.Lock() + capturedErr = err + errMu.Unlock() + close(callbackCalled) + } + + copyDone := make(chan struct{}) + go func() { + copier.Copy(context.Background(), pipeA, pipeB, onClose) + close(copyDone) + }() + + time.Sleep(50 * time.Millisecond) + + for i := 0; i < 5; i++ { + data := make([]byte, 100) + rand.Read(data) + pipeA.send(data, addr2) + } + + time.Sleep(50 * time.Millisecond) + + pipeA.Close() + close(pipeA.inChan) + + select { + case <-callbackCalled: + case <-time.After(5 * time.Second): + pipeB.Close() + t.Fatal("onClose callback was not called after source close") + } + + select { + case <-copyDone: + case <-time.After(5 * time.Second): + t.Fatal("Copy did not return after source close") + } + + pipeB.Close() + + errMu.Lock() + err := capturedErr + errMu.Unlock() + + require.NotNil(t, err) +} + +type legacyPacketConn struct { + N.NetPacketConn + targetAddr M.Socksaddr +} + +func (c *legacyPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { + _, err = c.NetPacketConn.ReadPacket(buffer) + if err != nil { + return M.Socksaddr{}, err + } + return c.targetAddr, nil +} diff --git a/common/exceptions/error.go b/common/exceptions/error.go index 06624bb6..6fcdf1fd 100644 --- a/common/exceptions/error.go +++ b/common/exceptions/error.go @@ -7,6 +7,7 @@ import ( "net" "net/http" "os" + "strings" "syscall" _ "unsafe" @@ -65,5 +66,13 @@ func IsClosed(err error) bool { } func IsCanceled(err error) bool { - return IsMulti(err, context.Canceled, context.DeadlineExceeded) + return IsMulti(err, context.Canceled, context.DeadlineExceeded) || isCanceledQuicLike(err) +} + +func isCanceledQuicLike(err error) bool { + if err == nil { + return false + } + s := err.Error() + return strings.Contains(s, "canceled by remote with error code 0") } diff --git a/common/network/read_notifier.go b/common/network/read_notifier.go new file mode 100644 index 00000000..3a693b15 --- /dev/null +++ b/common/network/read_notifier.go @@ -0,0 +1,21 @@ +package network + +type ReadNotifier interface { + isReadNotifier() +} + +type ChannelNotifier struct { + Channel <-chan *PacketBuffer +} + +func (*ChannelNotifier) isReadNotifier() {} + +type FileDescriptorNotifier struct { + FD int +} + +func (*FileDescriptorNotifier) isReadNotifier() {} + +type ReadNotifierSource interface { + CreateReadNotifier() ReadNotifier +} diff --git a/common/udpnat2/conn.go b/common/udpnat2/conn.go index 3c0cda38..81710bfc 100644 --- a/common/udpnat2/conn.go +++ b/common/udpnat2/conn.go @@ -140,3 +140,7 @@ func (c *natConn) SetWriteDeadline(t time.Time) error { func (c *natConn) Upstream() any { return c.writer } + +func (c *natConn) CreateReadNotifier() N.ReadNotifier { + return &N.ChannelNotifier{Channel: c.packetChan} +} diff --git a/common/wepoll/afd_windows.go b/common/wepoll/afd_windows.go new file mode 100644 index 00000000..aad3391a --- /dev/null +++ b/common/wepoll/afd_windows.go @@ -0,0 +1,122 @@ +//go:build windows + +package wepoll + +import ( + "math" + "unsafe" + + "golang.org/x/sys/windows" +) + +type AFD struct { + handle windows.Handle +} + +func NewAFD(iocp windows.Handle, name string) (*AFD, error) { + deviceName := `\Device\Afd\` + name + deviceNameUTF16, err := windows.UTF16FromString(deviceName) + if err != nil { + return nil, err + } + + unicodeString := UnicodeString{ + Length: uint16(len(deviceName) * 2), + MaximumLength: uint16(len(deviceName) * 2), + Buffer: &deviceNameUTF16[0], + } + + objectAttributes := ObjectAttributes{ + Length: uint32(unsafe.Sizeof(ObjectAttributes{})), + ObjectName: &unicodeString, + Attributes: OBJ_CASE_INSENSITIVE, + } + + var handle windows.Handle + var ioStatusBlock windows.IO_STATUS_BLOCK + + err = NtCreateFile( + &handle, + windows.SYNCHRONIZE, + &objectAttributes, + &ioStatusBlock, + nil, + 0, + windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE, + FILE_OPEN, + 0, + 0, + 0, + ) + if err != nil { + return nil, err + } + + _, err = windows.CreateIoCompletionPort(handle, iocp, 0, 0) + if err != nil { + windows.CloseHandle(handle) + return nil, err + } + + err = windows.SetFileCompletionNotificationModes(handle, windows.FILE_SKIP_SET_EVENT_ON_HANDLE) + if err != nil { + windows.CloseHandle(handle) + return nil, err + } + + return &AFD{handle: handle}, nil +} + +func (a *AFD) Poll(baseSocket windows.Handle, events uint32, iosb *windows.IO_STATUS_BLOCK, pollInfo *AFDPollInfo) error { + pollInfo.Timeout = math.MaxInt64 + pollInfo.NumberOfHandles = 1 + pollInfo.Exclusive = 0 + pollInfo.Handles[0].Handle = baseSocket + pollInfo.Handles[0].Events = events + pollInfo.Handles[0].Status = 0 + + size := uint32(unsafe.Sizeof(*pollInfo)) + + err := NtDeviceIoControlFile( + a.handle, + 0, + 0, + uintptr(unsafe.Pointer(iosb)), + iosb, + IOCTL_AFD_POLL, + unsafe.Pointer(pollInfo), + size, + unsafe.Pointer(pollInfo), + size, + ) + if err != nil { + if ntstatus, ok := err.(windows.NTStatus); ok { + if uint32(ntstatus) == STATUS_PENDING { + return nil + } + } + return err + } + return nil +} + +func (a *AFD) Cancel(ioStatusBlock *windows.IO_STATUS_BLOCK) error { + if uint32(ioStatusBlock.Status) != STATUS_PENDING { + return nil + } + var cancelIOStatusBlock windows.IO_STATUS_BLOCK + err := NtCancelIoFileEx(a.handle, ioStatusBlock, &cancelIOStatusBlock) + if err != nil { + if ntstatus, ok := err.(windows.NTStatus); ok { + if uint32(ntstatus) == STATUS_CANCELLED || uint32(ntstatus) == STATUS_NOT_FOUND { + return nil + } + } + return err + } + return nil +} + +func (a *AFD) Close() error { + return windows.CloseHandle(a.handle) +} diff --git a/common/wepoll/pinner.go b/common/wepoll/pinner.go new file mode 100644 index 00000000..58b76686 --- /dev/null +++ b/common/wepoll/pinner.go @@ -0,0 +1,7 @@ +//go:build go1.21 + +package wepoll + +import "runtime" + +type Pinner = runtime.Pinner diff --git a/common/wepoll/pinner_compat.go b/common/wepoll/pinner_compat.go new file mode 100644 index 00000000..a51a9fa6 --- /dev/null +++ b/common/wepoll/pinner_compat.go @@ -0,0 +1,9 @@ +//go:build !go1.21 + +package wepoll + +type Pinner struct{} + +func (p *Pinner) Pin(pointer any) {} + +func (p *Pinner) Unpin() {} diff --git a/common/wepoll/socket_windows.go b/common/wepoll/socket_windows.go new file mode 100644 index 00000000..e5655990 --- /dev/null +++ b/common/wepoll/socket_windows.go @@ -0,0 +1,49 @@ +//go:build windows + +package wepoll + +import ( + "unsafe" + + "golang.org/x/sys/windows" +) + +func GetBaseSocket(socket windows.Handle) (windows.Handle, error) { + var baseSocket windows.Handle + var bytesReturned uint32 + + for { + err := windows.WSAIoctl( + socket, + SIO_BASE_HANDLE, + nil, + 0, + (*byte)(unsafe.Pointer(&baseSocket)), + uint32(unsafe.Sizeof(baseSocket)), + &bytesReturned, + nil, + 0, + ) + if err != nil { + err = windows.WSAIoctl( + socket, + SIO_BSP_HANDLE_POLL, + nil, + 0, + (*byte)(unsafe.Pointer(&baseSocket)), + uint32(unsafe.Sizeof(baseSocket)), + &bytesReturned, + nil, + 0, + ) + if err != nil { + return socket, nil + } + } + + if baseSocket == socket { + return baseSocket, nil + } + socket = baseSocket + } +} diff --git a/common/wepoll/syscall_windows.go b/common/wepoll/syscall_windows.go new file mode 100644 index 00000000..948dd00e --- /dev/null +++ b/common/wepoll/syscall_windows.go @@ -0,0 +1,8 @@ +package wepoll + +//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go syscall_windows.go + +//sys NtCreateFile(handle *windows.Handle, access uint32, oa *ObjectAttributes, iosb *windows.IO_STATUS_BLOCK, allocationSize *int64, attributes uint32, share uint32, disposition uint32, options uint32, eaBuffer uintptr, eaLength uint32) (ntstatus error) = ntdll.NtCreateFile +//sys NtDeviceIoControlFile(handle windows.Handle, event windows.Handle, apcRoutine uintptr, apcContext uintptr, ioStatusBlock *windows.IO_STATUS_BLOCK, ioControlCode uint32, inputBuffer unsafe.Pointer, inputBufferLength uint32, outputBuffer unsafe.Pointer, outputBufferLength uint32) (ntstatus error) = ntdll.NtDeviceIoControlFile +//sys NtCancelIoFileEx(handle windows.Handle, ioRequestToCancel *windows.IO_STATUS_BLOCK, ioStatusBlock *windows.IO_STATUS_BLOCK) (ntstatus error) = ntdll.NtCancelIoFileEx +//sys GetQueuedCompletionStatusEx(cphandle windows.Handle, entries *OverlappedEntry, count uint32, numRemoved *uint32, timeout uint32, alertable bool) (err error) = kernel32.GetQueuedCompletionStatusEx diff --git a/common/wepoll/types_windows.go b/common/wepoll/types_windows.go new file mode 100644 index 00000000..aad2d79a --- /dev/null +++ b/common/wepoll/types_windows.go @@ -0,0 +1,64 @@ +//go:build windows + +package wepoll + +import "golang.org/x/sys/windows" + +const ( + IOCTL_AFD_POLL = 0x00012024 + + AFD_POLL_RECEIVE = 0x0001 + AFD_POLL_RECEIVE_EXPEDITED = 0x0002 + AFD_POLL_SEND = 0x0004 + AFD_POLL_DISCONNECT = 0x0008 + AFD_POLL_ABORT = 0x0010 + AFD_POLL_LOCAL_CLOSE = 0x0020 + AFD_POLL_ACCEPT = 0x0080 + AFD_POLL_CONNECT_FAIL = 0x0100 + + SIO_BASE_HANDLE = 0x48000022 + SIO_BSP_HANDLE_POLL = 0x4800001D + + STATUS_PENDING = 0x00000103 + STATUS_CANCELLED = 0xC0000120 + STATUS_NOT_FOUND = 0xC0000225 + + FILE_OPEN = 0x00000001 + + OBJ_CASE_INSENSITIVE = 0x00000040 +) + +type AFDPollHandleInfo struct { + Handle windows.Handle + Events uint32 + Status uint32 +} + +type AFDPollInfo struct { + Timeout int64 + NumberOfHandles uint32 + Exclusive uint32 + Handles [1]AFDPollHandleInfo +} + +type OverlappedEntry struct { + CompletionKey uintptr + Overlapped *windows.Overlapped + Internal uintptr + NumberOfBytesTransferred uint32 +} + +type UnicodeString struct { + Length uint16 + MaximumLength uint16 + Buffer *uint16 +} + +type ObjectAttributes struct { + Length uint32 + RootDirectory windows.Handle + ObjectName *UnicodeString + Attributes uint32 + SecurityDescriptor uintptr + SecurityQualityOfService uintptr +} diff --git a/common/wepoll/wepoll_test.go b/common/wepoll/wepoll_test.go new file mode 100644 index 00000000..0a0ec023 --- /dev/null +++ b/common/wepoll/wepoll_test.go @@ -0,0 +1,335 @@ +//go:build windows + +package wepoll + +import ( + "net" + "syscall" + "testing" + "time" + "unsafe" + + "github.com/stretchr/testify/require" + "golang.org/x/sys/windows" +) + +func createTestIOCP(t *testing.T) windows.Handle { + iocp, err := windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0) + require.NoError(t, err) + t.Cleanup(func() { + windows.CloseHandle(iocp) + }) + return iocp +} + +func getSocketHandle(t *testing.T, conn net.PacketConn) windows.Handle { + syscallConn, ok := conn.(syscall.Conn) + require.True(t, ok) + rawConn, err := syscallConn.SyscallConn() + require.NoError(t, err) + var fd uintptr + err = rawConn.Control(func(f uintptr) { fd = f }) + require.NoError(t, err) + return windows.Handle(fd) +} + +func getTCPSocketHandle(t *testing.T, conn net.Conn) windows.Handle { + syscallConn, ok := conn.(syscall.Conn) + require.True(t, ok) + rawConn, err := syscallConn.SyscallConn() + require.NoError(t, err) + var fd uintptr + err = rawConn.Control(func(f uintptr) { fd = f }) + require.NoError(t, err) + return windows.Handle(fd) +} + +func TestNewAFD(t *testing.T) { + t.Parallel() + + iocp := createTestIOCP(t) + + afd, err := NewAFD(iocp, "test") + require.NoError(t, err) + require.NotNil(t, afd) + + err = afd.Close() + require.NoError(t, err) +} + +func TestNewAFD_MultipleTimes(t *testing.T) { + t.Parallel() + + iocp := createTestIOCP(t) + + afd1, err := NewAFD(iocp, "test1") + require.NoError(t, err) + defer afd1.Close() + + afd2, err := NewAFD(iocp, "test2") + require.NoError(t, err) + defer afd2.Close() + + afd3, err := NewAFD(iocp, "test3") + require.NoError(t, err) + defer afd3.Close() +} + +func TestGetBaseSocket_UDP(t *testing.T) { + t.Parallel() + + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer conn.Close() + + handle := getSocketHandle(t, conn) + baseHandle, err := GetBaseSocket(handle) + require.NoError(t, err) + require.NotEqual(t, windows.InvalidHandle, baseHandle) +} + +func TestGetBaseSocket_TCP(t *testing.T) { + t.Parallel() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + go func() { + conn, err := listener.Accept() + if err == nil { + conn.Close() + } + }() + + conn, err := net.Dial("tcp", listener.Addr().String()) + require.NoError(t, err) + defer conn.Close() + + handle := getTCPSocketHandle(t, conn) + baseHandle, err := GetBaseSocket(handle) + require.NoError(t, err) + require.NotEqual(t, windows.InvalidHandle, baseHandle) +} + +func TestAFD_Poll_ReceiveEvent(t *testing.T) { + t.Parallel() + + iocp := createTestIOCP(t) + + afd, err := NewAFD(iocp, "poll_test") + require.NoError(t, err) + defer afd.Close() + + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer conn.Close() + + handle := getSocketHandle(t, conn) + baseHandle, err := GetBaseSocket(handle) + require.NoError(t, err) + + var state struct { + iosb windows.IO_STATUS_BLOCK + pollInfo AFDPollInfo + } + + var pinner Pinner + pinner.Pin(&state) + defer pinner.Unpin() + + events := uint32(AFD_POLL_RECEIVE | AFD_POLL_DISCONNECT | AFD_POLL_ABORT) + err = afd.Poll(baseHandle, events, &state.iosb, &state.pollInfo) + require.NoError(t, err) + + sender, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer sender.Close() + + _, err = sender.WriteTo([]byte("test data"), conn.LocalAddr()) + require.NoError(t, err) + + entries := make([]OverlappedEntry, 1) + var numRemoved uint32 + err = GetQueuedCompletionStatusEx(iocp, &entries[0], 1, &numRemoved, 5000, false) + require.NoError(t, err) + require.Equal(t, uint32(1), numRemoved) + require.Equal(t, uintptr(unsafe.Pointer(&state.iosb)), uintptr(unsafe.Pointer(entries[0].Overlapped))) +} + +func TestAFD_Cancel(t *testing.T) { + t.Parallel() + + iocp := createTestIOCP(t) + + afd, err := NewAFD(iocp, "cancel_test") + require.NoError(t, err) + defer afd.Close() + + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer conn.Close() + + handle := getSocketHandle(t, conn) + baseHandle, err := GetBaseSocket(handle) + require.NoError(t, err) + + var state struct { + iosb windows.IO_STATUS_BLOCK + pollInfo AFDPollInfo + } + + var pinner Pinner + pinner.Pin(&state) + defer pinner.Unpin() + + events := uint32(AFD_POLL_RECEIVE) + err = afd.Poll(baseHandle, events, &state.iosb, &state.pollInfo) + require.NoError(t, err) + + err = afd.Cancel(&state.iosb) + require.NoError(t, err) + + entries := make([]OverlappedEntry, 1) + var numRemoved uint32 + err = GetQueuedCompletionStatusEx(iocp, &entries[0], 1, &numRemoved, 1000, false) + require.NoError(t, err) + require.Equal(t, uint32(1), numRemoved) +} + +func TestAFD_Close(t *testing.T) { + t.Parallel() + + iocp := createTestIOCP(t) + + afd, err := NewAFD(iocp, "close_test") + require.NoError(t, err) + + err = afd.Close() + require.NoError(t, err) +} + +func TestGetQueuedCompletionStatusEx_Timeout(t *testing.T) { + t.Parallel() + + iocp := createTestIOCP(t) + + entries := make([]OverlappedEntry, 1) + var numRemoved uint32 + + start := time.Now() + err := GetQueuedCompletionStatusEx(iocp, &entries[0], 1, &numRemoved, 100, false) + elapsed := time.Since(start) + + require.Error(t, err) + require.GreaterOrEqual(t, elapsed, 50*time.Millisecond) +} + +func TestGetQueuedCompletionStatusEx_MultipleEntries(t *testing.T) { + t.Parallel() + + iocp := createTestIOCP(t) + + afd, err := NewAFD(iocp, "multi_test") + require.NoError(t, err) + defer afd.Close() + + const numConns = 3 + conns := make([]net.PacketConn, numConns) + states := make([]struct { + iosb windows.IO_STATUS_BLOCK + pollInfo AFDPollInfo + }, numConns) + pinners := make([]Pinner, numConns) + + for i := 0; i < numConns; i++ { + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer conn.Close() + conns[i] = conn + + handle := getSocketHandle(t, conn) + baseHandle, err := GetBaseSocket(handle) + require.NoError(t, err) + + pinners[i].Pin(&states[i]) + defer pinners[i].Unpin() + + events := uint32(AFD_POLL_RECEIVE) + err = afd.Poll(baseHandle, events, &states[i].iosb, &states[i].pollInfo) + require.NoError(t, err) + } + + sender, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer sender.Close() + + for i := 0; i < numConns; i++ { + _, err = sender.WriteTo([]byte("test"), conns[i].LocalAddr()) + require.NoError(t, err) + } + + entries := make([]OverlappedEntry, 8) + var numRemoved uint32 + received := 0 + for received < numConns { + err = GetQueuedCompletionStatusEx(iocp, &entries[0], 8, &numRemoved, 5000, false) + require.NoError(t, err) + received += int(numRemoved) + } + require.Equal(t, numConns, received) +} + +func TestAFD_Poll_DisconnectEvent(t *testing.T) { + t.Parallel() + + iocp := createTestIOCP(t) + + afd, err := NewAFD(iocp, "disconnect_test") + require.NoError(t, err) + defer afd.Close() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + serverDone := make(chan struct{}) + go func() { + defer close(serverDone) + conn, err := listener.Accept() + if err != nil { + return + } + time.Sleep(100 * time.Millisecond) + conn.Close() + }() + + client, err := net.Dial("tcp", listener.Addr().String()) + require.NoError(t, err) + defer client.Close() + + handle := getTCPSocketHandle(t, client) + baseHandle, err := GetBaseSocket(handle) + require.NoError(t, err) + + var state struct { + iosb windows.IO_STATUS_BLOCK + pollInfo AFDPollInfo + } + + var pinner Pinner + pinner.Pin(&state) + defer pinner.Unpin() + + events := uint32(AFD_POLL_RECEIVE | AFD_POLL_DISCONNECT | AFD_POLL_ABORT) + err = afd.Poll(baseHandle, events, &state.iosb, &state.pollInfo) + require.NoError(t, err) + + entries := make([]OverlappedEntry, 1) + var numRemoved uint32 + err = GetQueuedCompletionStatusEx(iocp, &entries[0], 1, &numRemoved, 5000, false) + require.NoError(t, err) + require.Equal(t, uint32(1), numRemoved) + + <-serverDone +} diff --git a/common/wepoll/zsyscall_windows.go b/common/wepoll/zsyscall_windows.go new file mode 100644 index 00000000..dac75d17 --- /dev/null +++ b/common/wepoll/zsyscall_windows.go @@ -0,0 +1,84 @@ +// Code generated by 'go generate'; DO NOT EDIT. + +package wepoll + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +var _ unsafe.Pointer + +// Do the interface allocations only once for common +// Errno values. +const ( + errnoERROR_IO_PENDING = 997 +) + +var ( + errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) + errERROR_EINVAL error = syscall.EINVAL +) + +// errnoErr returns common boxed Errno values, to prevent +// allocations at runtime. +func errnoErr(e syscall.Errno) error { + switch e { + case 0: + return errERROR_EINVAL + case errnoERROR_IO_PENDING: + return errERROR_IO_PENDING + } + // TODO: add more here, after collecting data on the common + // error values see on Windows. (perhaps when running + // all.bat?) + return e +} + +var ( + modkernel32 = windows.NewLazySystemDLL("kernel32.dll") + modntdll = windows.NewLazySystemDLL("ntdll.dll") + + procGetQueuedCompletionStatusEx = modkernel32.NewProc("GetQueuedCompletionStatusEx") + procNtCancelIoFileEx = modntdll.NewProc("NtCancelIoFileEx") + procNtCreateFile = modntdll.NewProc("NtCreateFile") + procNtDeviceIoControlFile = modntdll.NewProc("NtDeviceIoControlFile") +) + +func GetQueuedCompletionStatusEx(cphandle windows.Handle, entries *OverlappedEntry, count uint32, numRemoved *uint32, timeout uint32, alertable bool) (err error) { + var _p0 uint32 + if alertable { + _p0 = 1 + } + r1, _, e1 := syscall.Syscall6(procGetQueuedCompletionStatusEx.Addr(), 6, uintptr(cphandle), uintptr(unsafe.Pointer(entries)), uintptr(count), uintptr(unsafe.Pointer(numRemoved)), uintptr(timeout), uintptr(_p0)) + if r1 == 0 { + err = errnoErr(e1) + } + return +} + +func NtCancelIoFileEx(handle windows.Handle, ioRequestToCancel *windows.IO_STATUS_BLOCK, ioStatusBlock *windows.IO_STATUS_BLOCK) (ntstatus error) { + r0, _, _ := syscall.Syscall(procNtCancelIoFileEx.Addr(), 3, uintptr(handle), uintptr(unsafe.Pointer(ioRequestToCancel)), uintptr(unsafe.Pointer(ioStatusBlock))) + if r0 != 0 { + ntstatus = windows.NTStatus(r0) + } + return +} + +func NtCreateFile(handle *windows.Handle, access uint32, oa *ObjectAttributes, iosb *windows.IO_STATUS_BLOCK, allocationSize *int64, attributes uint32, share uint32, disposition uint32, options uint32, eaBuffer uintptr, eaLength uint32) (ntstatus error) { + r0, _, _ := syscall.Syscall12(procNtCreateFile.Addr(), 11, uintptr(unsafe.Pointer(handle)), uintptr(access), uintptr(unsafe.Pointer(oa)), uintptr(unsafe.Pointer(iosb)), uintptr(unsafe.Pointer(allocationSize)), uintptr(attributes), uintptr(share), uintptr(disposition), uintptr(options), uintptr(eaBuffer), uintptr(eaLength), 0) + if r0 != 0 { + ntstatus = windows.NTStatus(r0) + } + return +} + +func NtDeviceIoControlFile(handle windows.Handle, event windows.Handle, apcRoutine uintptr, apcContext uintptr, ioStatusBlock *windows.IO_STATUS_BLOCK, ioControlCode uint32, inputBuffer unsafe.Pointer, inputBufferLength uint32, outputBuffer unsafe.Pointer, outputBufferLength uint32) (ntstatus error) { + r0, _, _ := syscall.Syscall12(procNtDeviceIoControlFile.Addr(), 10, uintptr(handle), uintptr(event), uintptr(apcRoutine), uintptr(apcContext), uintptr(unsafe.Pointer(ioStatusBlock)), uintptr(ioControlCode), uintptr(inputBuffer), uintptr(inputBufferLength), uintptr(outputBuffer), uintptr(outputBufferLength), 0, 0) + if r0 != 0 { + ntstatus = windows.NTStatus(r0) + } + return +}