Skip to content

Commit c70e39c

Browse files
committed
fixed issues/17
1 parent 954807b commit c70e39c

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

forward/forward.go

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ func Run(stats *ConnectionStats) {
125125
}
126126
innerWg.Add(1)
127127
go func() {
128-
stats.handleTCPConnection(clientConn, ctx)
128+
stats.handleTCPConnection(clientConn, ctx, cancel)
129129
innerWg.Done()
130130
}()
131131
}
@@ -134,7 +134,7 @@ func Run(stats *ConnectionStats) {
134134
}
135135

136136
// TCP转发
137-
func (cs *ConnectionStats) handleTCPConnection(clientConn net.Conn, ctx context.Context) {
137+
func (cs *ConnectionStats) handleTCPConnection(clientConn net.Conn, ctx context.Context, cancel context.CancelFunc) {
138138
defer clientConn.Close()
139139
remoteConn, err := net.Dial("tcp", cs.RemoteAddr+":"+cs.RemotePort)
140140
if err != nil {
@@ -147,11 +147,17 @@ func (cs *ConnectionStats) handleTCPConnection(clientConn net.Conn, ctx context.
147147
copyWG.Add(2)
148148
go func() {
149149
defer copyWG.Done()
150-
cs.copyBytes(clientConn, remoteConn)
150+
if err := cs.copyBytes(clientConn, remoteConn); err != nil {
151+
log.Println("复制字节时发生错误:", err)
152+
cancel() // Assuming `cancel` is the cancel function from the context
153+
}
151154
}()
152155
go func() {
153156
defer copyWG.Done()
154-
cs.copyBytes(remoteConn, clientConn)
157+
if err := cs.copyBytes(remoteConn, clientConn); err != nil {
158+
log.Println("复制字节时发生错误:", err)
159+
cancel() // Assuming `cancel` is the cancel function from the context
160+
}
155161
}()
156162
for {
157163
select {
@@ -204,7 +210,7 @@ func (cs *ConnectionStats) forwardUDPMessage(localConn *net.UDPConn, remoteAddr
204210

205211
}
206212

207-
func (cs *ConnectionStats) copyBytes(dst, src net.Conn) {
213+
func (cs *ConnectionStats) copyBytes(dst, src net.Conn) error {
208214
buf := bufPool.Get().([]byte)
209215
defer bufPool.Put(buf)
210216
for {
@@ -216,7 +222,7 @@ func (cs *ConnectionStats) copyBytes(dst, src net.Conn) {
216222
_, err := dst.Write(buf[:n])
217223
if err != nil {
218224
log.Println("【"+cs.LocalPort+"】写入目标时发生错误:", err)
219-
break
225+
return err
220226
}
221227
}
222228
if err == io.EOF {
@@ -230,6 +236,7 @@ func (cs *ConnectionStats) copyBytes(dst, src net.Conn) {
230236
// 关闭连接
231237
dst.Close()
232238
src.Close()
239+
return nil
233240
}
234241

235242
// 定时打印和处理流量变化

0 commit comments

Comments
 (0)