Skip to content
65 changes: 22 additions & 43 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,40 +15,29 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Install toolchain with rustfmt
uses: actions-rs/toolchain@v1
uses: dtolnay/rust-toolchain@stable
with:
toolchain: stable
components: rustfmt
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- name: Run rustfmt
run: cargo fmt --all -- --check

audit:
name: Job audit
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- name: Run audit
uses: actions-rs/audit-check@v1
with:
token: ${{ secrets.GITHUB_TOKEN }}

clippy:
name: Job clippy
needs: rustfmt
runs-on: ubuntu-latest
steps:
- name: Install toolchain with clippy
uses: actions-rs/toolchain@v1
uses: dtolnay/rust-toolchain@stable
with:
toolchain: stable
components: clippy
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- name: Run clippy
uses: actions-rs/clippy-check@v1
uses: giraffate/clippy-action@v1
with:
token: ${{ secrets.GITHUB_TOKEN }}
args: -- --deny warnings -A clippy::unknown-clippy-lints
reporter: 'github-pr-check'
github_token: ${{ secrets.GITHUB_TOKEN }}
clippy_flags: --deny warnings -A clippy::unknown-clippy-lints

tests:
name: Job tests
Expand All @@ -60,39 +49,29 @@ jobs:
runs-on: ${{ matrix.os }}
steps:
- name: Install toolchain ${{ matrix.rust_channel }} on ${{ matrix.os }}
uses: actions-rs/toolchain@v1
uses: dtolnay/rust-toolchain@master
with:
toolchain: ${{ matrix.rust_channel }}
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- name: Run cargo test
uses: actions-rs/cargo@v1
with:
command: test
args: --all-features
run: cargo test --no-default-features --features "${{ matrix.features }}"

code-coverage:
name: Job code coverage
needs: tests
runs-on: ubuntu-latest
steps:
- name: Intall toolchain nightly on ubuntu-latest
uses: actions-rs/toolchain@v1
with:
toolchain: nightly
override: true
- uses: actions/checkout@v2
- name: Run cargo test
uses: actions-rs/cargo@v1
uses: dtolnay/rust-toolchain@stable
with:
command: test
args: --all-features
env:
CARGO_INCREMENTAL: '0'
RUSTFLAGS: '-Zprofile -Ccodegen-units=1 -Cinline-threshold=0 -Clink-dead-code -Coverflow-checks=off -Cpanic=abort -Zpanic_abort_tests'
RUSTDOCFLAGS: '-Zprofile -Ccodegen-units=1 -Cinline-threshold=0 -Clink-dead-code -Coverflow-checks=off -Cpanic=abort -Zpanic_abort_tests'
- name: Run grcov
uses: actions-rs/[email protected]
- name: Upload coverage
uses: codecov/codecov-action@v1
components: llvm-tools-preview
- uses: actions/checkout@v4
- name: cargo install cargo-llvm-cov
uses: taiki-e/install-action@cargo-llvm-cov
- name: cargo llvm-cov
run: cargo llvm-cov --all-features --lcov --output-path lcov.info
- name: Upload to codecov.io
uses: codecov/codecov-action@v4
with:
file: ${{ steps.coverage.outputs.report }}
token: ${{ secrets.CODECOV_TOKEN }}
env_vars: OS,RUST
12 changes: 6 additions & 6 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "dns-message-parser"
version = "0.7.0"
version = "0.8.0"
authors = ["LinkTed <[email protected]>"]
edition = "2018"
readme = "README.md"
Expand All @@ -22,13 +22,13 @@ categories = [
]

[dependencies]
base64 = "0.13.0"
bytes = "1.2.1"
hex = "0.4.3"
thiserror = "1.0.32"
base64 = "0.22"
bytes = "1"
hex = "0.4"
thiserror = "2"

[dev-dependencies]
criterion = "0.3.6"
criterion = "0.5"

[[bench]]
name = "message"
Expand Down
2 changes: 1 addition & 1 deletion src/decode/decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ impl<'a, 'b: 'a> Decoder<'a, 'b> {
}
}

pub(super) const fn get_main(&'a self) -> &Decoder<'a, 'b> {
pub(super) const fn get_main(&self) -> &Decoder<'a, 'b> {
let mut root = self;
loop {
match root.parent {
Expand Down
56 changes: 40 additions & 16 deletions src/decode/domain_name.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{
decode::Decoder, domain_name::DOMAIN_NAME_MAX_RECURSION, DecodeError, DecodeResult, DomainName,
};
use std::str::from_utf8;
use std::{collections::HashSet, str::from_utf8, usize};

const COMPRESSION_BITS: u8 = 0b1100_0000;
const COMPRESSION_BITS_REV: u8 = 0b0011_1111;
Expand All @@ -19,32 +19,56 @@ const fn get_offset(length_1: u8, length_2: u8) -> usize {
impl<'a, 'b: 'a> Decoder<'a, 'b> {
pub(super) fn domain_name(&mut self) -> DecodeResult<DomainName> {
let mut domain_name = DomainName::default();
self.domain_name_recursion(&mut domain_name, 0)?;

let mut length = self.u8()?;
while length != 0 {
if is_compressed(length) {
let mut recursions = HashSet::new();
self.domain_name_recursion(&mut domain_name, &mut recursions, length)?;
return Ok(domain_name);
} else {
length = self.domain_name_label(&mut domain_name, length)?;
}
}
Ok(domain_name)
}

fn domain_name_label(&mut self, domain_name: &mut DomainName, length: u8) -> DecodeResult<u8> {
let buffer = self.read(length as usize)?;
let label = from_utf8(buffer.as_ref())?;
let label = label.parse()?;
domain_name.append_label(label)?;
self.u8()
}

fn domain_name_recursion(
&mut self,
domain_name: &mut DomainName,
recursion: usize,
recursions: &mut HashSet<usize>,
mut length: u8,
) -> DecodeResult<()> {
if recursion > DOMAIN_NAME_MAX_RECURSION {
return Err(DecodeError::MaxRecursion(recursion));
}
let mut buffer = self.u8()?;
let mut offset = get_offset(length, buffer);
let mut decoder = self.new_main_offset(offset);

length = decoder.u8()?;

let mut length = self.u8()?;
while length != 0 {
if is_compressed(length) {
let buffer = self.u8()?;
let offset = get_offset(length, buffer);
let mut decoder = self.new_main_offset(offset);
return decoder.domain_name_recursion(domain_name, recursion + 1);
buffer = decoder.u8()?;
offset = get_offset(length, buffer);
if recursions.insert(offset) {
let recursions_len = recursions.len();
if recursions_len > DOMAIN_NAME_MAX_RECURSION {
return Err(DecodeError::MaxRecursion(recursions_len));
}
} else {
return Err(DecodeError::EndlessRecursion(offset));
}
decoder.offset = offset as usize;
length = decoder.u8()?;
} else {
let buffer = self.read(length as usize)?;
let label = from_utf8(buffer.as_ref())?;
let label = label.parse()?;
domain_name.append_label(label)?;
length = self.u8()?;
length = decoder.domain_name_label(domain_name, length)?;
}
}

Expand Down
2 changes: 2 additions & 0 deletions src/decode/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ pub enum DecodeError {
DNSKEYProtocol(u8),
#[error("Could not decode the domain name, the because maximum recursion is reached: {0}")]
MaxRecursion(usize),
#[error("Could not decode the domain name, because an endless recursion was detected: {0}")]
EndlessRecursion(usize),
#[error("The are remaining bytes, which was not parsed")]
RemainingBytes(usize, Dns),
#[error("Padding is not zero: {0}")]
Expand Down
14 changes: 7 additions & 7 deletions src/rr/draft_ietf_dnsop_svcb_https.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use std::cmp::Ordering;
use std::fmt::{Display, Formatter, Result as FmtResult};
use std::hash::{Hash, Hasher};
use std::net::{Ipv4Addr, Ipv6Addr};

use crate::rr::draft_ietf_dnsop_svcb_https::ServiceBindingMode::{Alias, Service};
use crate::rr::{ToType, Type};
use crate::DomainName;
use base64::{engine::general_purpose::STANDARD as Base64Standard, Engine};
use std::cmp::Ordering;
use std::collections::BTreeSet;
use std::fmt::{Display, Formatter, Result as FmtResult};
use std::hash::{Hash, Hasher};
use std::net::{Ipv4Addr, Ipv6Addr};

/// A Service Binding record for locating alternative endpoints for a service.
///
Expand Down Expand Up @@ -227,7 +227,7 @@ impl Display for ServiceParameter {
)
}
ServiceParameter::ECH { config_list } => {
write!(f, "ech={}", base64::encode(config_list))
write!(f, "ech={}", Base64Standard.encode(config_list))
}
ServiceParameter::IPV6_HINT { hints } => {
write!(
Expand Down Expand Up @@ -261,7 +261,7 @@ impl Display for ServiceParameter {
if let Ok(value) = String::from_utf8(escaped) {
write!(f, "{}=\"{}\"", key, value)
} else {
write!(f, "{}=\"{}\"", key, base64::encode(wire_data))
write!(f, "{}=\"{}\"", key, Base64Standard.encode(wire_data))
}
}
}
Expand Down
24 changes: 23 additions & 1 deletion tests/decode_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use bytes::Bytes;
use dns_message_parser::{
question::{QClass, QType, Question},
rr::{AddressError, Class, TagError},
Opcode, RCode, MAXIMUM_DNS_PACKET_SIZE, {DecodeError, Dns, Flags},
DecodeError, Dns, DomainName, Flags, Opcode, RCode, MAXIMUM_DNS_PACKET_SIZE,
};

fn decode_msg_error(msg: &[u8], e: DecodeError) {
Expand All @@ -23,6 +23,15 @@ fn decode_flags_error(msg: &[u8], e: DecodeError) {
assert_eq!(flags, Err(e))
}

fn decode_domain_name_error(msg: &[u8], e: DecodeError) {
// Decode BytesMut to message
let bytes = Bytes::copy_from_slice(msg);
// Decode the domain name
let domain_name = DomainName::decode(bytes);
// Check the result
assert_eq!(domain_name, Err(e))
}

#[test]
fn flags_1() {
let msg = b"";
Expand Down Expand Up @@ -53,6 +62,19 @@ fn flags_5() {
decode_flags_error(msg, DecodeError::RCode(15));
}

#[test]
fn domain_name_1() {
let msg = b"\xc0\x02\xc0\x00";
decode_domain_name_error(msg, DecodeError::EndlessRecursion(0));
}

#[test]
fn domain_name_2() {
let msg =
b"\xc0\x02\xc0\x04\xc0\x06\xc0\x08\xc0\x0a\xc0\x0c\xc0\x0e\xc0\x10\xc0\x12\xc0\x14\xc0\x16\xc0\x18\xc0\x1a\xc0\x1c\xc0\x1e\xc0\x20\xc0\x22\xc0\x24";
decode_domain_name_error(msg, DecodeError::MaxRecursion(17));
}

#[test]
fn dns_not_enough_bytes() {
let msg = b"";
Expand Down
Loading