Skip to content

Commit 91d05d2

Browse files
committed
Improve packet reactor
1 parent 5820f0e commit 91d05d2

File tree

6 files changed

+150
-279
lines changed

6 files changed

+150
-279
lines changed

common/bufio/channel_poller.go

Lines changed: 0 additions & 143 deletions
This file was deleted.

common/bufio/packet_reactor.go

Lines changed: 56 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,18 @@ const (
2424
)
2525

2626
type PacketReactor struct {
27-
ctx context.Context
28-
cancel context.CancelFunc
29-
channelPoller *ChannelPoller
30-
fdPoller *FDPoller
31-
fdPollerOnce sync.Once
32-
fdPollerErr error
27+
ctx context.Context
28+
cancel context.CancelFunc
29+
fdPoller *FDPoller
30+
fdPollerOnce sync.Once
31+
fdPollerErr error
3332
}
3433

3534
func NewPacketReactor(ctx context.Context) *PacketReactor {
3635
ctx, cancel := context.WithCancel(ctx)
3736
return &PacketReactor{
38-
ctx: ctx,
39-
cancel: cancel,
40-
channelPoller: NewChannelPoller(ctx),
37+
ctx: ctx,
38+
cancel: cancel,
4139
}
4240
}
4341

@@ -50,14 +48,10 @@ func (r *PacketReactor) getFDPoller() (*FDPoller, error) {
5048

5149
func (r *PacketReactor) Close() error {
5250
r.cancel()
53-
var errs []error
54-
if r.channelPoller != nil {
55-
errs = append(errs, r.channelPoller.Close())
56-
}
5751
if r.fdPoller != nil {
58-
errs = append(errs, r.fdPoller.Close())
52+
return r.fdPoller.Close()
5953
}
60-
return E.Errors(errs...)
54+
return nil
6155
}
6256

6357
type reactorConnection struct {
@@ -80,6 +74,7 @@ type reactorStream struct {
8074
destination N.PacketWriter
8175
originSource N.PacketReader
8276

77+
pushable N.PacketPushable
8378
pollable N.PacketPollable
8479
options N.ReadWaitOptions
8580
readWaiter N.PacketReadWaiter
@@ -159,7 +154,9 @@ func (r *PacketReactor) prepareStream(conn *reactorConnection, source N.PacketRe
159154
stream.readWaiter.InitializeReadWaiter(stream.options)
160155
}
161156

162-
if pollable, ok := source.(N.PacketPollable); ok {
157+
if pushable, ok := source.(N.PacketPushable); ok {
158+
stream.pushable = pushable
159+
} else if pollable, ok := source.(N.PacketPollable); ok {
163160
stream.pollable = pollable
164161
} else if creator, ok := source.(N.PacketPollableCreator); ok {
165162
stream.pollable, _ = creator.CreatePacketPollable()
@@ -169,25 +166,32 @@ func (r *PacketReactor) prepareStream(conn *reactorConnection, source N.PacketRe
169166
}
170167

171168
func (r *PacketReactor) registerStream(stream *reactorStream) {
169+
if stream.pushable != nil {
170+
stream.pushable.SetOnDataReady(func() {
171+
if stream.state.CompareAndSwap(stateIdle, stateActive) {
172+
go stream.runActiveLoop(nil)
173+
}
174+
})
175+
if stream.pushable.HasPendingData() {
176+
if stream.state.CompareAndSwap(stateIdle, stateActive) {
177+
go stream.runActiveLoop(nil)
178+
}
179+
}
180+
return
181+
}
182+
172183
if stream.pollable == nil {
173184
go stream.runLegacyCopy()
174185
return
175186
}
176187

177-
switch stream.pollable.PollMode() {
178-
case N.PacketPollModeChannel:
179-
r.channelPoller.Add(stream, stream.pollable.PacketChannel())
180-
case N.PacketPollModeFD:
181-
fdPoller, err := r.getFDPoller()
182-
if err != nil {
183-
go stream.runLegacyCopy()
184-
return
185-
}
186-
err = fdPoller.Add(stream, stream.pollable.FD())
187-
if err != nil {
188-
go stream.runLegacyCopy()
189-
}
190-
default:
188+
fdPoller, err := r.getFDPoller()
189+
if err != nil {
190+
go stream.runLegacyCopy()
191+
return
192+
}
193+
err = fdPoller.Add(stream, stream.pollable.FD())
194+
if err != nil {
191195
go stream.runLegacyCopy()
192196
}
193197
}
@@ -259,9 +263,18 @@ func (s *reactorStream) runActiveLoop(firstPacket *N.PacketBuffer) {
259263
if setter, ok := s.source.(interface{ SetReadDeadline(time.Time) error }); ok {
260264
setter.SetReadDeadline(time.Time{})
261265
}
262-
if s.state.CompareAndSwap(stateActive, stateIdle) {
263-
s.returnToPool()
266+
if !s.state.CompareAndSwap(stateActive, stateIdle) {
267+
return
268+
}
269+
if s.pushable != nil {
270+
if s.pushable.HasPendingData() {
271+
if s.state.CompareAndSwap(stateIdle, stateActive) {
272+
continue
273+
}
274+
}
275+
return
264276
}
277+
s.returnToPool()
265278
return
266279
}
267280
if !notFirstTime {
@@ -310,30 +323,18 @@ func (s *reactorStream) returnToPool() {
310323
return
311324
}
312325

313-
if s.pollable == nil {
326+
if s.pollable == nil || s.connection.reactor.fdPoller == nil {
314327
return
315328
}
316329

317-
switch s.pollable.PollMode() {
318-
case N.PacketPollModeChannel:
319-
channel := s.pollable.PacketChannel()
320-
s.connection.reactor.channelPoller.Add(s, channel)
321-
if s.state.Load() != stateIdle {
322-
s.connection.reactor.channelPoller.Remove(channel)
323-
}
324-
case N.PacketPollModeFD:
325-
if s.connection.reactor.fdPoller == nil {
326-
return
327-
}
328-
fd := s.pollable.FD()
329-
err := s.connection.reactor.fdPoller.Add(s, fd)
330-
if err != nil {
331-
s.closeWithError(err)
332-
return
333-
}
334-
if s.state.Load() != stateIdle {
335-
s.connection.reactor.fdPoller.Remove(fd)
336-
}
330+
fd := s.pollable.FD()
331+
err := s.connection.reactor.fdPoller.Add(s, fd)
332+
if err != nil {
333+
s.closeWithError(err)
334+
return
335+
}
336+
if s.state.Load() != stateIdle {
337+
s.connection.reactor.fdPoller.Remove(fd)
337338
}
338339
}
339340

@@ -385,15 +386,8 @@ func (c *reactorConnection) removeFromPollers() {
385386
}
386387

387388
func (c *reactorConnection) removeStreamFromPoller(stream *reactorStream) {
388-
if stream == nil || stream.pollable == nil {
389+
if stream == nil || stream.pollable == nil || c.reactor.fdPoller == nil {
389390
return
390391
}
391-
switch stream.pollable.PollMode() {
392-
case N.PacketPollModeChannel:
393-
c.reactor.channelPoller.Remove(stream.pollable.PacketChannel())
394-
case N.PacketPollModeFD:
395-
if c.reactor.fdPoller != nil {
396-
c.reactor.fdPoller.Remove(stream.pollable.FD())
397-
}
398-
}
392+
c.reactor.fdPoller.Remove(stream.pollable.FD())
399393
}

common/bufio/packet_reactor_test.go

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,6 @@ func (p *testPacketPipe) SetWriteDeadline(t time.Time) error {
103103
return nil
104104
}
105105

106-
func (p *testPacketPipe) PacketChannel() <-chan *N.PacketBuffer {
107-
return p.inChan
108-
}
109-
110106
func (p *testPacketPipe) send(data []byte, destination M.Socksaddr) {
111107
packet := N.NewPacketBuffer()
112108
newBuf := buf.NewSize(len(data))
@@ -255,10 +251,6 @@ func (c *channelPacketConn) SetReadDeadline(t time.Time) error {
255251
return nil
256252
}
257253

258-
func (c *channelPacketConn) PacketChannel() <-chan *N.PacketBuffer {
259-
return c.packetChan
260-
}
261-
262254
func (c *channelPacketConn) Close() error {
263255
c.closeOnce.Do(func() {
264256
close(c.done)
@@ -551,7 +543,7 @@ func TestBatchCopy_FDPoller_DataIntegrity(t *testing.T) {
551543
assert.Equal(t, serverPair.sendHash, clientPair.recvHash, "server->client mismatch")
552544
}
553545

554-
func TestBatchCopy_ChannelPoller_DataIntegrity(t *testing.T) {
546+
func TestBatchCopy_LegacyChannel_DataIntegrity(t *testing.T) {
555547
t.Parallel()
556548

557549
clientConn, err := net.ListenPacket("udp", "127.0.0.1:0")

common/network/packet_pollable.go

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,19 @@
11
package network
22

3-
type PacketPollMode int
4-
5-
const (
6-
PacketPollModeChannel PacketPollMode = iota
7-
PacketPollModeFD
8-
)
3+
// PacketPushable represents a packet source that receives pushed data
4+
// from external code and notifies reactor via callback.
5+
type PacketPushable interface {
6+
SetOnDataReady(callback func())
7+
HasPendingData() bool
8+
}
99

10-
// PacketPollable provides polling support for packet connections
10+
// PacketPollable provides FD-based polling for packet connections.
11+
// Mirrors StreamPollable for consistency.
1112
type PacketPollable interface {
12-
PollMode() PacketPollMode
13-
PacketChannel() <-chan *PacketBuffer
1413
FD() int
1514
}
1615

17-
// PacketPollableCreator creates a PacketPollable dynamically
16+
// PacketPollableCreator creates a PacketPollable dynamically.
1817
type PacketPollableCreator interface {
1918
CreatePacketPollable() (PacketPollable, bool)
2019
}

0 commit comments

Comments
 (0)