2
2
3
3
import io .vertx .core .Handler ;
4
4
import io .vertx .core .buffer .Buffer ;
5
+ import io .vertx .core .net .NetSocket ;
6
+ import org .slf4j .Logger ;
7
+ import org .slf4j .LoggerFactory ;
5
8
6
9
/**
7
10
* 自定义消息结构解析器
12
15
*/
13
16
public class TunnelMessageParser implements Handler <Buffer > {
14
17
18
+ private static final Logger log = LoggerFactory .getLogger (TunnelMessageParser .class );
15
19
private Buffer buf = Buffer .buffer ();
16
20
21
+ /**
22
+ * 最大长度,单位字节。防止对方构造超长字段,占用内存。
23
+ */
24
+ private final int maxLength = 1024 * 1024 ;
25
+
17
26
/**
18
27
* 预设长度起始位置
19
28
*/
@@ -40,8 +49,12 @@ public class TunnelMessageParser implements Handler<Buffer> {
40
49
41
50
private final Handler <Buffer > outputHandler ;
42
51
43
- public TunnelMessageParser (Handler <Buffer > outputHandler ) {
52
+ private final NetSocket netSocket ;
53
+
54
+ public TunnelMessageParser (Handler <Buffer > outputHandler ,
55
+ NetSocket netSocket ) {
44
56
this .outputHandler = outputHandler ;
57
+ this .netSocket = netSocket ;
45
58
}
46
59
47
60
@ Override
@@ -51,6 +64,25 @@ public void handle(Buffer buffer) {
51
64
return ;
52
65
} else {
53
66
int totalLength = buf .getInt (lengthFieldOffset );
67
+ // 校验最大长度
68
+ if (totalLength > maxLength ) {
69
+ log .warn ("too many bytes in length field, connection {} will be closed" , netSocket .remoteAddress ());
70
+ netSocket .close ();
71
+ return ;
72
+ }
73
+ // 校验类型编码是否在预设范围内
74
+ if (totalLength >= (lengthFieldLength + typeFieldLength )) {
75
+ short code = buf .getShort (lengthFieldLength );
76
+ try {
77
+ TunnelMessageType .fromCode (code );
78
+ } catch (Exception e ) {
79
+ log .error ("invalid type, connection {} will be closed" , netSocket .remoteAddress (), e );
80
+ netSocket .close ();
81
+ return ;
82
+ }
83
+ }
84
+
85
+ // 校验是否达到预设总长度
54
86
if (buf .length () < totalLength ) {
55
87
return ;
56
88
} else {
0 commit comments