11use crate :: {
22 decode:: Decoder , domain_name:: DOMAIN_NAME_MAX_RECURSION , DecodeError , DecodeResult , DomainName ,
33} ;
4- use std:: { collections:: HashSet , str:: from_utf8, usize } ;
4+ use std:: { collections:: HashSet , str:: from_utf8} ;
55
66const COMPRESSION_BITS : u8 = 0b1100_0000 ;
77const COMPRESSION_BITS_REV : u8 = 0b0011_1111 ;
@@ -12,67 +12,75 @@ const fn is_compressed(length: u8) -> bool {
1212}
1313
1414#[ inline]
15- const fn get_offset ( length_1 : u8 , length_2 : u8 ) -> usize {
16- ( ( ( length_1 & COMPRESSION_BITS_REV ) as usize ) << 8 ) | length_2 as usize
15+ const fn get_offset ( length_1 : u8 , length_2 : u8 ) -> u16 {
16+ ( ( ( length_1 & COMPRESSION_BITS_REV ) as u16 ) << 8 ) | length_2 as u16
17+ }
18+
19+ enum DomainNameLength {
20+ Compressed ( u16 ) ,
21+ Label ( u8 ) ,
1722}
1823
1924impl < ' a , ' b : ' a > Decoder < ' a , ' b > {
25+ fn domain_name_length ( & mut self ) -> DecodeResult < DomainNameLength > {
26+ let length = self . u8 ( ) ?;
27+ if is_compressed ( length) {
28+ let offset = self . u8 ( ) ?;
29+ Ok ( DomainNameLength :: Compressed ( get_offset ( length, offset) ) )
30+ } else {
31+ Ok ( DomainNameLength :: Label ( length) )
32+ }
33+ }
34+
2035 pub ( super ) fn domain_name ( & mut self ) -> DecodeResult < DomainName > {
2136 let mut domain_name = DomainName :: default ( ) ;
2237
23- let mut length = self . u8 ( ) ?;
24- while length != 0 {
25- if is_compressed ( length) {
26- let mut recursions = HashSet :: new ( ) ;
27- self . domain_name_recursion ( & mut domain_name, & mut recursions, length) ?;
28- return Ok ( domain_name) ;
29- } else {
30- length = self . domain_name_label ( & mut domain_name, length) ?;
38+ loop {
39+ match self . domain_name_length ( ) ? {
40+ DomainNameLength :: Compressed ( offset) => {
41+ self . domain_name_recursion ( & mut domain_name, offset) ?;
42+ return Ok ( domain_name) ;
43+ }
44+ DomainNameLength :: Label ( 0 ) => return Ok ( domain_name) ,
45+ DomainNameLength :: Label ( length) => {
46+ self . domain_name_label ( & mut domain_name, length) ?
47+ }
3148 }
3249 }
33- Ok ( domain_name)
3450 }
3551
36- fn domain_name_label ( & mut self , domain_name : & mut DomainName , length : u8 ) -> DecodeResult < u8 > {
52+ fn domain_name_label ( & mut self , domain_name : & mut DomainName , length : u8 ) -> DecodeResult < ( ) > {
3753 let buffer = self . read ( length as usize ) ?;
3854 let label = from_utf8 ( buffer. as_ref ( ) ) ?;
3955 let label = label. parse ( ) ?;
4056 domain_name. append_label ( label) ?;
41- self . u8 ( )
57+ Ok ( ( ) )
4258 }
4359
44- fn domain_name_recursion (
45- & mut self ,
46- domain_name : & mut DomainName ,
47- recursions : & mut HashSet < usize > ,
48- mut length : u8 ,
49- ) -> DecodeResult < ( ) > {
50- let mut buffer = self . u8 ( ) ?;
51- let mut offset = get_offset ( length, buffer) ;
60+ fn domain_name_recursion ( & self , domain_name : & mut DomainName , offset : u16 ) -> DecodeResult < ( ) > {
5261 let mut decoder = self . new_main_offset ( offset) ;
62+ let mut recursions = HashSet :: new ( ) ;
5363
54- length = decoder . u8 ( ) ? ;
55-
56- while length != 0 {
57- if is_compressed ( length ) {
58- buffer = decoder . u8 ( ) ? ;
59- offset = get_offset ( length , buffer ) ;
60- if recursions . insert ( offset ) {
61- let recursions_len = recursions . len ( ) ;
62- if recursions_len > DOMAIN_NAME_MAX_RECURSION {
63- return Err ( DecodeError :: MaxRecursion ( recursions_len ) ) ;
64+ loop {
65+ match decoder . domain_name_length ( ) ? {
66+ DomainNameLength :: Compressed ( offset ) => {
67+ if recursions . insert ( offset ) {
68+ let recursions_len = recursions . len ( ) ;
69+ if recursions_len > DOMAIN_NAME_MAX_RECURSION {
70+ return Err ( DecodeError :: MaxRecursion ( recursions_len ) ) ;
71+ }
72+ } else {
73+ return Err ( DecodeError :: EndlessRecursion ( offset ) ) ;
6474 }
65- } else {
66- return Err ( DecodeError :: EndlessRecursion ( offset) ) ;
75+
76+ decoder. offset = offset as usize ;
77+ }
78+ DomainNameLength :: Label ( 0 ) => return Ok ( ( ) ) ,
79+ DomainNameLength :: Label ( length) => {
80+ decoder. domain_name_label ( domain_name, length) ?;
6781 }
68- decoder. offset = offset as usize ;
69- length = decoder. u8 ( ) ?;
70- } else {
71- length = decoder. domain_name_label ( domain_name, length) ?;
7282 }
7383 }
74-
75- Ok ( ( ) )
7684 }
7785}
7886
0 commit comments