Skip to content

Commit 4d74900

Browse files
committed
feat: Improve domain name decoding
Use an iterative instead of a recursive approach in the decoding of domain names.
1 parent df822dc commit 4d74900

File tree

1 file changed

+32
-17
lines changed

1 file changed

+32
-17
lines changed

src/decode/domain_name.rs

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use crate::{
22
decode::Decoder, domain_name::DOMAIN_NAME_MAX_RECURSION, DecodeError, DecodeResult, DomainName,
33
};
4-
use std::{collections::HashSet, str::from_utf8};
4+
use std::{collections::HashSet, str::from_utf8, usize};
55

66
const COMPRESSION_BITS: u8 = 0b1100_0000;
77
const COMPRESSION_BITS_REV: u8 = 0b0011_1111;
@@ -19,38 +19,53 @@ const fn get_offset(length_1: u8, length_2: u8) -> usize {
1919
impl<'a, 'b: 'a> Decoder<'a, 'b> {
2020
pub(super) fn domain_name(&mut self) -> DecodeResult<DomainName> {
2121
let mut domain_name = DomainName::default();
22-
let mut recursions = HashSet::new();
23-
self.domain_name_recursion(&mut domain_name, &mut recursions)?;
22+
23+
let mut length = self.u8()?;
24+
while length != 0 {
25+
if is_compressed(length) {
26+
let mut recursions = HashSet::new();
27+
self.domain_name_recursion(&mut domain_name, &mut recursions, length)?;
28+
return Ok(domain_name);
29+
} else {
30+
length = self.domain_name_label(&mut domain_name, length)?;
31+
}
32+
}
2433
Ok(domain_name)
2534
}
2635

36+
fn domain_name_label(&mut self, domain_name: &mut DomainName, length: u8) -> DecodeResult<u8> {
37+
let buffer = self.read(length as usize)?;
38+
let label = from_utf8(buffer.as_ref())?;
39+
let label = label.parse()?;
40+
domain_name.append_label(label)?;
41+
self.u8()
42+
}
43+
2744
fn domain_name_recursion(
2845
&mut self,
2946
domain_name: &mut DomainName,
3047
recursions: &mut HashSet<usize>,
48+
mut length: u8,
3149
) -> DecodeResult<()> {
32-
let recursions_len = recursions.len();
33-
if recursions_len > DOMAIN_NAME_MAX_RECURSION {
34-
return Err(DecodeError::MaxRecursion(recursions_len));
35-
}
50+
let mut buffer = self.u8()?;
51+
let mut offset = get_offset(length, buffer);
52+
let mut decoder = self.new_main_offset(offset);
3653

37-
let mut length = self.u8()?;
3854
while length != 0 {
3955
if is_compressed(length) {
40-
let buffer = self.u8()?;
41-
let offset = get_offset(length, buffer);
56+
buffer = decoder.u8()?;
57+
offset = get_offset(length, buffer);
4258
if recursions.insert(offset) {
43-
let mut decoder = self.new_main_offset(offset);
44-
return decoder.domain_name_recursion(domain_name, recursions);
59+
let recursions_len = recursions.len();
60+
if recursions_len > DOMAIN_NAME_MAX_RECURSION {
61+
return Err(DecodeError::MaxRecursion(recursions_len));
62+
}
4563
} else {
4664
return Err(DecodeError::EndlessRecursion(offset));
4765
}
66+
decoder.offset = offset as usize;
4867
} else {
49-
let buffer = self.read(length as usize)?;
50-
let label = from_utf8(buffer.as_ref())?;
51-
let label = label.parse()?;
52-
domain_name.append_label(label)?;
53-
length = self.u8()?;
68+
length = decoder.domain_name_label(domain_name, length)?;
5469
}
5570
}
5671

0 commit comments

Comments
 (0)