diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 081600e..e879201 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -15,40 +15,29 @@ jobs: runs-on: ubuntu-latest steps: - name: Install toolchain with rustfmt - uses: actions-rs/toolchain@v1 + uses: dtolnay/rust-toolchain@stable with: - toolchain: stable components: rustfmt - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Run rustfmt run: cargo fmt --all -- --check - audit: - name: Job audit - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v1 - - name: Run audit - uses: actions-rs/audit-check@v1 - with: - token: ${{ secrets.GITHUB_TOKEN }} - clippy: name: Job clippy needs: rustfmt runs-on: ubuntu-latest steps: - name: Install toolchain with clippy - uses: actions-rs/toolchain@v1 + uses: dtolnay/rust-toolchain@stable with: - toolchain: stable components: clippy - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Run clippy - uses: actions-rs/clippy-check@v1 + uses: giraffate/clippy-action@v1 with: - token: ${{ secrets.GITHUB_TOKEN }} - args: -- --deny warnings -A clippy::unknown-clippy-lints + reporter: 'github-pr-check' + github_token: ${{ secrets.GITHUB_TOKEN }} + clippy_flags: --deny warnings -A clippy::unknown-clippy-lints tests: name: Job tests @@ -60,15 +49,12 @@ jobs: runs-on: ${{ matrix.os }} steps: - name: Install toolchain ${{ matrix.rust_channel }} on ${{ matrix.os }} - uses: actions-rs/toolchain@v1 + uses: dtolnay/rust-toolchain@master with: toolchain: ${{ matrix.rust_channel }} - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Run cargo test - uses: actions-rs/cargo@v1 - with: - command: test - args: --all-features + run: cargo test --no-default-features --features "${{ matrix.features }}" code-coverage: name: Job code coverage @@ -76,23 +62,16 @@ jobs: runs-on: ubuntu-latest steps: - name: Intall toolchain nightly on ubuntu-latest - uses: actions-rs/toolchain@v1 - with: - toolchain: nightly - override: true - - uses: actions/checkout@v2 - - name: Run cargo test - uses: actions-rs/cargo@v1 + uses: dtolnay/rust-toolchain@stable with: - command: test - args: --all-features - env: - CARGO_INCREMENTAL: '0' - RUSTFLAGS: '-Zprofile -Ccodegen-units=1 -Cinline-threshold=0 -Clink-dead-code -Coverflow-checks=off -Cpanic=abort -Zpanic_abort_tests' - RUSTDOCFLAGS: '-Zprofile -Ccodegen-units=1 -Cinline-threshold=0 -Clink-dead-code -Coverflow-checks=off -Cpanic=abort -Zpanic_abort_tests' - - name: Run grcov - uses: actions-rs/grcov@v0.1 - - name: Upload coverage - uses: codecov/codecov-action@v1 + components: llvm-tools-preview + - uses: actions/checkout@v4 + - name: cargo install cargo-llvm-cov + uses: taiki-e/install-action@cargo-llvm-cov + - name: cargo llvm-cov + run: cargo llvm-cov --all-features --lcov --output-path lcov.info + - name: Upload to codecov.io + uses: codecov/codecov-action@v4 with: - file: ${{ steps.coverage.outputs.report }} + token: ${{ secrets.CODECOV_TOKEN }} + env_vars: OS,RUST diff --git a/Cargo.toml b/Cargo.toml index cc94432..2fc7513 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "dns-message-parser" -version = "0.7.0" +version = "0.8.0" authors = ["LinkTed "] edition = "2018" readme = "README.md" @@ -22,13 +22,13 @@ categories = [ ] [dependencies] -base64 = "0.13.0" -bytes = "1.2.1" -hex = "0.4.3" -thiserror = "1.0.32" +base64 = "0.22" +bytes = "1" +hex = "0.4" +thiserror = "2" [dev-dependencies] -criterion = "0.3.6" +criterion = "0.5" [[bench]] name = "message" diff --git a/src/decode/decoder.rs b/src/decode/decoder.rs index b302f39..35f30dc 100644 --- a/src/decode/decoder.rs +++ b/src/decode/decoder.rs @@ -38,7 +38,7 @@ impl<'a, 'b: 'a> Decoder<'a, 'b> { } } - pub(super) const fn get_main(&'a self) -> &Decoder<'a, 'b> { + pub(super) const fn get_main(&self) -> &Decoder<'a, 'b> { let mut root = self; loop { match root.parent { diff --git a/src/decode/domain_name.rs b/src/decode/domain_name.rs index 04bd598..33c63da 100644 --- a/src/decode/domain_name.rs +++ b/src/decode/domain_name.rs @@ -1,7 +1,7 @@ use crate::{ decode::Decoder, domain_name::DOMAIN_NAME_MAX_RECURSION, DecodeError, DecodeResult, DomainName, }; -use std::str::from_utf8; +use std::{collections::HashSet, str::from_utf8, usize}; const COMPRESSION_BITS: u8 = 0b1100_0000; const COMPRESSION_BITS_REV: u8 = 0b0011_1111; @@ -19,32 +19,56 @@ const fn get_offset(length_1: u8, length_2: u8) -> usize { impl<'a, 'b: 'a> Decoder<'a, 'b> { pub(super) fn domain_name(&mut self) -> DecodeResult { let mut domain_name = DomainName::default(); - self.domain_name_recursion(&mut domain_name, 0)?; + + let mut length = self.u8()?; + while length != 0 { + if is_compressed(length) { + let mut recursions = HashSet::new(); + self.domain_name_recursion(&mut domain_name, &mut recursions, length)?; + return Ok(domain_name); + } else { + length = self.domain_name_label(&mut domain_name, length)?; + } + } Ok(domain_name) } + fn domain_name_label(&mut self, domain_name: &mut DomainName, length: u8) -> DecodeResult { + let buffer = self.read(length as usize)?; + let label = from_utf8(buffer.as_ref())?; + let label = label.parse()?; + domain_name.append_label(label)?; + self.u8() + } + fn domain_name_recursion( &mut self, domain_name: &mut DomainName, - recursion: usize, + recursions: &mut HashSet, + mut length: u8, ) -> DecodeResult<()> { - if recursion > DOMAIN_NAME_MAX_RECURSION { - return Err(DecodeError::MaxRecursion(recursion)); - } + let mut buffer = self.u8()?; + let mut offset = get_offset(length, buffer); + let mut decoder = self.new_main_offset(offset); + + length = decoder.u8()?; - let mut length = self.u8()?; while length != 0 { if is_compressed(length) { - let buffer = self.u8()?; - let offset = get_offset(length, buffer); - let mut decoder = self.new_main_offset(offset); - return decoder.domain_name_recursion(domain_name, recursion + 1); + buffer = decoder.u8()?; + offset = get_offset(length, buffer); + if recursions.insert(offset) { + let recursions_len = recursions.len(); + if recursions_len > DOMAIN_NAME_MAX_RECURSION { + return Err(DecodeError::MaxRecursion(recursions_len)); + } + } else { + return Err(DecodeError::EndlessRecursion(offset)); + } + decoder.offset = offset as usize; + length = decoder.u8()?; } else { - let buffer = self.read(length as usize)?; - let label = from_utf8(buffer.as_ref())?; - let label = label.parse()?; - domain_name.append_label(label)?; - length = self.u8()?; + length = decoder.domain_name_label(domain_name, length)?; } } diff --git a/src/decode/error.rs b/src/decode/error.rs index 9acaf02..dbf562e 100644 --- a/src/decode/error.rs +++ b/src/decode/error.rs @@ -89,6 +89,8 @@ pub enum DecodeError { DNSKEYProtocol(u8), #[error("Could not decode the domain name, the because maximum recursion is reached: {0}")] MaxRecursion(usize), + #[error("Could not decode the domain name, because an endless recursion was detected: {0}")] + EndlessRecursion(usize), #[error("The are remaining bytes, which was not parsed")] RemainingBytes(usize, Dns), #[error("Padding is not zero: {0}")] diff --git a/src/rr/draft_ietf_dnsop_svcb_https.rs b/src/rr/draft_ietf_dnsop_svcb_https.rs index 16c0401..8031c09 100644 --- a/src/rr/draft_ietf_dnsop_svcb_https.rs +++ b/src/rr/draft_ietf_dnsop_svcb_https.rs @@ -1,12 +1,12 @@ -use std::cmp::Ordering; -use std::fmt::{Display, Formatter, Result as FmtResult}; -use std::hash::{Hash, Hasher}; -use std::net::{Ipv4Addr, Ipv6Addr}; - use crate::rr::draft_ietf_dnsop_svcb_https::ServiceBindingMode::{Alias, Service}; use crate::rr::{ToType, Type}; use crate::DomainName; +use base64::{engine::general_purpose::STANDARD as Base64Standard, Engine}; +use std::cmp::Ordering; use std::collections::BTreeSet; +use std::fmt::{Display, Formatter, Result as FmtResult}; +use std::hash::{Hash, Hasher}; +use std::net::{Ipv4Addr, Ipv6Addr}; /// A Service Binding record for locating alternative endpoints for a service. /// @@ -227,7 +227,7 @@ impl Display for ServiceParameter { ) } ServiceParameter::ECH { config_list } => { - write!(f, "ech={}", base64::encode(config_list)) + write!(f, "ech={}", Base64Standard.encode(config_list)) } ServiceParameter::IPV6_HINT { hints } => { write!( @@ -261,7 +261,7 @@ impl Display for ServiceParameter { if let Ok(value) = String::from_utf8(escaped) { write!(f, "{}=\"{}\"", key, value) } else { - write!(f, "{}=\"{}\"", key, base64::encode(wire_data)) + write!(f, "{}=\"{}\"", key, Base64Standard.encode(wire_data)) } } } diff --git a/tests/decode_error.rs b/tests/decode_error.rs index a59e23a..fec4035 100644 --- a/tests/decode_error.rs +++ b/tests/decode_error.rs @@ -2,7 +2,7 @@ use bytes::Bytes; use dns_message_parser::{ question::{QClass, QType, Question}, rr::{AddressError, Class, TagError}, - Opcode, RCode, MAXIMUM_DNS_PACKET_SIZE, {DecodeError, Dns, Flags}, + DecodeError, Dns, DomainName, Flags, Opcode, RCode, MAXIMUM_DNS_PACKET_SIZE, }; fn decode_msg_error(msg: &[u8], e: DecodeError) { @@ -23,6 +23,15 @@ fn decode_flags_error(msg: &[u8], e: DecodeError) { assert_eq!(flags, Err(e)) } +fn decode_domain_name_error(msg: &[u8], e: DecodeError) { + // Decode BytesMut to message + let bytes = Bytes::copy_from_slice(msg); + // Decode the domain name + let domain_name = DomainName::decode(bytes); + // Check the result + assert_eq!(domain_name, Err(e)) +} + #[test] fn flags_1() { let msg = b""; @@ -53,6 +62,19 @@ fn flags_5() { decode_flags_error(msg, DecodeError::RCode(15)); } +#[test] +fn domain_name_1() { + let msg = b"\xc0\x02\xc0\x00"; + decode_domain_name_error(msg, DecodeError::EndlessRecursion(0)); +} + +#[test] +fn domain_name_2() { + let msg = + b"\xc0\x02\xc0\x04\xc0\x06\xc0\x08\xc0\x0a\xc0\x0c\xc0\x0e\xc0\x10\xc0\x12\xc0\x14\xc0\x16\xc0\x18\xc0\x1a\xc0\x1c\xc0\x1e\xc0\x20\xc0\x22\xc0\x24"; + decode_domain_name_error(msg, DecodeError::MaxRecursion(17)); +} + #[test] fn dns_not_enough_bytes() { let msg = b"";