diff --git a/common/buf/copy.go b/common/buf/copy.go index 4cc3be881d88..9ec5b50e6142 100644 --- a/common/buf/copy.go +++ b/common/buf/copy.go @@ -2,10 +2,12 @@ package buf import ( "io" + "sync" "time" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/signal" + "github.com/xtls/xray-core/features/policy" "github.com/xtls/xray-core/features/stats" ) @@ -113,7 +115,12 @@ func Copy(reader Reader, writer Writer, options ...CopyOption) error { for _, option := range options { option(&handler) } - err := copyInternal(reader, writer, &handler) + var err error + if sReader, ok := reader.(*SingleReader); ok && false { + err = copyV(sReader, writer, &handler) + } else { + err = copyInternal(reader, writer, &handler) + } if err != nil && errors.Cause(err) != io.EOF { return err } @@ -133,3 +140,92 @@ func CopyOnceTimeout(reader Reader, writer Writer, timeout time.Duration) error } return writer.WriteMultiBuffer(mb) } + +func copyV(r *SingleReader, w Writer, handler *copyHandler) error { + // channel buffer size is maxBuffer/maxPerPacketLen (ignore the case of many small packets) + // default buffer size: + // 0 in ARM MIPS MIPSLE + // 4kb in ARM64 MIPS64 MIPS64LE + // 512kb in others + channelBuffer := (policy.SessionDefault().Buffer.PerConnection) / Size + if channelBuffer <= 0 { + channelBuffer = 4 + } + cache := make(chan *Buffer, channelBuffer) + stopRead := make(chan struct{}) + var rErr error + var wErr error + wg := sync.WaitGroup{} + wg.Add(2) + // downlink + go func() { + defer wg.Done() + defer close(cache) + for { + b, err := r.readBuffer() + if err == nil { + select { + case cache <- b: + // must be write error + case <-stopRead: + b.Release() + return + } + } else { + rErr = err + select { + case cache <- b: + case <-stopRead: + b.Release() + } + return + } + } + }() + // uplink + go func() { + defer wg.Done() + for { + b, ok := <-cache + if !ok { + return + } + var buffers = []*Buffer{b} + for stop := false; !stop; { + select { + case b, ok := <-cache: + if !ok { + stop = true + continue + } + buffers = append(buffers, b) + default: + stop = true + } + } + mb := MultiBuffer(buffers) + err := w.WriteMultiBuffer(mb) + for _, handler := range handler.onData { + handler(mb) + } + ReleaseMulti(mb) + if err != nil { + wErr = err + close(stopRead) + return + } + } + }() + wg.Wait() + // drain cache + for b := range cache { + b.Release() + } + if wErr != nil { + return writeError{wErr} + } + if rErr != nil { + return readError{rErr} + } + return nil +} diff --git a/common/buf/reader.go b/common/buf/reader.go index 33d362d427af..ca00043b4785 100644 --- a/common/buf/reader.go +++ b/common/buf/reader.go @@ -159,6 +159,11 @@ func (r *SingleReader) ReadMultiBuffer() (MultiBuffer, error) { return MultiBuffer{b}, err } +func (r *SingleReader) readBuffer() (*Buffer, error) { + b, err := ReadBuffer(r.Reader) + return b, err +} + // PacketReader is a Reader that read one Buffer every time. type PacketReader struct { io.Reader diff --git a/transport/pipe/impl.go b/transport/pipe/impl.go index e5d678272ab9..81172906c656 100644 --- a/transport/pipe/impl.go +++ b/transport/pipe/impl.go @@ -3,7 +3,6 @@ package pipe import ( "errors" "io" - "runtime" "sync" "time" @@ -136,11 +135,10 @@ func (p *pipe) writeMultiBufferInternal(mb buf.MultiBuffer) error { if p.data == nil { p.data = mb - return nil + } else { + p.data, _ = buf.MergeMulti(p.data, mb) } - - p.data, _ = buf.MergeMulti(p.data, mb) - return errSlowDown + return nil } func (p *pipe) WriteMultiBuffer(mb buf.MultiBuffer) error { @@ -155,30 +153,23 @@ func (p *pipe) WriteMultiBuffer(mb buf.MultiBuffer) error { return nil } - if err == errSlowDown { - p.readSignal.Signal() - - // Yield current goroutine. Hopefully the reading counterpart can pick up the payload. - runtime.Gosched() - return nil - } - - if err == errBufferFull && p.option.discardOverflow { - buf.ReleaseMulti(mb) - return nil + if err == errBufferFull { + if p.option.discardOverflow { + buf.ReleaseMulti(mb) + return nil + } + select { + case <-p.writeSignal.Wait(): + continue + case <-p.done.Wait(): + buf.ReleaseMulti(mb) + return io.ErrClosedPipe + } } - if err != errBufferFull { - buf.ReleaseMulti(mb) - p.readSignal.Signal() - return err - } - - select { - case <-p.writeSignal.Wait(): - case <-p.done.Wait(): - return io.ErrClosedPipe - } + buf.ReleaseMulti(mb) + p.readSignal.Signal() + return err } }