Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 97 additions & 1 deletion common/buf/copy.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ package buf

import (
"io"
"sync"
"time"

"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/signal"
"github.com/xtls/xray-core/features/policy"
"github.com/xtls/xray-core/features/stats"
)

Expand Down Expand Up @@ -113,7 +115,12 @@ func Copy(reader Reader, writer Writer, options ...CopyOption) error {
for _, option := range options {
option(&handler)
}
err := copyInternal(reader, writer, &handler)
var err error
if sReader, ok := reader.(*SingleReader); ok && false {
err = copyV(sReader, writer, &handler)
} else {
err = copyInternal(reader, writer, &handler)
}
if err != nil && errors.Cause(err) != io.EOF {
return err
}
Expand All @@ -133,3 +140,92 @@ func CopyOnceTimeout(reader Reader, writer Writer, timeout time.Duration) error
}
return writer.WriteMultiBuffer(mb)
}

func copyV(r *SingleReader, w Writer, handler *copyHandler) error {
// channel buffer size is maxBuffer/maxPerPacketLen (ignore the case of many small packets)
// default buffer size:
// 0 in ARM MIPS MIPSLE
// 4kb in ARM64 MIPS64 MIPS64LE
// 512kb in others
channelBuffer := (policy.SessionDefault().Buffer.PerConnection) / Size
if channelBuffer <= 0 {
channelBuffer = 4
}
cache := make(chan *Buffer, channelBuffer)
stopRead := make(chan struct{})
var rErr error
var wErr error
wg := sync.WaitGroup{}
wg.Add(2)
// downlink
go func() {
defer wg.Done()
defer close(cache)
for {
b, err := r.readBuffer()
if err == nil {
select {
case cache <- b:
// must be write error
case <-stopRead:
b.Release()
return
}
} else {
rErr = err
select {
case cache <- b:
case <-stopRead:
b.Release()
}
return
}
}
}()
// uplink
go func() {
defer wg.Done()
for {
b, ok := <-cache
if !ok {
return
}
var buffers = []*Buffer{b}
for stop := false; !stop; {
select {
case b, ok := <-cache:
if !ok {
stop = true
continue
}
buffers = append(buffers, b)
default:
stop = true
}
}
mb := MultiBuffer(buffers)
err := w.WriteMultiBuffer(mb)
for _, handler := range handler.onData {
handler(mb)
}
ReleaseMulti(mb)
if err != nil {
wErr = err
close(stopRead)
return
}
}
}()
wg.Wait()
// drain cache
for b := range cache {
b.Release()
}
if wErr != nil {
return writeError{wErr}
}
if rErr != nil {
return readError{rErr}
}
return nil
}
5 changes: 5 additions & 0 deletions common/buf/reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,11 @@ func (r *SingleReader) ReadMultiBuffer() (MultiBuffer, error) {
return MultiBuffer{b}, err
}

func (r *SingleReader) readBuffer() (*Buffer, error) {
b, err := ReadBuffer(r.Reader)
return b, err
}

// PacketReader is a Reader that read one Buffer every time.
type PacketReader struct {
io.Reader
Expand Down
45 changes: 18 additions & 27 deletions transport/pipe/impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package pipe
import (
"errors"
"io"
"runtime"
"sync"
"time"

Expand Down Expand Up @@ -136,11 +135,10 @@ func (p *pipe) writeMultiBufferInternal(mb buf.MultiBuffer) error {

if p.data == nil {
p.data = mb
return nil
} else {
p.data, _ = buf.MergeMulti(p.data, mb)
}

p.data, _ = buf.MergeMulti(p.data, mb)
return errSlowDown
return nil
}

func (p *pipe) WriteMultiBuffer(mb buf.MultiBuffer) error {
Expand All @@ -155,30 +153,23 @@ func (p *pipe) WriteMultiBuffer(mb buf.MultiBuffer) error {
return nil
}

if err == errSlowDown {
p.readSignal.Signal()

// Yield current goroutine. Hopefully the reading counterpart can pick up the payload.
runtime.Gosched()
return nil
}

if err == errBufferFull && p.option.discardOverflow {
buf.ReleaseMulti(mb)
return nil
if err == errBufferFull {
if p.option.discardOverflow {
buf.ReleaseMulti(mb)
return nil
}
select {
case <-p.writeSignal.Wait():
continue
case <-p.done.Wait():
buf.ReleaseMulti(mb)
return io.ErrClosedPipe
}
}

if err != errBufferFull {
buf.ReleaseMulti(mb)
p.readSignal.Signal()
return err
}

select {
case <-p.writeSignal.Wait():
case <-p.done.Wait():
return io.ErrClosedPipe
}
buf.ReleaseMulti(mb)
p.readSignal.Signal()
return err
}
}

Expand Down