Skip to content

Commit b5604ec

Browse files
committed
feat: 1. 添加tcp load balancer轮询策略 2. 所有generateName修饰符由protected调整为public 3. 添加lb相关单元测试
1 parent 246571c commit b5604ec

File tree

10 files changed

+233
-16
lines changed

10 files changed

+233
-16
lines changed
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
package top.meethigher.proxy;
2+
3+
/**
4+
* 负载均衡策略
5+
*
6+
* @author <a href="https://meethigher.top">chenchuancheng</a>
7+
* @since 2025/07/26 13:05
8+
*/
9+
public interface LoadBalancer<T> {
10+
T next();
11+
12+
String name();
13+
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
package top.meethigher.proxy;
2+
3+
import java.util.Objects;
4+
5+
public class NetAddress {
6+
private final String host;
7+
private final int port;
8+
9+
public NetAddress(String host, int port) {
10+
this.host = host;
11+
this.port = port;
12+
}
13+
14+
public String getHost() {
15+
return host;
16+
}
17+
18+
public int getPort() {
19+
return port;
20+
}
21+
22+
@Override
23+
public String toString() {
24+
return host + ":" + port;
25+
}
26+
27+
@Override
28+
public boolean equals(Object o) {
29+
if (o == null || getClass() != o.getClass()) {
30+
return false;
31+
}
32+
33+
NetAddress that = (NetAddress) o;
34+
return this.host.equals(that.getHost()) && this.port == that.getPort();
35+
}
36+
37+
@Override
38+
public int hashCode() {
39+
return Objects.hashCode(this.toString());
40+
}
41+
}

src/main/java/top/meethigher/proxy/http/ReverseHttpProxy.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ public static ReverseHttpProxy create(Router router, HttpServer httpServer, Http
243243
return new ReverseHttpProxy(httpServer, httpClient, router, generateName());
244244
}
245245

246-
protected static String generateName() {
246+
public static String generateName() {
247247
final String prefix = ReverseHttpProxy.class.getSimpleName() + "-";
248248
try {
249249
// 池号对于虚拟机来说是全局的,以避免在类加载器范围的环境中池号重叠

src/main/java/top/meethigher/proxy/tcp/ReverseTcpProxy.java

Lines changed: 71 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@
1010
import io.vertx.core.net.SocketAddress;
1111
import org.slf4j.Logger;
1212
import org.slf4j.LoggerFactory;
13+
import top.meethigher.proxy.LoadBalancer;
14+
import top.meethigher.proxy.NetAddress;
1315

16+
import java.util.ArrayList;
17+
import java.util.List;
1418
import java.util.concurrent.ThreadLocalRandom;
1519

1620
/**
@@ -32,24 +36,32 @@ public class ReverseTcpProxy {
3236
protected final Handler<NetSocket> connectHandler;
3337
protected final NetServer netServer;
3438
protected final NetClient netClient;
35-
protected final String targetHost;
36-
protected final int targetPort;
39+
protected final LoadBalancer<NetAddress> lb;
40+
protected final List<NetAddress> netAddresses;
3741
protected final String name;
3842

3943
protected ReverseTcpProxy(NetServer netServer, NetClient netClient,
40-
String targetHost, int targetPort, String name) {
44+
LoadBalancer<NetAddress> loadBalancer,
45+
List<NetAddress> netAddresses,
46+
String name) {
4147
this.name = name;
42-
this.targetHost = targetHost;
43-
this.targetPort = targetPort;
48+
this.lb = loadBalancer;
49+
this.netAddresses = netAddresses;
4450
this.netServer = netServer;
4551
this.netClient = netClient;
4652
this.connectHandler = sourceSocket -> {
4753
// 暂停流读取
4854
sourceSocket.pause();
4955
SocketAddress sourceRemote = sourceSocket.remoteAddress();
5056
SocketAddress sourceLocal = sourceSocket.localAddress();
51-
log.debug("source {} -- {} connected", sourceLocal, sourceRemote);
5257
sourceSocket.closeHandler(v -> log.debug("source {} -- {} closed", sourceLocal, sourceRemote));
58+
NetAddress next = lb.next();
59+
String targetHost = next.getHost();
60+
int targetPort = next.getPort();
61+
log.debug("source {} -- {} connected. lb [{}] next target {}", sourceLocal, sourceRemote,
62+
lb.name(),
63+
next
64+
);
5365
netClient.connect(targetPort, targetHost)
5466
.onFailure(e -> {
5567
log.error("failed to connect to {}:{}", targetHost, targetPort, e);
@@ -61,7 +73,7 @@ protected ReverseTcpProxy(NetServer netServer, NetClient netClient,
6173
SocketAddress targetRemote = targetSocket.remoteAddress();
6274
SocketAddress targetLocal = targetSocket.localAddress();
6375
log.debug("target {} -- {} connected", targetLocal, targetRemote);
64-
76+
6577
// feat: v1.0.5以前的版本,在closeHandler里面,将对端连接也关闭。比如targetSocket关闭时,则将sourceSocket也关闭。
6678
// 结果导致在转发短连接时,出现了bug。参考https://github.com/meethigher/tcp-reverse-proxy/issues/6
6779
targetSocket.closeHandler(v -> log.debug("target {} -- {} closed", targetLocal, targetRemote));
@@ -88,20 +100,56 @@ protected ReverseTcpProxy(NetServer netServer, NetClient netClient,
88100

89101
public static ReverseTcpProxy create(Vertx vertx,
90102
String targetHost, int targetPort, String name) {
91-
return new ReverseTcpProxy(vertx.createNetServer(), vertx.createNetClient(), targetHost, targetPort, name);
103+
List<NetAddress> list = new ArrayList<>();
104+
TcpRoundRobinLoadBalancer lb = TcpRoundRobinLoadBalancer.create(list);
105+
return new ReverseTcpProxy(
106+
vertx.createNetServer(),
107+
vertx.createNetClient(),
108+
lb,
109+
list,
110+
name
111+
).addNode(new NetAddress(targetHost, targetPort));
92112
}
93113

94114
public static ReverseTcpProxy create(Vertx vertx,
95115
String targetHost, int targetPort) {
96-
return new ReverseTcpProxy(vertx.createNetServer(), vertx.createNetClient(), targetHost, targetPort, generateName());
116+
List<NetAddress> list = new ArrayList<>();
117+
return new ReverseTcpProxy(
118+
vertx.createNetServer(),
119+
vertx.createNetClient(),
120+
TcpRoundRobinLoadBalancer.create(list),
121+
list,
122+
generateName()
123+
).addNode(new NetAddress(targetHost, targetPort));
97124
}
98125

99126
public static ReverseTcpProxy create(NetServer netServer, NetClient netClient, String targetHost, int targetPort) {
100-
return new ReverseTcpProxy(netServer, netClient, targetHost, targetPort, generateName());
127+
List<NetAddress> list = new ArrayList<>();
128+
return new ReverseTcpProxy(
129+
netServer,
130+
netClient,
131+
TcpRoundRobinLoadBalancer.create(list),
132+
list,
133+
generateName()
134+
).addNode(new NetAddress(targetHost, targetPort));
101135
}
102136

103137
public static ReverseTcpProxy create(NetServer netServer, NetClient netClient, String targetHost, int targetPort, String name) {
104-
return new ReverseTcpProxy(netServer, netClient, targetHost, targetPort, name);
138+
List<NetAddress> list = new ArrayList<>();
139+
return new ReverseTcpProxy(
140+
netServer,
141+
netClient,
142+
TcpRoundRobinLoadBalancer.create(list),
143+
list,
144+
name
145+
).addNode(new NetAddress(targetHost, targetPort));
146+
}
147+
148+
public static ReverseTcpProxy create(NetServer netServer, NetClient netClient,
149+
LoadBalancer<NetAddress> loadBalancer,
150+
List<NetAddress> netAddresses,
151+
String name) {
152+
return new ReverseTcpProxy(netServer, netClient, loadBalancer, netAddresses, name);
105153
}
106154

107155
public ReverseTcpProxy port(int port) {
@@ -114,8 +162,15 @@ public ReverseTcpProxy host(String host) {
114162
return this;
115163
}
116164

165+
public ReverseTcpProxy addNode(NetAddress netAddress) {
166+
if (!netAddresses.contains(netAddress)) {
167+
netAddresses.add(netAddress);
168+
}
169+
return this;
170+
}
171+
117172

118-
protected static String generateName() {
173+
public static String generateName() {
119174
final String prefix = ReverseTcpProxy.class.getSimpleName() + "-";
120175
try {
121176
// 池号对于虚拟机来说是全局的,以避免在类加载器范围的环境中池号重叠
@@ -135,12 +190,15 @@ protected static String generateName() {
135190
}
136191

137192
public void start() {
193+
if (netAddresses.size() <= 0) {
194+
throw new IllegalStateException("netAddresses size must be greater than 0");
195+
}
138196
netServer.connectHandler(connectHandler).exceptionHandler(e -> log.error("connect failed", e));
139197
Future<NetServer> listenFuture = netServer.listen(sourcePort, sourceHost);
140198

141199
Handler<AsyncResult<NetServer>> asyncResultHandler = ar -> {
142200
if (ar.succeeded()) {
143-
log.info("{} started on {}:{}", name, sourceHost, sourcePort);
201+
log.info("{} started on {}:{}\nLB-Mode: {}\n {}", name, sourceHost, sourcePort, lb.name(), netAddresses);
144202
} else {
145203
Throwable e = ar.cause();
146204
log.error("{} start failed", name, e);
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package top.meethigher.proxy.tcp;
2+
3+
import top.meethigher.proxy.LoadBalancer;
4+
import top.meethigher.proxy.NetAddress;
5+
6+
import java.util.List;
7+
import java.util.concurrent.atomic.AtomicInteger;
8+
9+
/**
10+
* 轮询策略实现
11+
*
12+
* @author <a href="https://meethigher.top">chenchuancheng</a>
13+
* @since 2025/07/26 13:41
14+
*/
15+
public class TcpRoundRobinLoadBalancer implements LoadBalancer<NetAddress> {
16+
17+
private final List<NetAddress> nodes;
18+
19+
private final AtomicInteger idx = new AtomicInteger(0);
20+
21+
private final String name = "TcpRoundRobinLoadBalancer";
22+
23+
private TcpRoundRobinLoadBalancer(List<NetAddress> nodes) {
24+
this.nodes = nodes;
25+
}
26+
27+
28+
public NetAddress next() {
29+
if (nodes == null) {
30+
return null;
31+
}
32+
int index = idx.getAndUpdate(v -> (v + 1) % nodes.size());
33+
return nodes.get(index);
34+
}
35+
36+
@Override
37+
public String name() {
38+
return name;
39+
}
40+
41+
public static TcpRoundRobinLoadBalancer create(List<NetAddress> nodes) {
42+
return new TcpRoundRobinLoadBalancer(nodes);
43+
}
44+
}

src/main/java/top/meethigher/proxy/tcp/tunnel/ReverseTcpProxyTunnelClient.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ public class ReverseTcpProxyTunnelClient extends TunnelClient {
5353
protected String dataProxyName = "ssh-proxy";
5454

5555

56-
protected static String generateName() {
56+
public static String generateName() {
5757
final String prefix = ReverseTcpProxyTunnelClient.class.getSimpleName() + "-";
5858
try {
5959
// 池号对于虚拟机来说是全局的,以避免在类加载器范围的环境中池号重叠

src/main/java/top/meethigher/proxy/tcp/tunnel/ReverseTcpProxyTunnelServer.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ public static ReverseTcpProxyTunnelServer create(Vertx vertx) {
125125
}
126126

127127

128-
protected static String generateName() {
128+
public static String generateName() {
129129
final String prefix = ReverseTcpProxyTunnelServer.class.getSimpleName() + "-";
130130
try {
131131
// 池号对于虚拟机来说是全局的,以避免在类加载器范围的环境中池号重叠
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package top.meethigher.proxy;
2+
3+
import org.junit.Test;
4+
5+
public class NetAddressTest {
6+
@Test
7+
public void name() {
8+
NetAddress netAddress1 = new NetAddress("127.0.0.1", 6666);
9+
NetAddress netAddress2 = new NetAddress("127.0.0.1", 6666);
10+
NetAddress netAddress3 = new NetAddress("127.0.0.1", 6667);
11+
System.out.println(netAddress2.equals(netAddress1));
12+
System.out.println(netAddress3.equals(netAddress1));
13+
System.out.println(netAddress2 == netAddress1);
14+
System.out.println(netAddress3 == netAddress1);
15+
}
16+
}

src/test/java/top/meethigher/proxy/tcp/ReverseTcpProxyTest.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
package top.meethigher.proxy.tcp;
22

33
import io.vertx.core.Vertx;
4+
import io.vertx.core.net.NetClient;
5+
import io.vertx.core.net.NetServer;
46
import org.junit.Test;
7+
import top.meethigher.proxy.NetAddress;
58

9+
import java.util.ArrayList;
10+
import java.util.List;
611
import java.util.concurrent.TimeUnit;
12+
import java.util.concurrent.locks.LockSupport;
713

814
public class ReverseTcpProxyTest {
915

@@ -15,4 +21,18 @@ public void testVertxTCPReverseProxy() throws Exception {
1521
proxy.stop();
1622
}
1723

24+
25+
@Test
26+
public void testLb() {
27+
Vertx vertx = Vertx.vertx();
28+
NetServer netServer = vertx.createNetServer();
29+
NetClient netClient = vertx.createNetClient();
30+
List<NetAddress> list = new ArrayList<>();
31+
ReverseTcpProxy.create(netServer, netClient, TcpRoundRobinLoadBalancer.create(list), list, ReverseTcpProxy.generateName())
32+
.addNode(new NetAddress("10.0.0.20", 22))
33+
.addNode(new NetAddress("10.0.0.30", 22))
34+
.start();
35+
36+
LockSupport.park();
37+
}
1838
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package top.meethigher.proxy.tcp;
2+
3+
import org.junit.Test;
4+
import top.meethigher.proxy.NetAddress;
5+
6+
import java.util.ArrayList;
7+
import java.util.List;
8+
9+
public class TcpRoundRobinLoadBalancerTest {
10+
11+
@Test
12+
public void next() {
13+
List<NetAddress> nodes = new ArrayList<>();
14+
nodes.add(new NetAddress("127.0.0.1", 6666));
15+
nodes.add(new NetAddress("127.0.0.1", 6667));
16+
TcpRoundRobinLoadBalancer balancer = TcpRoundRobinLoadBalancer.create(nodes);
17+
System.out.println(balancer.next());
18+
System.out.println(balancer.next());
19+
System.out.println(balancer.next());
20+
nodes.add(new NetAddress("127.0.0.1", 6668));
21+
System.out.println(balancer.next());
22+
System.out.println(balancer.next());
23+
System.out.println(balancer.next());
24+
}
25+
}

0 commit comments

Comments
 (0)