Skip to content

Commit 0d94d64

Browse files
authored
Merge pull request #22 from meethigher/issue-15
支持tcpmux
2 parents 246571c + 1b0e033 commit 0d94d64

24 files changed

+1170
-28
lines changed
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
package top.meethigher.proxy;
2+
3+
import javax.crypto.Cipher;
4+
import javax.crypto.KeyGenerator;
5+
import javax.crypto.SecretKey;
6+
import javax.crypto.spec.GCMParameterSpec;
7+
import javax.crypto.spec.SecretKeySpec;
8+
import java.nio.ByteBuffer;
9+
import java.security.SecureRandom;
10+
import java.util.Base64;
11+
12+
/**
13+
* 高性能 AES/GCM 对称加解密工具类(仅依赖 JDK 标准库)
14+
*/
15+
public final class FastAes {
16+
17+
private static final String TRANSFORMATION = "AES/GCM/NoPadding";
18+
private static final int AES_KEY_LEN = 128; // bit
19+
private static final int GCM_TAG_LEN = 128; // bit
20+
private static final int GCM_IV_LEN = 12; // byte
21+
22+
private static final SecureRandom RAND = new SecureRandom();
23+
24+
/* 每个线程复用自己的 Cipher 实例 */
25+
private static final ThreadLocal<Cipher> CIPHER_HOLDER = ThreadLocal.withInitial(() -> {
26+
try {
27+
return Cipher.getInstance(TRANSFORMATION);
28+
} catch (Exception e) {
29+
throw new RuntimeException("AES/GCM not available", e);
30+
}
31+
});
32+
33+
private FastAes() {} // utility class
34+
35+
/* ---------------------------------- 对外 API ---------------------------------- */
36+
37+
/**
38+
* 随机生成 AES-128 密钥
39+
*/
40+
public static SecretKey generateKey() {
41+
try {
42+
KeyGenerator kg = KeyGenerator.getInstance("AES");
43+
kg.init(AES_KEY_LEN);
44+
return kg.generateKey();
45+
} catch (Exception e) {
46+
throw new RuntimeException(e);
47+
}
48+
}
49+
50+
/**
51+
* 将原始密钥字节数组包装成 SecretKey
52+
*/
53+
public static SecretKey restoreKey(byte[] rawKey) {
54+
if (rawKey.length != AES_KEY_LEN / 8) {
55+
throw new IllegalArgumentException("Key length != 16 byte");
56+
}
57+
return new SecretKeySpec(rawKey, "AES");
58+
}
59+
60+
/**
61+
* 加密:返回 byte[],格式为 IV(12B) + CipherText + Tag(16B)
62+
*/
63+
public static byte[] encrypt(byte[] plain, SecretKey key) {
64+
try {
65+
byte[] iv = new byte[GCM_IV_LEN];
66+
RAND.nextBytes(iv);
67+
68+
Cipher cipher = CIPHER_HOLDER.get();
69+
cipher.init(Cipher.ENCRYPT_MODE, key, new GCMParameterSpec(GCM_TAG_LEN, iv));
70+
71+
byte[] cipherText = cipher.doFinal(plain);
72+
73+
return ByteBuffer.allocate(iv.length + cipherText.length)
74+
.put(iv)
75+
.put(cipherText)
76+
.array();
77+
} catch (Exception e) {
78+
throw new RuntimeException("Encrypt error", e);
79+
}
80+
}
81+
82+
/**
83+
* 解密:输入格式须为 IV(12B) + CipherText + Tag(16B)
84+
*/
85+
public static byte[] decrypt(byte[] ivPlusCipherText, SecretKey key) {
86+
try {
87+
if (ivPlusCipherText.length < GCM_IV_LEN) {
88+
throw new IllegalArgumentException("Bad input length");
89+
}
90+
ByteBuffer buf = ByteBuffer.wrap(ivPlusCipherText);
91+
92+
byte[] iv = new byte[GCM_IV_LEN];
93+
buf.get(iv);
94+
95+
byte[] cipherAndTag = new byte[buf.remaining()];
96+
buf.get(cipherAndTag);
97+
98+
Cipher cipher = CIPHER_HOLDER.get();
99+
cipher.init(Cipher.DECRYPT_MODE, key, new GCMParameterSpec(GCM_TAG_LEN, iv));
100+
101+
return cipher.doFinal(cipherAndTag);
102+
} catch (Exception e) {
103+
throw new RuntimeException("Decrypt error", e);
104+
}
105+
}
106+
107+
/* ----------------------------- 简易 Base64 封装 ----------------------------- */
108+
109+
public static String encryptToBase64(byte[] plain, SecretKey key) {
110+
return Base64.getEncoder().encodeToString(encrypt(plain, key));
111+
}
112+
113+
public static byte[] decryptFromBase64(String base64, SecretKey key) {
114+
return decrypt(Base64.getDecoder().decode(base64), key);
115+
}
116+
117+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
package top.meethigher.proxy;
2+
3+
/**
4+
* 负载均衡策略
5+
*
6+
* @author <a href="https://meethigher.top">chenchuancheng</a>
7+
* @since 2025/07/26 13:05
8+
*/
9+
public interface LoadBalancer<T> {
10+
T next();
11+
12+
String name();
13+
}
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
package top.meethigher.proxy;
2+
3+
import java.util.Objects;
4+
5+
public class NetAddress {
6+
private final String host;
7+
private final int port;
8+
9+
public NetAddress(String host, int port) {
10+
this.host = host;
11+
this.port = port;
12+
}
13+
14+
public String getHost() {
15+
return host;
16+
}
17+
18+
public int getPort() {
19+
return port;
20+
}
21+
22+
@Override
23+
public String toString() {
24+
return host + ":" + port;
25+
}
26+
27+
@Override
28+
public boolean equals(Object o) {
29+
if (o == null || getClass() != o.getClass()) {
30+
return false;
31+
}
32+
33+
NetAddress that = (NetAddress) o;
34+
return this.host.equals(that.getHost()) && this.port == that.getPort();
35+
}
36+
37+
@Override
38+
public int hashCode() {
39+
return Objects.hashCode(this.toString());
40+
}
41+
42+
public static NetAddress parse(String addr) {
43+
try {
44+
String[] addrArr = addr.split(":");
45+
return new NetAddress(addrArr[0], Integer.parseInt(addrArr[1]));
46+
} catch (Exception e) {
47+
return null;
48+
}
49+
}
50+
}

src/main/java/top/meethigher/proxy/http/ReverseHttpProxy.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ public static ReverseHttpProxy create(Router router, HttpServer httpServer, Http
243243
return new ReverseHttpProxy(httpServer, httpClient, router, generateName());
244244
}
245245

246-
protected static String generateName() {
246+
public static String generateName() {
247247
final String prefix = ReverseHttpProxy.class.getSimpleName() + "-";
248248
try {
249249
// 池号对于虚拟机来说是全局的,以避免在类加载器范围的环境中池号重叠

src/main/java/top/meethigher/proxy/tcp/ReverseTcpProxy.java

Lines changed: 71 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@
1010
import io.vertx.core.net.SocketAddress;
1111
import org.slf4j.Logger;
1212
import org.slf4j.LoggerFactory;
13+
import top.meethigher.proxy.LoadBalancer;
14+
import top.meethigher.proxy.NetAddress;
1315

16+
import java.util.ArrayList;
17+
import java.util.List;
1418
import java.util.concurrent.ThreadLocalRandom;
1519

1620
/**
@@ -32,24 +36,32 @@ public class ReverseTcpProxy {
3236
protected final Handler<NetSocket> connectHandler;
3337
protected final NetServer netServer;
3438
protected final NetClient netClient;
35-
protected final String targetHost;
36-
protected final int targetPort;
39+
protected final LoadBalancer<NetAddress> lb;
40+
protected final List<NetAddress> netAddresses;
3741
protected final String name;
3842

3943
protected ReverseTcpProxy(NetServer netServer, NetClient netClient,
40-
String targetHost, int targetPort, String name) {
44+
LoadBalancer<NetAddress> loadBalancer,
45+
List<NetAddress> netAddresses,
46+
String name) {
4147
this.name = name;
42-
this.targetHost = targetHost;
43-
this.targetPort = targetPort;
48+
this.lb = loadBalancer;
49+
this.netAddresses = netAddresses;
4450
this.netServer = netServer;
4551
this.netClient = netClient;
4652
this.connectHandler = sourceSocket -> {
4753
// 暂停流读取
4854
sourceSocket.pause();
4955
SocketAddress sourceRemote = sourceSocket.remoteAddress();
5056
SocketAddress sourceLocal = sourceSocket.localAddress();
51-
log.debug("source {} -- {} connected", sourceLocal, sourceRemote);
5257
sourceSocket.closeHandler(v -> log.debug("source {} -- {} closed", sourceLocal, sourceRemote));
58+
NetAddress next = lb.next();
59+
String targetHost = next.getHost();
60+
int targetPort = next.getPort();
61+
log.debug("source {} -- {} connected. lb [{}] next target {}", sourceLocal, sourceRemote,
62+
lb.name(),
63+
next
64+
);
5365
netClient.connect(targetPort, targetHost)
5466
.onFailure(e -> {
5567
log.error("failed to connect to {}:{}", targetHost, targetPort, e);
@@ -61,7 +73,7 @@ protected ReverseTcpProxy(NetServer netServer, NetClient netClient,
6173
SocketAddress targetRemote = targetSocket.remoteAddress();
6274
SocketAddress targetLocal = targetSocket.localAddress();
6375
log.debug("target {} -- {} connected", targetLocal, targetRemote);
64-
76+
6577
// feat: v1.0.5以前的版本,在closeHandler里面,将对端连接也关闭。比如targetSocket关闭时,则将sourceSocket也关闭。
6678
// 结果导致在转发短连接时,出现了bug。参考https://github.com/meethigher/tcp-reverse-proxy/issues/6
6779
targetSocket.closeHandler(v -> log.debug("target {} -- {} closed", targetLocal, targetRemote));
@@ -88,20 +100,56 @@ protected ReverseTcpProxy(NetServer netServer, NetClient netClient,
88100

89101
public static ReverseTcpProxy create(Vertx vertx,
90102
String targetHost, int targetPort, String name) {
91-
return new ReverseTcpProxy(vertx.createNetServer(), vertx.createNetClient(), targetHost, targetPort, name);
103+
List<NetAddress> list = new ArrayList<>();
104+
TcpRoundRobinLoadBalancer lb = TcpRoundRobinLoadBalancer.create(list);
105+
return new ReverseTcpProxy(
106+
vertx.createNetServer(),
107+
vertx.createNetClient(),
108+
lb,
109+
list,
110+
name
111+
).addNode(new NetAddress(targetHost, targetPort));
92112
}
93113

94114
public static ReverseTcpProxy create(Vertx vertx,
95115
String targetHost, int targetPort) {
96-
return new ReverseTcpProxy(vertx.createNetServer(), vertx.createNetClient(), targetHost, targetPort, generateName());
116+
List<NetAddress> list = new ArrayList<>();
117+
return new ReverseTcpProxy(
118+
vertx.createNetServer(),
119+
vertx.createNetClient(),
120+
TcpRoundRobinLoadBalancer.create(list),
121+
list,
122+
generateName()
123+
).addNode(new NetAddress(targetHost, targetPort));
97124
}
98125

99126
public static ReverseTcpProxy create(NetServer netServer, NetClient netClient, String targetHost, int targetPort) {
100-
return new ReverseTcpProxy(netServer, netClient, targetHost, targetPort, generateName());
127+
List<NetAddress> list = new ArrayList<>();
128+
return new ReverseTcpProxy(
129+
netServer,
130+
netClient,
131+
TcpRoundRobinLoadBalancer.create(list),
132+
list,
133+
generateName()
134+
).addNode(new NetAddress(targetHost, targetPort));
101135
}
102136

103137
public static ReverseTcpProxy create(NetServer netServer, NetClient netClient, String targetHost, int targetPort, String name) {
104-
return new ReverseTcpProxy(netServer, netClient, targetHost, targetPort, name);
138+
List<NetAddress> list = new ArrayList<>();
139+
return new ReverseTcpProxy(
140+
netServer,
141+
netClient,
142+
TcpRoundRobinLoadBalancer.create(list),
143+
list,
144+
name
145+
).addNode(new NetAddress(targetHost, targetPort));
146+
}
147+
148+
public static ReverseTcpProxy create(NetServer netServer, NetClient netClient,
149+
LoadBalancer<NetAddress> loadBalancer,
150+
List<NetAddress> netAddresses,
151+
String name) {
152+
return new ReverseTcpProxy(netServer, netClient, loadBalancer, netAddresses, name);
105153
}
106154

107155
public ReverseTcpProxy port(int port) {
@@ -114,8 +162,15 @@ public ReverseTcpProxy host(String host) {
114162
return this;
115163
}
116164

165+
public ReverseTcpProxy addNode(NetAddress netAddress) {
166+
if (!netAddresses.contains(netAddress)) {
167+
netAddresses.add(netAddress);
168+
}
169+
return this;
170+
}
171+
117172

118-
protected static String generateName() {
173+
public static String generateName() {
119174
final String prefix = ReverseTcpProxy.class.getSimpleName() + "-";
120175
try {
121176
// 池号对于虚拟机来说是全局的,以避免在类加载器范围的环境中池号重叠
@@ -135,12 +190,15 @@ protected static String generateName() {
135190
}
136191

137192
public void start() {
193+
if (netAddresses.size() <= 0) {
194+
throw new IllegalStateException("netAddresses size must be greater than 0");
195+
}
138196
netServer.connectHandler(connectHandler).exceptionHandler(e -> log.error("connect failed", e));
139197
Future<NetServer> listenFuture = netServer.listen(sourcePort, sourceHost);
140198

141199
Handler<AsyncResult<NetServer>> asyncResultHandler = ar -> {
142200
if (ar.succeeded()) {
143-
log.info("{} started on {}:{}", name, sourceHost, sourcePort);
201+
log.info("{} started on {}:{}\nLB-Mode: {}\n {}", name, sourceHost, sourcePort, lb.name(), netAddresses);
144202
} else {
145203
Throwable e = ar.cause();
146204
log.error("{} start failed", name, e);
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package top.meethigher.proxy.tcp;
2+
3+
import top.meethigher.proxy.LoadBalancer;
4+
import top.meethigher.proxy.NetAddress;
5+
6+
import java.util.List;
7+
import java.util.concurrent.atomic.AtomicInteger;
8+
9+
/**
10+
* 轮询策略实现
11+
*
12+
* @author <a href="https://meethigher.top">chenchuancheng</a>
13+
* @since 2025/07/26 13:41
14+
*/
15+
public class TcpRoundRobinLoadBalancer implements LoadBalancer<NetAddress> {
16+
17+
private final List<NetAddress> nodes;
18+
19+
private final AtomicInteger idx = new AtomicInteger(0);
20+
21+
private final String name = "TcpRoundRobinLoadBalancer";
22+
23+
private TcpRoundRobinLoadBalancer(List<NetAddress> nodes) {
24+
this.nodes = nodes;
25+
}
26+
27+
28+
public NetAddress next() {
29+
if (nodes == null) {
30+
return null;
31+
}
32+
int index = idx.getAndUpdate(v -> (v + 1) % nodes.size());
33+
return nodes.get(index);
34+
}
35+
36+
@Override
37+
public String name() {
38+
return name;
39+
}
40+
41+
public static TcpRoundRobinLoadBalancer create(List<NetAddress> nodes) {
42+
return new TcpRoundRobinLoadBalancer(nodes);
43+
}
44+
}

0 commit comments

Comments
 (0)