66 "context"
77 "sync"
88 "sync/atomic"
9+ "unsafe"
910
1011 "github.com/sagernet/sing/common/wepoll"
1112
@@ -36,7 +37,7 @@ type FDDemultiplexer struct {
3637 mutex sync.Mutex
3738 entries map [int ]* fdDemuxEntry
3839 registrationCounter uint64
39- registrationToFD map [uint64 ]int
40+ iosbToFD map [uintptr ]int
4041 running bool
4142 closed atomic.Bool
4243 wg sync.WaitGroup
@@ -56,12 +57,12 @@ func NewFDDemultiplexer(ctx context.Context) (*FDDemultiplexer, error) {
5657
5758 ctx , cancel := context .WithCancel (ctx )
5859 demux := & FDDemultiplexer {
59- ctx : ctx ,
60- cancel : cancel ,
61- iocp : iocp ,
62- afd : afd ,
63- entries : make (map [int ]* fdDemuxEntry ),
64- registrationToFD : make (map [uint64 ]int ),
60+ ctx : ctx ,
61+ cancel : cancel ,
62+ iocp : iocp ,
63+ afd : afd ,
64+ entries : make (map [int ]* fdDemuxEntry ),
65+ iosbToFD : make (map [uintptr ]int ),
6566 }
6667 return demux , nil
6768}
@@ -94,14 +95,14 @@ func (p *FDDemultiplexer) Add(stream *reactorStream, fd int) error {
9495 entry .pinner .Pin (& entry .state )
9596
9697 events := uint32 (wepoll .AFD_POLL_RECEIVE | wepoll .AFD_POLL_DISCONNECT | wepoll .AFD_POLL_ABORT | wepoll .AFD_POLL_LOCAL_CLOSE )
97- err = p .afd .Poll (baseHandle , events , & entry .state .iosb , & entry .state .pollInfo , uintptr ( regID ) )
98+ err = p .afd .Poll (baseHandle , events , & entry .state .iosb , & entry .state .pollInfo )
9899 if err != nil {
99100 entry .pinner .Unpin ()
100101 return err
101102 }
102103
103104 p .entries [fd ] = entry
104- p .registrationToFD [ regID ] = fd
105+ p .iosbToFD [ uintptr ( unsafe . Pointer ( & entry . state . iosb )) ] = fd
105106
106107 if ! p .running {
107108 p .running = true
@@ -122,7 +123,9 @@ func (p *FDDemultiplexer) Remove(fd int) {
122123 }
123124
124125 entry .cancelled = true
125- p .afd .Cancel (& entry .state .iosb )
126+ if p .afd != nil {
127+ p .afd .Cancel (& entry .state .iosb )
128+ }
126129}
127130
128131func (p * FDDemultiplexer ) wakeup () {
@@ -178,27 +181,28 @@ func (p *FDDemultiplexer) run() {
178181
179182 for i := uint32 (0 ); i < numRemoved ; i ++ {
180183 ev := entries [i ]
181- regID := uint64 (ev .CompletionKey )
182184
183- if regID == 0 {
185+ if ev . Overlapped == nil {
184186 continue
185187 }
186188
189+ iosbPtr := uintptr (unsafe .Pointer (ev .Overlapped ))
190+
187191 p .mutex .Lock ()
188- fd , ok := p .registrationToFD [ regID ]
192+ fd , ok := p .iosbToFD [ iosbPtr ]
189193 if ! ok {
190194 p .mutex .Unlock ()
191195 continue
192196 }
193197
194198 entry := p .entries [fd ]
195- if entry == nil || entry . registrationID != regID {
199+ if entry == nil {
196200 p .mutex .Unlock ()
197201 continue
198202 }
199203
200204 entry .pinner .Unpin ()
201- delete (p .registrationToFD , regID )
205+ delete (p .iosbToFD , iosbPtr )
202206 delete (p .entries , fd )
203207
204208 if entry .cancelled {
0 commit comments