11use crate :: {
22 decode:: Decoder , domain_name:: DOMAIN_NAME_MAX_RECURSION , DecodeError , DecodeResult , DomainName ,
33} ;
4- use std:: str:: from_utf8;
4+ use std:: { collections :: HashSet , str:: from_utf8, usize } ;
55
66const COMPRESSION_BITS : u8 = 0b1100_0000 ;
77const COMPRESSION_BITS_REV : u8 = 0b0011_1111 ;
@@ -19,32 +19,56 @@ const fn get_offset(length_1: u8, length_2: u8) -> usize {
1919impl < ' a , ' b : ' a > Decoder < ' a , ' b > {
2020 pub ( super ) fn domain_name ( & mut self ) -> DecodeResult < DomainName > {
2121 let mut domain_name = DomainName :: default ( ) ;
22- self . domain_name_recursion ( & mut domain_name, 0 ) ?;
22+
23+ let mut length = self . u8 ( ) ?;
24+ while length != 0 {
25+ if is_compressed ( length) {
26+ let mut recursions = HashSet :: new ( ) ;
27+ self . domain_name_recursion ( & mut domain_name, & mut recursions, length) ?;
28+ return Ok ( domain_name) ;
29+ } else {
30+ length = self . domain_name_label ( & mut domain_name, length) ?;
31+ }
32+ }
2333 Ok ( domain_name)
2434 }
2535
36+ fn domain_name_label ( & mut self , domain_name : & mut DomainName , length : u8 ) -> DecodeResult < u8 > {
37+ let buffer = self . read ( length as usize ) ?;
38+ let label = from_utf8 ( buffer. as_ref ( ) ) ?;
39+ let label = label. parse ( ) ?;
40+ domain_name. append_label ( label) ?;
41+ self . u8 ( )
42+ }
43+
2644 fn domain_name_recursion (
2745 & mut self ,
2846 domain_name : & mut DomainName ,
29- recursion : usize ,
47+ recursions : & mut HashSet < usize > ,
48+ mut length : u8 ,
3049 ) -> DecodeResult < ( ) > {
31- if recursion > DOMAIN_NAME_MAX_RECURSION {
32- return Err ( DecodeError :: MaxRecursion ( recursion) ) ;
33- }
50+ let mut buffer = self . u8 ( ) ?;
51+ let mut offset = get_offset ( length, buffer) ;
52+ let mut decoder = self . new_main_offset ( offset) ;
53+
54+ length = decoder. u8 ( ) ?;
3455
35- let mut length = self . u8 ( ) ?;
3656 while length != 0 {
3757 if is_compressed ( length) {
38- let buffer = self . u8 ( ) ?;
39- let offset = get_offset ( length, buffer) ;
40- let mut decoder = self . new_main_offset ( offset) ;
41- return decoder. domain_name_recursion ( domain_name, recursion + 1 ) ;
58+ buffer = decoder. u8 ( ) ?;
59+ offset = get_offset ( length, buffer) ;
60+ if recursions. insert ( offset) {
61+ let recursions_len = recursions. len ( ) ;
62+ if recursions_len > DOMAIN_NAME_MAX_RECURSION {
63+ return Err ( DecodeError :: MaxRecursion ( recursions_len) ) ;
64+ }
65+ } else {
66+ return Err ( DecodeError :: EndlessRecursion ( offset) ) ;
67+ }
68+ decoder. offset = offset as usize ;
69+ length = decoder. u8 ( ) ?;
4270 } else {
43- let buffer = self . read ( length as usize ) ?;
44- let label = from_utf8 ( buffer. as_ref ( ) ) ?;
45- let label = label. parse ( ) ?;
46- domain_name. append_label ( label) ?;
47- length = self . u8 ( ) ?;
71+ length = decoder. domain_name_label ( domain_name, length) ?;
4872 }
4973 }
5074
0 commit comments