Skip to content

Commit a9f4d2e

Browse files
authored
Merge pull request #169 from SenseUnit/pipeconn
Pipeconn
2 parents 242380c + 4ace4a9 commit a9f4d2e

File tree

2 files changed

+137
-5
lines changed

2 files changed

+137
-5
lines changed

dialer/pipewrap.go

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
package dialer
2+
3+
import (
4+
"fmt"
5+
"io"
6+
"net"
7+
"sync"
8+
"time"
9+
10+
"github.com/hashicorp/go-multierror"
11+
)
12+
13+
type ReadPipe interface {
14+
io.Reader
15+
io.WriterTo
16+
io.Closer
17+
Fd() uintptr
18+
SetReadDeadline(t time.Time) error
19+
}
20+
21+
type WritePipe interface {
22+
io.Writer
23+
io.ReaderFrom
24+
io.Closer
25+
Fd() uintptr
26+
SetWriteDeadline(t time.Time) error
27+
}
28+
29+
type PipeAddr struct {
30+
rfd uintptr
31+
wfd uintptr
32+
}
33+
34+
func (_ PipeAddr) Network() string {
35+
return "pipe"
36+
}
37+
38+
func (a PipeAddr) String() string {
39+
return fmt.Sprintf("<read fd: %d, write rd: %d>", a.rfd, a.wfd)
40+
}
41+
42+
type PipeConn struct {
43+
r ReadPipe
44+
w WritePipe
45+
rc sync.Once
46+
wc sync.Once
47+
}
48+
49+
func NewPipeConn(r ReadPipe, w WritePipe) *PipeConn {
50+
return &PipeConn{
51+
r: r,
52+
w: w,
53+
}
54+
}
55+
56+
func (c *PipeConn) Read(p []byte) (n int, err error) {
57+
return c.r.Read(p)
58+
}
59+
60+
func (c *PipeConn) Write(p []byte) (n int, err error) {
61+
return c.w.Write(p)
62+
}
63+
64+
func (c *PipeConn) Close() error {
65+
var err error
66+
if closeErr := c.CloseWrite(); closeErr != nil {
67+
err = multierror.Append(err, closeErr)
68+
}
69+
if closeErr := c.CloseRead(); closeErr != nil {
70+
err = multierror.Append(err, closeErr)
71+
}
72+
return err
73+
}
74+
75+
func (c *PipeConn) CloseWrite() error {
76+
var err error
77+
c.wc.Do(func() {
78+
err = c.w.Close()
79+
})
80+
return err
81+
}
82+
83+
func (c *PipeConn) CloseRead() error {
84+
var err error
85+
c.wc.Do(func() {
86+
err = c.r.Close()
87+
})
88+
return err
89+
}
90+
91+
func (c *PipeConn) LocalAddr() net.Addr {
92+
return PipeAddr{
93+
rfd: c.r.Fd(),
94+
wfd: c.w.Fd(),
95+
}
96+
}
97+
98+
func (c *PipeConn) RemoteAddr() net.Addr {
99+
return c.LocalAddr()
100+
}
101+
102+
func (c *PipeConn) SetReadDeadline(t time.Time) error {
103+
return c.r.SetReadDeadline(t)
104+
}
105+
106+
func (c *PipeConn) SetWriteDeadline(t time.Time) error {
107+
return c.w.SetWriteDeadline(t)
108+
}
109+
110+
func (c *PipeConn) SetDeadline(t time.Time) error {
111+
var err error
112+
if cErr := c.SetReadDeadline(t); err != nil {
113+
err = multierror.Append(err, cErr)
114+
}
115+
if cErr := c.SetWriteDeadline(t); err != nil {
116+
err = multierror.Append(err, cErr)
117+
}
118+
return err
119+
}
120+
121+
func (c *PipeConn) ReadFrom(r io.Reader) (n int64, err error) {
122+
return c.w.ReadFrom(r)
123+
}
124+
125+
func (c *PipeConn) WriteTo(w io.Writer) (n int64, err error) {
126+
return c.r.WriteTo(w)
127+
}
128+
129+
var _ net.Conn = new(PipeConn)
130+
var _ io.ReaderFrom = new(PipeConn)
131+
var _ io.WriterTo = new(PipeConn)

handler/stdio.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,24 @@ package handler
33
import (
44
"context"
55
"fmt"
6-
"io"
76
"net"
87
"sync"
98

109
clog "github.com/SenseUnit/dumbproxy/log"
10+
11+
"github.com/SenseUnit/dumbproxy/dialer"
1112
)
1213

13-
func StdIOHandler(dialer HandlerDialer, logger *clog.CondLogger, forward ForwardFunc) func(ctx context.Context, reader io.Reader, writer io.Writer, dstAddress string) error {
14-
return func(ctx context.Context, reader io.Reader, writer io.Writer, dstAddress string) error {
14+
func StdIOHandler(d HandlerDialer, logger *clog.CondLogger, forward ForwardFunc) func(context.Context, dialer.ReadPipe, dialer.WritePipe, string) error {
15+
return func(ctx context.Context, reader dialer.ReadPipe, writer dialer.WritePipe, dstAddress string) error {
1516
logger.Debug("Request: %v => %v %q %v %v %v", "<stdio>", "<stdio>", "", "STDIO", "CONNECT", dstAddress)
16-
target, err := dialer.DialContext(ctx, "tcp", dstAddress)
17+
target, err := d.DialContext(ctx, "tcp", dstAddress)
1718
if err != nil {
1819
return fmt.Errorf("connect to %q failed: %w", dstAddress, err)
1920
}
2021
defer target.Close()
2122

22-
return forward(ctx, "", wrapSOCKS(reader, writer), target)
23+
return forward(ctx, "", dialer.NewPipeConn(reader, writer), target)
2324
}
2425
}
2526

0 commit comments

Comments
 (0)