Skip to content

Commit 498654b

Browse files
committed
feat: Improve domain name decoding
Use an iterative instead of a recursive approach in the decoding of domain names.
1 parent df822dc commit 498654b

File tree

2 files changed

+37
-19
lines changed

2 files changed

+37
-19
lines changed

src/decode/domain_name.rs

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use crate::{
22
decode::Decoder, domain_name::DOMAIN_NAME_MAX_RECURSION, DecodeError, DecodeResult, DomainName,
33
};
4-
use std::{collections::HashSet, str::from_utf8};
4+
use std::{collections::HashSet, str::from_utf8, usize};
55

66
const COMPRESSION_BITS: u8 = 0b1100_0000;
77
const COMPRESSION_BITS_REV: u8 = 0b0011_1111;
@@ -19,38 +19,56 @@ const fn get_offset(length_1: u8, length_2: u8) -> usize {
1919
impl<'a, 'b: 'a> Decoder<'a, 'b> {
2020
pub(super) fn domain_name(&mut self) -> DecodeResult<DomainName> {
2121
let mut domain_name = DomainName::default();
22-
let mut recursions = HashSet::new();
23-
self.domain_name_recursion(&mut domain_name, &mut recursions)?;
22+
23+
let mut length = self.u8()?;
24+
while length != 0 {
25+
if is_compressed(length) {
26+
let mut recursions = HashSet::new();
27+
self.domain_name_recursion(&mut domain_name, &mut recursions, length)?;
28+
return Ok(domain_name);
29+
} else {
30+
length = self.domain_name_label(&mut domain_name, length)?;
31+
}
32+
}
2433
Ok(domain_name)
2534
}
2635

36+
fn domain_name_label(&mut self, domain_name: &mut DomainName, length: u8) -> DecodeResult<u8> {
37+
let buffer = self.read(length as usize)?;
38+
let label = from_utf8(buffer.as_ref())?;
39+
let label = label.parse()?;
40+
domain_name.append_label(label)?;
41+
self.u8()
42+
}
43+
2744
fn domain_name_recursion(
2845
&mut self,
2946
domain_name: &mut DomainName,
3047
recursions: &mut HashSet<usize>,
48+
mut length: u8,
3149
) -> DecodeResult<()> {
32-
let recursions_len = recursions.len();
33-
if recursions_len > DOMAIN_NAME_MAX_RECURSION {
34-
return Err(DecodeError::MaxRecursion(recursions_len));
35-
}
50+
let mut buffer = self.u8()?;
51+
let mut offset = get_offset(length, buffer);
52+
let mut decoder = self.new_main_offset(offset);
53+
54+
length = decoder.u8()?;
3655

37-
let mut length = self.u8()?;
3856
while length != 0 {
3957
if is_compressed(length) {
40-
let buffer = self.u8()?;
41-
let offset = get_offset(length, buffer);
58+
buffer = decoder.u8()?;
59+
offset = get_offset(length, buffer);
4260
if recursions.insert(offset) {
43-
let mut decoder = self.new_main_offset(offset);
44-
return decoder.domain_name_recursion(domain_name, recursions);
61+
let recursions_len = recursions.len();
62+
if recursions_len > DOMAIN_NAME_MAX_RECURSION {
63+
return Err(DecodeError::MaxRecursion(recursions_len));
64+
}
4565
} else {
4666
return Err(DecodeError::EndlessRecursion(offset));
4767
}
68+
decoder.offset = offset as usize;
69+
length = decoder.u8()?;
4870
} else {
49-
let buffer = self.read(length as usize)?;
50-
let label = from_utf8(buffer.as_ref())?;
51-
let label = label.parse()?;
52-
domain_name.append_label(label)?;
53-
length = self.u8()?;
71+
length = decoder.domain_name_label(domain_name, length)?;
5472
}
5573
}
5674

tests/decode_error.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,13 @@ fn flags_5() {
6565
#[test]
6666
fn domain_name_1() {
6767
let msg = b"\xc0\x02\xc0\x00";
68-
decode_domain_name_error(msg, DecodeError::EndlessRecursion(2));
68+
decode_domain_name_error(msg, DecodeError::EndlessRecursion(0));
6969
}
7070

7171
#[test]
7272
fn domain_name_2() {
7373
let msg =
74-
b"\xc0\x02\xc0\x04\xc0\x06\xc0\x08\xc0\x0a\xc0\x0c\xc0\x0e\xc0\x10\xc0\x12\xc0\x14\xc0\x16\xc0\x18\xc0\x1a\xc0\x1c\xc0\x1e\xc0\x20\xc0\x22";
74+
b"\xc0\x02\xc0\x04\xc0\x06\xc0\x08\xc0\x0a\xc0\x0c\xc0\x0e\xc0\x10\xc0\x12\xc0\x14\xc0\x16\xc0\x18\xc0\x1a\xc0\x1c\xc0\x1e\xc0\x20\xc0\x22\xc0\x24";
7575
decode_domain_name_error(msg, DecodeError::MaxRecursion(17));
7676
}
7777

0 commit comments

Comments
 (0)