Skip to content

Commit 3ba7de3

Browse files
committed
Bugfix: The server need to keep running after the connection closed by client
1 parent db8ecdd commit 3ba7de3

File tree

3 files changed

+61
-36
lines changed

3 files changed

+61
-36
lines changed

client/util/client.go

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
package client
22

33
import (
4-
"sync/atomic"
54
"encoding/binary"
65
"log"
76
"net"
87
"strconv"
98
"strings"
9+
"sync/atomic"
1010
"time"
1111

1212
"github.com/leviathan1995/spleen/service"
@@ -37,7 +37,7 @@ func NewClient(clientID int, serverIP string, serverPort int, limitRate []string
3737
}
3838
}
3939

40-
var connections int64 = 0;
40+
var connections uint64 = 0
4141

4242
func (c *client) Run() {
4343
log.Printf("Begin to running the client[%d]", c.clientID)
@@ -46,20 +46,21 @@ func (c *client) Run() {
4646
}
4747

4848
for {
49-
if atomic.LoadInt64(&connections) < 10 {
50-
for i := atomic.LoadInt64(&connections); i < 10; i++ {
49+
if atomic.LoadUint64(&connections) < 10 {
50+
for i := atomic.LoadUint64(&connections); i < 10; i++ {
5151
srvConn, err := c.DialSrv()
5252
if err != nil {
53-
log.Printf("Connect to the server %s:%d failed: %s. \n", c.srvAddr.IP.String(), c.srvAddr.Port, err)
53+
log.Printf("Connect to the proxy %s:%d failed: %s. \n", c.srvAddr.IP.String(), c.srvAddr.Port, err)
5454
continue
5555
}
56-
log.Printf("Connect to the server %s:%d successful.\n", c.srvAddr.IP.String(), c.srvAddr.Port)
56+
log.Printf("Connect to the proxy %s:%d successful.\n", c.srvAddr.IP.String(), c.srvAddr.Port)
5757
srvConn.SetKeepAlive(true)
5858
srvConn.SetLinger(0)
59-
connections = atomic.AddInt64(&connections, 1)
59+
atomic.StoreUint64(&connections, atomic.AddUint64(&connections, 1))
6060
go c.handleConn(srvConn)
6161
}
6262
} else {
63+
log.Printf("Currently, We still have %d active connections.", atomic.LoadUint64(&connections))
6364
time.Sleep(1 * time.Second)
6465
}
6566
}
@@ -70,55 +71,57 @@ func (c *client) DialSrv() (*net.TCPConn, error) {
7071
}
7172

7273
func (c *client) handleConn(srvConn *net.TCPConn) {
73-
defer srvConn.Close()
74-
75-
transBuf := make([]byte, 8)
74+
transBuf := make([]byte, service.IDBuf)
7675
/* Send the ID of client to proxy. */
7776
binary.LittleEndian.PutUint64(transBuf, uint64(c.clientID))
7877
err := c.TCPWrite(srvConn, transBuf)
7978
if err != nil {
80-
connections = atomic.AddInt64(&connections, -1)
79+
atomic.StoreUint64(&connections, atomic.AddUint64(&connections, ^uint64(1-1)))
80+
_ = srvConn.Close()
8181
log.Println("Try to send the ID of client to the proxy failed.")
8282
return
8383
}
8484

8585
/* Waiting for the transfer port from proxy. */
86-
nRead, err := srvConn.Read(transBuf)
87-
connections = atomic.AddInt64(&connections, -1)
86+
err = c.TCPRead(srvConn, transBuf, service.PortBuf)
87+
atomic.StoreUint64(&connections, atomic.AddUint64(&connections, ^uint64(1-1)))
8888
if err != nil {
89-
log.Println("Try to read the destination port failed.")
89+
_ = srvConn.Close()
90+
log.Println("Try to read destination port from the proxy failed.")
9091
return
9192
}
92-
port := int64(binary.LittleEndian.Uint64(transBuf[:nRead]))
93+
port := int64(binary.LittleEndian.Uint64(transBuf))
9394

9495
/* Try to direct connect to the destination sever. */
9596
dstAddr, err := net.ResolveTCPAddr("tcp", ":"+strconv.Itoa(int(port)))
9697
if err != nil {
97-
log.Printf("Try to resolve TCPAddr %s failed: %s.\n", "localhost"+":"+string(port), err.Error())
98+
_ = srvConn.Close()
99+
log.Printf("Try to resolve TCPAddr %s failed: %s.\n", "localhost"+":"+strconv.FormatInt(port, 10), err.Error())
98100
return
99101
}
100102

101103
dstConn, err := net.DialTCP("tcp", nil, dstAddr)
102104
if err != nil {
105+
_ = srvConn.Close()
103106
log.Printf("Connect to localhost:%d failed.", dstAddr.Port)
104107
return
105108
} else {
106109
log.Printf("Connect to the destination address localhost:%d successful.", dstAddr.Port)
107110
}
108-
defer dstConn.Close()
109111

110-
dstConn.SetKeepAlive(true)
112+
_ = dstConn.SetKeepAlive(true)
111113
_ = dstConn.SetLinger(0)
112114

113115
var limitRate int64
114-
115116
if rate, found := c.limitRate[port]; found {
116117
limitRate = rate * 1024 /* bytes */
117118
}
118119

119120
go func() {
120121
errTransfer := c.TransferToTCP(dstConn, srvConn, limitRate)
121122
if errTransfer != nil {
123+
_ = srvConn.Close()
124+
_ = dstConn.Close()
122125
return
123126
}
124127
}()

server/util/server.go

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -70,15 +70,23 @@ func (s *server) ListenForIntranet(tcpAddr *net.TCPAddr) {
7070
if err != nil {
7171
continue
7272
}
73-
/* The proxy should get the ID of the client first. */
74-
transBuf := make([]byte, 8)
75-
nRead, err := conn.Read(transBuf)
73+
74+
/* The proxy should get the magic number and ID of the client first. */
75+
transBuf := make([]byte, service.IDBuf)
76+
err = s.TCPRead(conn, transBuf, service.IDBuf)
7677
if err != nil {
78+
_ = conn.Close()
7779
log.Println("Try to read the destination port failed.")
78-
return
80+
continue
81+
}
82+
83+
id := int64(binary.LittleEndian.Uint64(transBuf))
84+
if ConnectionPool.Has(s.connectionPool, id) == false {
85+
_ = conn.Close()
86+
continue
87+
} else {
88+
s.connectionPool[id] <- conn
7989
}
80-
id := int64(binary.LittleEndian.Uint64(transBuf[:nRead]))
81-
s.connectionPool[id] <- conn
8290
}
8391
}
8492

@@ -100,41 +108,41 @@ func (s *server) Listen() {
100108
}
101109

102110
func (s *server) handleConn(cliConn *net.TCPConn, clientID int64, transferPort uint64) {
103-
defer cliConn.Close()
104-
105111
select {
106112
case intranetConn := <-s.connectionPool[clientID]:
107113
_ = intranetConn.SetLinger(0)
108114

109115
/* Send the transfer port to intranet server . */
110-
portBuf := make([]byte, 8)
116+
portBuf := make([]byte, service.PortBuf)
111117
binary.LittleEndian.PutUint64(portBuf, transferPort)
112118
err := s.TCPWrite(intranetConn, portBuf)
113119
if err != nil {
114-
intranetConn.Close()
120+
_ = cliConn.Close()
121+
_ = intranetConn.Close()
115122
for {
116123
/* Close all connections from this client. */
117124
select {
118125
case intranetConn = <-s.connectionPool[clientID]:
119-
intranetConn.Close()
126+
_ = intranetConn.SetLinger(0)
127+
_ = intranetConn.Close()
120128
default:
121129
return
122130
}
123131
}
124132
}
125133

126-
log.Printf("Make a successful connection between the user and the intranet server[Client ID: %d - Port: %d].", clientID, transferPort)
134+
log.Printf("Make a successful connection between the user [%s] and the intranet server[Client ID: %d - Port: %d].",
135+
cliConn.RemoteAddr().String(), clientID, transferPort)
127136
/* Transfer network packets. */
128137
go func() {
129138
errTransfer := s.TransferToTCP(cliConn, intranetConn, 0)
130139
if errTransfer != nil {
131-
intranetConn.Close()
140+
_ = cliConn.Close()
141+
_ = intranetConn.Close()
132142
return
133143
}
134144
}()
135145
err = s.TransferToTCP(intranetConn, cliConn, 0)
136146
return
137-
default:
138-
log.Printf("Currently, Do not have any active connection from the intranet server[Client ID: %d - Port: %d].", clientID, transferPort)
139147
}
140148
}

service/service.go

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ import (
55
"time"
66
)
77

8-
const BUFFERSIZE = 1024 * 16
8+
const TransferBuf = 1024 * 16
9+
const PortBuf = 8
10+
const IDBuf = 8
911

1012
type Service struct {
1113
IP string
@@ -25,9 +27,21 @@ func (s *Service) TCPWrite(conn *net.TCPConn, buf []byte) error {
2527
return nil
2628
}
2729

30+
func (s *Service) TCPRead(conn *net.TCPConn, buf []byte, len int) error {
31+
nRead := 0
32+
for nRead < len {
33+
n, errRead := conn.Read(buf[nRead:])
34+
if errRead != nil {
35+
return errRead
36+
}
37+
nRead += n
38+
}
39+
return nil
40+
}
41+
2842
func (s *Service) TransferToTCP(cliConn net.Conn, dstConn *net.TCPConn, limitRate int64) error {
2943
var totalRead, lastTime int64
30-
buf := make([]byte, BUFFERSIZE)
44+
buf := make([]byte, TransferBuf)
3145

3246
for {
3347
nRead, errRead := cliConn.Read(buf)

0 commit comments

Comments
 (0)