11use crate :: {
22 decode:: Decoder , domain_name:: DOMAIN_NAME_MAX_RECURSION , DecodeError , DecodeResult , DomainName ,
33} ;
4- use std:: { collections:: HashSet , str:: from_utf8} ;
4+ use std:: { collections:: HashSet , str:: from_utf8, usize } ;
55
66const COMPRESSION_BITS : u8 = 0b1100_0000 ;
77const COMPRESSION_BITS_REV : u8 = 0b0011_1111 ;
@@ -19,38 +19,53 @@ const fn get_offset(length_1: u8, length_2: u8) -> usize {
1919impl < ' a , ' b : ' a > Decoder < ' a , ' b > {
2020 pub ( super ) fn domain_name ( & mut self ) -> DecodeResult < DomainName > {
2121 let mut domain_name = DomainName :: default ( ) ;
22- let mut recursions = HashSet :: new ( ) ;
23- self . domain_name_recursion ( & mut domain_name, & mut recursions) ?;
22+
23+ let mut length = self . u8 ( ) ?;
24+ while length != 0 {
25+ if is_compressed ( length) {
26+ let mut recursions = HashSet :: new ( ) ;
27+ self . domain_name_recursion ( & mut domain_name, & mut recursions, length) ?;
28+ return Ok ( domain_name) ;
29+ } else {
30+ length = self . domain_name_label ( & mut domain_name, length) ?;
31+ }
32+ }
2433 Ok ( domain_name)
2534 }
2635
36+ fn domain_name_label ( & mut self , domain_name : & mut DomainName , length : u8 ) -> DecodeResult < u8 > {
37+ let buffer = self . read ( length as usize ) ?;
38+ let label = from_utf8 ( buffer. as_ref ( ) ) ?;
39+ let label = label. parse ( ) ?;
40+ domain_name. append_label ( label) ?;
41+ self . u8 ( )
42+ }
43+
2744 fn domain_name_recursion (
2845 & mut self ,
2946 domain_name : & mut DomainName ,
3047 recursions : & mut HashSet < usize > ,
48+ mut length : u8 ,
3149 ) -> DecodeResult < ( ) > {
32- let recursions_len = recursions. len ( ) ;
33- if recursions_len > DOMAIN_NAME_MAX_RECURSION {
34- return Err ( DecodeError :: MaxRecursion ( recursions_len) ) ;
35- }
50+ let mut buffer = self . u8 ( ) ?;
51+ let mut offset = get_offset ( length, buffer) ;
52+ let mut decoder = self . new_main_offset ( offset) ;
3653
37- let mut length = self . u8 ( ) ?;
3854 while length != 0 {
3955 if is_compressed ( length) {
40- let buffer = self . u8 ( ) ?;
41- let offset = get_offset ( length, buffer) ;
56+ buffer = decoder . u8 ( ) ?;
57+ offset = get_offset ( length, buffer) ;
4258 if recursions. insert ( offset) {
43- let mut decoder = self . new_main_offset ( offset) ;
44- return decoder. domain_name_recursion ( domain_name, recursions) ;
59+ let recursions_len = recursions. len ( ) ;
60+ if recursions_len > DOMAIN_NAME_MAX_RECURSION {
61+ return Err ( DecodeError :: MaxRecursion ( recursions_len) ) ;
62+ }
4563 } else {
4664 return Err ( DecodeError :: EndlessRecursion ( offset) ) ;
4765 }
66+ decoder. offset = offset as usize ;
4867 } else {
49- let buffer = self . read ( length as usize ) ?;
50- let label = from_utf8 ( buffer. as_ref ( ) ) ?;
51- let label = label. parse ( ) ?;
52- domain_name. append_label ( label) ?;
53- length = self . u8 ( ) ?;
68+ length = decoder. domain_name_label ( domain_name, length) ?;
5469 }
5570 }
5671
0 commit comments