diff --git a/checkdmarc/__init__.py b/checkdmarc/__init__.py index f12175a..e57c30c 100644 --- a/checkdmarc/__init__.py +++ b/checkdmarc/__init__.py @@ -59,6 +59,7 @@ def check_domains( nameservers: list[str] = None, resolver: dns.resolver.Resolver = None, timeout: float = 2.0, + timeout_retries: int = 2, wait: float = 0.0, ) -> Union[OrderedDict, list[OrderedDict]]: """ @@ -79,6 +80,7 @@ def check_domains( resolver (dns.resolver.Resolver): A resolver object to use for DNS requests timeout (float): number of seconds to wait for an answer from DNS + timeout_retries (int): The number of times to reattempt a query after a timeout wait (float): number of seconds to wait between processing domains Returns: @@ -139,11 +141,16 @@ def check_domains( nameservers=nameservers, resolver=resolver, timeout=timeout, + timeout_retries=timeout_retries, ) mta_sts_mx_patterns = None domain_results["mta_sts"] = check_mta_sts( - domain, nameservers=nameservers, resolver=resolver, timeout=timeout + domain, + nameservers=nameservers, + resolver=resolver, + timeout=timeout, + timeout_retries=timeout_retries, ) if domain_results["mta_sts"]["valid"]: mta_sts_mx_patterns = domain_results["mta_sts"]["policy"]["mx"] @@ -155,6 +162,7 @@ def check_domains( nameservers=nameservers, resolver=resolver, timeout=timeout, + timeout_retries=timeout_retries, ) domain_results["spf"] = check_spf( @@ -163,6 +171,7 @@ def check_domains( nameservers=nameservers, resolver=resolver, timeout=timeout, + timeout_retries=timeout_retries, ) domain_results["dmarc"] = check_dmarc( @@ -172,10 +181,15 @@ def check_domains( nameservers=nameservers, resolver=resolver, timeout=timeout, + timeout_retries=timeout_retries, ) domain_results["smtp_tls_reporting"] = check_smtp_tls_reporting( - domain, nameservers=nameservers, resolver=resolver, timeout=timeout + domain, + nameservers=nameservers, + resolver=resolver, + timeout=timeout, + timeout_retries=timeout_retries, ) if bimi_selector is not None: domain_results["bimi"] = check_bimi( @@ -186,6 +200,7 @@ def check_domains( nameservers=nameservers, resolver=resolver, timeout=timeout, + timeout_retries=timeout_retries, ) results.append(domain_results) @@ -205,6 +220,7 @@ def check_ns( nameservers: list[str] = None, resolver: dns.resolver.Resolver = None, timeout: float = 2.0, + timeout_retries: int = 2, ) -> OrderedDict: """ Returns a dictionary of nameservers and warnings or a dictionary with an @@ -236,6 +252,7 @@ def check_ns( nameservers=nameservers, resolver=resolver, timeout=timeout, + timeout_retries=timeout_retries, ) except DNSException as error: ns_results = OrderedDict([("hostnames", []), ("error", error.__str__())]) diff --git a/checkdmarc/_cli.py b/checkdmarc/_cli.py index c28865c..b844222 100644 --- a/checkdmarc/_cli.py +++ b/checkdmarc/_cli.py @@ -80,6 +80,13 @@ def _main(): type=float, default=2.0, ) + arg_parser.add_argument( + "--timeout-retries", + help="number of times to reattempt a query after a timeout (default 2)", + type=int, + default=2, + ) + arg_parser.add_argument( "-b", "--bimi-selector", default="default", help="the BIMI selector to use" ) @@ -137,6 +144,7 @@ def _main(): include_tag_descriptions=args.descriptions, nameservers=args.nameserver, timeout=args.timeout, + timeout_retries=args.timeout_retries, bimi_selector=args.bimi_selector, wait=args.wait, ) diff --git a/checkdmarc/bimi.py b/checkdmarc/bimi.py index a9d6c01..bd3336b 100644 --- a/checkdmarc/bimi.py +++ b/checkdmarc/bimi.py @@ -650,6 +650,7 @@ def _query_bimi_record( nameservers: list[str] = None, resolver: dns.resolver.Resolver = None, timeout: float = 2.0, + timeout_retries: int = 2, ): """ Queries DNS for a BIMI record @@ -661,6 +662,7 @@ def _query_bimi_record( resolver (dns.resolver.Resolver): A resolver object to use for DNS requests timeout (float): number of seconds to wait for a record from DNS + timeout_retries (int): The number of times to reattempt a query after a timeout Returns: str: A record string or None @@ -674,7 +676,12 @@ def _query_bimi_record( try: records = query_dns( - target, "TXT", nameservers=nameservers, resolver=resolver, timeout=timeout + target, + "TXT", + nameservers=nameservers, + resolver=resolver, + timeout=timeout, + timeout_retries=timeout_retries, ) for record in records: if record.startswith(txt_prefix): @@ -730,6 +737,7 @@ def query_bimi_record( nameservers: list[str] = None, resolver: dns.resolver.Resolver = None, timeout: float = 2.0, + timeout_retries: int = 2, ) -> OrderedDict: """ Queries DNS for a BIMI record @@ -741,6 +749,7 @@ def query_bimi_record( resolver (dns.resolver.Resolver): A resolver object to use for DNS requests timeout (float): number of seconds to wait for a record from DNS + timeout_retries (int): The number of times to reattempt a query after a timeout Returns: OrderedDict: An ``OrderedDict`` with the following keys: @@ -765,10 +774,16 @@ def query_bimi_record( nameservers=nameservers, resolver=resolver, timeout=timeout, + timeout_retries=timeout_retries, ) try: root_records = query_dns( - domain, "TXT", nameservers=nameservers, resolver=resolver, timeout=timeout + domain, + "TXT", + nameservers=nameservers, + resolver=resolver, + timeout=timeout, + timeout_retries=timeout_retries, ) for root_record in root_records: if root_record.startswith("v=BIMI1"): @@ -780,7 +795,11 @@ def query_bimi_record( if record is None and domain != base_domain: record = _query_bimi_record( - base_domain, nameservers=nameservers, resolver=resolver, timeout=timeout + base_domain, + nameservers=nameservers, + resolver=resolver, + timeout=timeout, + timeout_retries=timeout_retries, ) location = base_domain if record is None: @@ -993,6 +1012,7 @@ def check_bimi( nameservers: list[str] = None, resolver: dns.resolver.Resolver = None, timeout: float = 2.0, + timeout_retries: int = 2, ) -> OrderedDict: """ Returns a dictionary with a parsed BIMI record or an error. @@ -1012,6 +1032,7 @@ def check_bimi( resolver (dns.resolver.Resolver): A resolver object to use for DNS requests timeout (float): number of seconds to wait for an answer from DNS + timeout_retries (int): The number of times to reattempt a query after a timeout Returns: OrderedDict: An ``OrderedDict`` with the following keys: diff --git a/checkdmarc/dmarc.py b/checkdmarc/dmarc.py index f414dd9..dfde744 100644 --- a/checkdmarc/dmarc.py +++ b/checkdmarc/dmarc.py @@ -399,6 +399,7 @@ def _query_dmarc_record( nameservers: list[str] = None, resolver: dns.resolver.Resolver = None, timeout: float = 2.0, + timeout_retries: int = 2, ignore_unrelated_records: bool = False, ) -> Union[str, None]: """ @@ -410,6 +411,8 @@ def _query_dmarc_record( resolver (dns.resolver.Resolver): A resolver object to use for DNS requests timeout (float): number of seconds to wait for a record from DNS + timeout_retries (int): The number of times to reattempt a query after a timeout + ignore_unrelated_records (bool): Do not raise a warning if unrelated records are found Returns: str: A record string or None @@ -423,7 +426,12 @@ def _query_dmarc_record( try: records = query_dns( - target, "TXT", nameservers=nameservers, resolver=resolver, timeout=timeout + target, + "TXT", + nameservers=nameservers, + resolver=resolver, + timeout=timeout, + timeout_retries=timeout_retries, ) for record in records: if record.startswith(txt_prefix): @@ -462,6 +470,7 @@ def _query_dmarc_record( nameservers=nameservers, resolver=resolver, timeout=timeout, + timeout_retries=timeout_retries, ) for record in records: if record.startswith(txt_prefix): @@ -495,6 +504,7 @@ def query_dmarc_record( nameservers: list[str] = None, resolver: dns.resolver.Resolver = None, timeout: float = 2.0, + timeout_retries: int = 2, ignore_unrelated_records: bool = False, ) -> OrderedDict: """ @@ -506,6 +516,7 @@ def query_dmarc_record( resolver (dns.resolver.Resolver): A resolver object to use for DNS requests timeout (float): number of seconds to wait for a record from DNS + timeout_retries (int): The number of times to reattempt a query after a timeout ignore_unrelated_records (bool): Ignore unrelated TXT records Returns: @@ -532,6 +543,7 @@ def query_dmarc_record( nameservers=nameservers, resolver=resolver, timeout=timeout, + timeout_retries=timeout_retries, ignore_unrelated_records=ignore_unrelated_records, ) except DMARCRecordNotFound: @@ -558,6 +570,7 @@ def query_dmarc_record( nameservers=nameservers, resolver=resolver, timeout=timeout, + timeout_retries=timeout_retries, ignore_unrelated_records=ignore_unrelated_records, ) location = base_domain @@ -668,6 +681,7 @@ def check_wildcard_dmarc_report_authorization( ignore_unrelated_records: bool = False, resolver: dns.resolver.Resolver = None, timeout: float = 2.0, + timeout_retries: int = 2, ) -> bool: """ Checks for a wildcard DMARC report authorization record, e.g.: @@ -683,6 +697,8 @@ def check_wildcard_dmarc_report_authorization( resolver (dns.resolver.Resolver): A resolver object to use for DNS requests timeout (float): number of seconds to wait for an answer from DNS + timeout_retries (int): The number of times to reattempt a query after a timeout + Returns: bool: An indicator of the existence of a valid wildcard DMARC report @@ -698,6 +714,7 @@ def check_wildcard_dmarc_report_authorization( nameservers=nameservers, resolver=resolver, timeout=timeout, + timeout_retries=timeout_retries, ) for record in records: @@ -732,6 +749,7 @@ def verify_dmarc_report_destination( ignore_unrelated_records: bool = False, resolver: dns.resolver.Resolver = None, timeout: float = 2.0, + timeout_retries: int = 2, ) -> None: """ Checks if the report destination accepts reports for the source domain @@ -746,6 +764,8 @@ def verify_dmarc_report_destination( resolver (dns.resolver.Resolver): A resolver object to use for DNS requests timeout (float): number of seconds to wait for an answer from DNS + timeout_retries (int): The number of times to reattempt a query after a timeout + Raises: :exc:`checkdmarc.dmarc.UnverifiedDMARCURIDestination` @@ -761,6 +781,8 @@ def verify_dmarc_report_destination( nameservers=nameservers, ignore_unrelated_records=ignore_unrelated_records, resolver=resolver, + timeout=timeout, + timeout_retries=timeout_retries, ): return target = f"{source_domain}._report._dmarc.{destination_domain}" @@ -780,6 +802,7 @@ def verify_dmarc_report_destination( nameservers=nameservers, resolver=resolver, timeout=timeout, + timeout_retries=timeout_retries, ) for record in records: @@ -814,6 +837,7 @@ def parse_dmarc_record( ignore_unrelated_records: bool = False, resolver: dns.resolver.Resolver = None, timeout: float = 2.0, + timeout_retries: int = 2, syntax_error_marker: str = SYNTAX_ERROR_MARKER, ) -> OrderedDict: """ @@ -829,6 +853,7 @@ def parse_dmarc_record( resolver (dns.resolver.Resolver): A resolver object to use for DNS requests timeout (float): number of seconds to wait for an answer from DNS + timeout_retries (int): The number of times to reattempt a query after a timeout syntax_error_marker (str): The maker for pointing out syntax errors Returns: @@ -984,6 +1009,7 @@ def parse_dmarc_record( ignore_unrelated_records=ignore_unrelated_records, resolver=resolver, timeout=timeout, + timeout_retries=timeout_retries, ) try: hosts = get_mx_records( @@ -991,6 +1017,7 @@ def parse_dmarc_record( nameservers=nameservers, resolver=resolver, timeout=timeout, + timeout_retries=timeout_retries, ) if len(hosts) == 0: raise DMARCReportEmailAddressMissingMXRecords( @@ -1045,6 +1072,7 @@ def parse_dmarc_record( ignore_unrelated_records=ignore_unrelated_records, resolver=resolver, timeout=timeout, + timeout_retries=timeout_retries, ) try: hosts = get_mx_records( @@ -1052,6 +1080,7 @@ def parse_dmarc_record( nameservers=nameservers, resolver=resolver, timeout=timeout, + timeout_retries=timeout_retries, ) if len(hosts) == 0: raise DMARCReportEmailAddressMissingMXRecords( @@ -1119,6 +1148,7 @@ def get_dmarc_record( nameservers: list[str] = None, resolver: dns.resolver.Resolver = None, timeout: float = 2.0, + timeout_retries: int = 2, ) -> OrderedDict: """ Retrieves a DMARC record for a domain and parses it @@ -1130,6 +1160,7 @@ def get_dmarc_record( resolver (dns.resolver.Resolver): A resolver object to use for DNS requests timeout (float): number of seconds to wait for an answer from DNS + timeout_retries (int): The number of times to reattempt a query after a timeout Returns: OrderedDict: An ``OrderedDict`` with the following keys: @@ -1164,6 +1195,7 @@ def get_dmarc_record( nameservers=nameservers, resolver=resolver, timeout=timeout, + timeout_retries=timeout_retries, ) return OrderedDict( @@ -1180,6 +1212,7 @@ def check_dmarc( nameservers: list[str] = None, resolver: dns.resolver.Resolver = None, timeout: float = 2.0, + timeout_retries: int = 2, ) -> OrderedDict: """ Returns a dictionary with a parsed DMARC record or an error @@ -1193,6 +1226,8 @@ def check_dmarc( resolver (dns.resolver.Resolver): A resolver object to use for DNS requests timeout (float): number of seconds to wait for a record from DNS + timeout_retries (int): The number of times to reattempt a query after a timeout + Returns: OrderedDict: An ``OrderedDict`` with the following keys: @@ -1216,6 +1251,7 @@ def check_dmarc( nameservers=nameservers, resolver=resolver, timeout=timeout, + timeout_retries=timeout_retries, ) dmarc_results["record"] = dmarc_query["record"] dmarc_results["location"] = dmarc_query["location"] @@ -1228,6 +1264,7 @@ def check_dmarc( nameservers=nameservers, resolver=resolver, timeout=timeout, + timeout_retries=timeout_retries, ) dmarc_results["warnings"] = dmarc_query["warnings"] diff --git a/checkdmarc/mta_sts.py b/checkdmarc/mta_sts.py index aa31ca4..d037a02 100644 --- a/checkdmarc/mta_sts.py +++ b/checkdmarc/mta_sts.py @@ -145,6 +145,7 @@ def query_mta_sts_record( nameservers: list[str] = None, resolver: dns.resolver.Resolver = None, timeout: float = 2.0, + timeout_retries: int = 2, ) -> OrderedDict: """ Queries DNS for an MTA-STS record @@ -155,6 +156,8 @@ def query_mta_sts_record( resolver (dns.resolver.Resolver): A resolver object to use for DNS requests timeout (float): number of seconds to wait for a record from DNS + timeout_retries (int): The number of times to reattempt a query after a timeout + Returns: OrderedDict: An ``OrderedDict`` with the following keys: @@ -178,7 +181,12 @@ def query_mta_sts_record( try: records = query_dns( - target, "TXT", nameservers=nameservers, resolver=resolver, timeout=timeout + target, + "TXT", + nameservers=nameservers, + resolver=resolver, + timeout=timeout, + timeout_retries=timeout_retries, ) for record in records: if record.startswith(txt_prefix): @@ -206,6 +214,7 @@ def query_mta_sts_record( nameservers=nameservers, resolver=resolver, timeout=timeout, + timeout_retries=timeout_retries, ) for record in records: if record.startswith(txt_prefix): @@ -433,6 +442,7 @@ def check_mta_sts( nameservers: list[str] = None, resolver: dns.resolver.Resolver = None, timeout: float = 2.0, + timeout_retries: int = 2, ) -> OrderedDict: """ Returns a dictionary with a parsed MTA-STS policy or an error. @@ -443,6 +453,8 @@ def check_mta_sts( resolver (dns.resolver.Resolver): A resolver object to use for DNS requests timeout (float): number of seconds to wait for an answer from DNS + timeout_retries (int): The number of times to reattempt a query after a timeout + Returns: OrderedDict: An ``OrderedDict`` with the following keys: @@ -462,7 +474,11 @@ def check_mta_sts( mta_sts_results = OrderedDict([("valid", True)]) try: mta_sts_record = query_mta_sts_record( - domain, nameservers=nameservers, resolver=resolver, timeout=timeout + domain, + nameservers=nameservers, + resolver=resolver, + timeout=timeout, + timeout_retries=timeout_retries, ) warnings = mta_sts_record["warnings"] mta_sts_record = parse_mta_sts_record(mta_sts_record["record"]) diff --git a/checkdmarc/smtp.py b/checkdmarc/smtp.py index 892f4cb..fa5276b 100644 --- a/checkdmarc/smtp.py +++ b/checkdmarc/smtp.py @@ -334,6 +334,7 @@ def get_mx_hosts( nameservers: list[str] = None, resolver: dns.resolver.Resolver = None, timeout: float = 2.0, + timeout_retries: int = 2, ): """ Gets MX hostname and their addresses @@ -367,7 +368,11 @@ def get_mx_hosts( dupe_hostnames = set() logging.debug(f"Getting MX records for {domain}") mx_records = get_mx_records( - domain, nameservers=nameservers, resolver=resolver, timeout=timeout + domain, + nameservers=nameservers, + resolver=resolver, + timeout=timeout, + timeout_retries=timeout_retries, ) for record in mx_records: hosts.append( @@ -407,16 +412,28 @@ def get_mx_hosts( try: dnssec = False try: - dnssec = test_dnssec(hostname, nameservers=nameservers, timeout=timeout) + dnssec = test_dnssec( + hostname, + nameservers=nameservers, + timeout=timeout, + timeout_retries=timeout_retries, + ) except Exception as e: logging.debug(e) host["dnssec"] = dnssec host["addresses"] = [] host["addresses"] = get_a_records( - hostname, nameservers=nameservers, resolver=resolver, timeout=timeout + hostname, + nameservers=nameservers, + resolver=resolver, + timeout=timeout, + timeout_retries=timeout_retries, ) tlsa_records = get_tlsa_records( - hostname, nameservers=nameservers, timeout=timeout + hostname, + nameservers=nameservers, + timeout=timeout, + timeout_retries=timeout_retries, ) if len(tlsa_records) > 0: @@ -436,7 +453,11 @@ def get_mx_hosts( for address in host["addresses"]: try: reverse_hostnames = get_reverse_dns( - address, nameservers=nameservers, resolver=resolver, timeout=timeout + address, + nameservers=nameservers, + resolver=resolver, + timeout=timeout, + timeout_retries=timeout_retries, ) except DNSException: reverse_hostnames = [] @@ -446,7 +467,12 @@ def get_mx_hosts( ) for reverse_hostname in reverse_hostnames: try: - _addresses = get_a_records(reverse_hostname, resolver=resolver) + _addresses = get_a_records( + reverse_hostname, + resolver=resolver, + timeout=timeout, + timeout_retries=timeout_retries, + ) except DNSException as warning: warnings.append(str(warning)) _addresses = [] @@ -501,6 +527,7 @@ def check_mx( nameservers: list[str] = None, resolver: dns.resolver.Resolver = None, timeout: float = 2.0, + timeout_retries: int = 2, ) -> OrderedDict: """ Gets MX hostname and their addresses, or an empty list of hosts and an @@ -515,6 +542,7 @@ def check_mx( resolver (dns.resolver.Resolver): A resolver object to use for DNS requests timeout (float): number of seconds to wait for a record from DNS + timeout_retries (int): The number of times to reattempt a query after a timeout Returns: OrderedDict: An ``OrderedDict`` with the following keys: @@ -541,6 +569,7 @@ def check_mx( nameservers=nameservers, resolver=resolver, timeout=timeout, + timeout_retries=timeout_retries, ) except DNSException as error: mx_results = OrderedDict([("hosts", []), ("error", str(error))]) diff --git a/checkdmarc/smtp_tls_reporting.py b/checkdmarc/smtp_tls_reporting.py index 09d44f6..070d396 100644 --- a/checkdmarc/smtp_tls_reporting.py +++ b/checkdmarc/smtp_tls_reporting.py @@ -140,6 +140,7 @@ def query_smtp_tls_reporting_record( nameservers: list[str] = None, resolver: dns.resolver.Resolver = None, timeout: float = 2.0, + timeout_retries: int = 2, ) -> OrderedDict: """ Queries DNS for an SMTP TLS Reporting record @@ -150,6 +151,7 @@ def query_smtp_tls_reporting_record( resolver (dns.resolver.Resolver): A resolver object to use for DNS requests timeout (float): number of seconds to wait for a record from DNS + timeout_retries (int): The number of times to reattempt a query after a timeout Returns: OrderedDict: An ``OrderedDict`` with the following keys: @@ -173,7 +175,12 @@ def query_smtp_tls_reporting_record( try: records = query_dns( - target, "TXT", nameservers=nameservers, resolver=resolver, timeout=timeout + target, + "TXT", + nameservers=nameservers, + resolver=resolver, + timeout=timeout, + timeout_retries=timeout_retries, ) for record in records: if record.startswith(txt_prefix): @@ -203,6 +210,7 @@ def query_smtp_tls_reporting_record( nameservers=nameservers, resolver=resolver, timeout=timeout, + timeout_retries=timeout_retries, ) for record in records: if record.startswith(txt_prefix): @@ -323,6 +331,7 @@ def check_smtp_tls_reporting( nameservers: list[str] = None, resolver: dns.resolver.Resolver = None, timeout: float = 2.0, + timeout_retries: int = 2, ) -> OrderedDict: """ Returns a dictionary with a parsed SMTP-TLS Reporting policy or an error. @@ -333,6 +342,7 @@ def check_smtp_tls_reporting( resolver (dns.resolver.Resolver): A resolver object to use for DNS requests timeout (float): number of seconds to wait for an answer from DNS + timeout_retries (int): The number of times to reattempt a query after a timeout Returns: OrderedDict: An ``OrderedDict`` with the following keys: @@ -351,7 +361,11 @@ def check_smtp_tls_reporting( smtp_tls_reporting_results = OrderedDict([("valid", True)]) try: smtp_tls_reporting_record = query_smtp_tls_reporting_record( - domain, nameservers=nameservers, resolver=resolver, timeout=timeout + domain, + nameservers=nameservers, + resolver=resolver, + timeout=timeout, + timeout_retries=timeout_retries, ) warnings = smtp_tls_reporting_record["warnings"] smtp_tls_reporting_record = parse_smtp_tls_reporting_record( diff --git a/checkdmarc/soa.py b/checkdmarc/soa.py index 65e5b39..dabd2ae 100644 --- a/checkdmarc/soa.py +++ b/checkdmarc/soa.py @@ -71,6 +71,7 @@ def check_soa( nameservers: List[str] = None, resolver: dns.resolver.Resolver = None, timeout: float = 2.0, + timeout_retries: int = 2, ) -> OrderedDict: """ Returns a dictionary of a domain's SOA record and a parsed version of the record or a dictionary with an @@ -82,6 +83,7 @@ def check_soa( resolver (dns.resolver.Resolver): A resolver object to use for DNS requests timeout (float): number of seconds to wait for a record from DNS + timeout_retries (int): The number of times to reattempt a query after a timeout Returns: OrderedDict: A dictionary with the following keys: @@ -96,7 +98,11 @@ def check_soa( """ try: record = get_soa_record( - domain, nameservers=nameservers, resolver=resolver, timeout=timeout + domain, + nameservers=nameservers, + resolver=resolver, + timeout=timeout, + timeout_retries=timeout_retries, ) results = OrderedDict([("record", record)]) except Exception as e: diff --git a/checkdmarc/spf.py b/checkdmarc/spf.py index 13fa8ba..80d962e 100644 --- a/checkdmarc/spf.py +++ b/checkdmarc/spf.py @@ -140,6 +140,7 @@ def query_spf_record( nameservers: list[str] = None, resolver: dns.resolver.Resolver = None, timeout: float = 2.0, + timeout_retries: int = 2, ) -> OrderedDict: """ Queries DNS for an SPF record @@ -149,6 +150,7 @@ def query_spf_record( nameservers (list): A list of nameservers to query resolver (dns.resolver.Resolver): A resolver object to use for DNS requests timeout (float): number of seconds to wait for an answer from DNS + timeout_retries (int): The number of times to reattempt a query after a timeout Returns: OrderedDict: An ``OrderedDict`` with the following keys: @@ -166,7 +168,12 @@ def query_spf_record( spf_txt_records = [] try: spf_type_records += query_dns( - domain, "SPF", nameservers=nameservers, resolver=resolver, timeout=timeout + domain, + "SPF", + nameservers=nameservers, + resolver=resolver, + timeout=timeout, + timeout_retries=timeout_retries, ) except (dns.resolver.NoAnswer, Exception): pass @@ -181,7 +188,12 @@ def query_spf_record( warnings.append(message) try: answers = query_dns( - domain, "TXT", nameservers=nameservers, resolver=resolver, timeout=timeout + domain, + "TXT", + nameservers=nameservers, + resolver=resolver, + timeout=timeout, + timeout_retries=timeout_retries, ) spf_record = None for record in answers: @@ -272,6 +284,7 @@ def parse_spf_record( resolver: dns.resolver.Resolver = None, recursion: OrderedDict = None, timeout: float = 2.0, + timeout_retries: int = 2, syntax_error_marker: str = SYNTAX_ERROR_MARKER, ) -> OrderedDict: """ @@ -287,6 +300,7 @@ def parse_spf_record( resolver (dns.resolver.Resolver): A resolver object to use for DNS requests recursion (OrderedDict): Results from a previous call timeout (float): number of seconds to wait for an answer from DNS + timeout_retries (int): The number of times to reattempt a query after a timeout syntax_error_marker (str): The maker for pointing out syntax errors Returns: @@ -403,6 +417,7 @@ def parse_spf_record( nameservers=nameservers, resolver=resolver, timeout=timeout, + timeout_retries=timeout_retries, ) if len(a_records) == 0: mechanism_void_dns_lookups += 1 @@ -436,6 +451,7 @@ def parse_spf_record( nameservers=nameservers, resolver=resolver, timeout=timeout, + timeout_retries=timeout_retries, ) if len(mx_hosts) == 0: @@ -465,6 +481,7 @@ def parse_spf_record( nameservers=nameservers, resolver=resolver, timeout=timeout, + timeout_retries=timeout_retries, ) host_ips[hostname] = _addresses @@ -542,6 +559,7 @@ def parse_spf_record( nameservers=nameservers, resolver=resolver, timeout=timeout, + timeout_retries=timeout_retries, ) redirect_record = redirect_record["record"] redirect = parse_spf_record( @@ -552,6 +570,7 @@ def parse_spf_record( nameservers=nameservers, resolver=resolver, timeout=timeout, + timeout_retries=timeout_retries, ) parsed["all"] = redirect["parsed"]["all"] mechanism_dns_lookups += redirect["dns_lookups"] @@ -599,6 +618,7 @@ def parse_spf_record( nameservers=nameservers, resolver=resolver, timeout=timeout, + timeout_retries=timeout_retries, ) parsed["exp"] = txts[0] if txts else None @@ -633,6 +653,7 @@ def parse_spf_record( nameservers=nameservers, resolver=resolver, timeout=timeout, + timeout_retries=timeout_retries, ) include_record = include_record["record"] include = parse_spf_record( @@ -643,6 +664,7 @@ def parse_spf_record( nameservers=nameservers, resolver=resolver, timeout=timeout, + timeout_retries=timeout_retries, ) total_dns_lookups += include["dns_lookups"] total_void_dns_lookups += include["void_dns_lookups"] @@ -773,6 +795,7 @@ def get_spf_record( nameservers: list[str] = None, resolver: dns.resolver.Resolver = None, timeout: float = 2.0, + timeout_retries: int = 2, ) -> OrderedDict: """ Retrieves and parses an SPF record @@ -782,6 +805,7 @@ def get_spf_record( nameservers (list): A list of nameservers to query resolver (dns.resolver.Resolver): A resolver object to use for DNS requests timeout (float): Number of seconds to wait for an answer from DNS + timeout_retries (int): The number of times to reattempt a query after a timeout Returns: OrderedDict: An SPF record parsed by result @@ -812,6 +836,7 @@ def check_spf( nameservers: list[str] = None, resolver: dns.resolver.Resolver = None, timeout: float = 2.0, + timeout_retries: int = 2, ) -> OrderedDict: """ Returns a dictionary with a parsed SPF record or an error. @@ -822,6 +847,7 @@ def check_spf( nameservers (list): A list of nameservers to query resolver (dns.resolver.Resolver): A resolver object to use for DNS requests timeout (float): number of seconds to wait for an answer from DNS + timeout_retries (int): The number of times to reattempt a query after a timeout Returns: OrderedDict: An ``OrderedDict`` with the following keys: @@ -847,7 +873,11 @@ def check_spf( ) try: spf_query = query_spf_record( - domain, nameservers=nameservers, resolver=resolver, timeout=timeout + domain, + nameservers=nameservers, + resolver=resolver, + timeout=timeout, + timeout_retries=timeout_retries, ) spf_results["record"] = spf_query["record"] spf_results["warnings"] = spf_query["warnings"] @@ -860,6 +890,7 @@ def check_spf( nameservers=nameservers, resolver=resolver, timeout=timeout, + timeout_retries=timeout_retries, ) spf_results["dns_lookups"] = parsed_spf["dns_lookups"] diff --git a/checkdmarc/utils.py b/checkdmarc/utils.py index 595f1c2..9095cb1 100644 --- a/checkdmarc/utils.py +++ b/checkdmarc/utils.py @@ -100,6 +100,8 @@ def query_dns( nameservers: list[str] = None, resolver: dns.resolver.Resolver = None, timeout: float = 2.0, + timeout_retries: int = 2, + _attempt: int = 0, cache: ExpiringDict = None, ) -> list[str]: """ @@ -112,6 +114,7 @@ def query_dns( resolver (dns.resolver.Resolver): A resolver object to use for DNS requests timeout (float): Sets the DNS timeout in seconds + timeout_retries (int): The number of times to reattempt a query after a timeout cache (ExpiringDict): Cache storage Returns: @@ -134,10 +137,25 @@ def query_dns( resolver.timeout = timeout resolver.lifetime = timeout if record_type == "TXT": + try: + answers = resolver.resolve(domain, record_type, lifetime=timeout) + except dns.resolver.LifetimeTimeout as e: + _attempt += 1 + if _attempt > timeout_retries: + raise e + return query_dns( + domain, + record_type, + nameservers, + resolver=resolver, + timeout=timeout, + timeout_retries=timeout_retries, + _attempt=_attempt, + ) resource_records = list( map( lambda r: r.strings, - resolver.resolve(domain, record_type, lifetime=timeout), + answers, ) ) _resource_record = [ @@ -153,10 +171,25 @@ def query_dns( r = "Undecodable characters" records.append(r) else: + try: + answers = resolver.resolve(domain, record_type, lifetime=timeout) + except dns.resolver.LifetimeTimeout as e: + _attempt += 1 + if _attempt > timeout_retries: + raise e + return query_dns( + domain, + record_type, + nameservers, + resolver=resolver, + timeout=timeout, + timeout_retries=timeout_retries, + _attempt=_attempt, + ) records = list( map( lambda r: r.to_text().replace('"', "").rstrip("."), - resolver.resolve(domain, record_type, lifetime=timeout), + answers, ) ) if type(cache) is ExpiringDict: @@ -171,6 +204,7 @@ def get_a_records( nameservers: list[str] = None, resolver: dns.resolver.Resolver = None, timeout: float = 2.0, + timeout_retries: int = 2, ) -> list[str]: """ Queries DNS for A and AAAA records @@ -194,7 +228,12 @@ def get_a_records( try: logging.debug(f"Getting {qt} records for {domain}") addresses += query_dns( - domain, qt, nameservers=nameservers, resolver=resolver, timeout=timeout + domain, + qt, + nameservers=nameservers, + resolver=resolver, + timeout=timeout, + timeout_retries=timeout_retries, ) except dns.resolver.NXDOMAIN: raise DNSExceptionNXDOMAIN("The domain does not exist.") @@ -214,6 +253,7 @@ def get_reverse_dns( nameservers: list[str] = None, resolver: dns.resolver.Resolver = None, timeout: float = 2.0, + timeout_retries: int = 2, ) -> list[str]: """ Queries for an IP addresses reverse DNS hostname(s) @@ -224,6 +264,7 @@ def get_reverse_dns( resolver (dns.resolver.Resolver): A resolver object to use for DNS requests timeout (float): number of seconds to wait for an answer from DNS + timeout_retries (int): The number of times to reattempt a query after a timeout Returns: list: A list of reverse DNS hostnames @@ -236,7 +277,12 @@ def get_reverse_dns( name = str(dns.reversename.from_address(ip_address)) logging.debug(f"Getting PTR records for {ip_address}") hostnames = query_dns( - name, "PTR", nameservers=nameservers, resolver=resolver, timeout=timeout + name, + "PTR", + nameservers=nameservers, + resolver=resolver, + timeout=timeout, + timeout_retries=timeout_retries, ) except dns.resolver.NXDOMAIN: return [] @@ -252,6 +298,7 @@ def get_txt_records( nameservers: list[str] = None, resolver: dns.resolver.Resolver = None, timeout: float = 2.0, + timeout_retries: int = 2, ) -> list[str]: """ Queries DNS for TXT records @@ -262,6 +309,7 @@ def get_txt_records( resolver (dns.resolver.Resolver): A resolver object to use for DNS requests timeout (float): number of seconds to wait for an answer from DNS + timeout_retries (int): The number of times to reattempt a query after a timeout Returns: list: A list of TXT records @@ -290,6 +338,7 @@ def get_soa_record( nameservers: list[str] = None, resolver: dns.resolver.Resolver = None, timeout: float = 2.0, + timeout_retries: int = 2, ) -> list[str]: """ Queries DNS for an SOA record @@ -300,6 +349,7 @@ def get_soa_record( resolver (dns.resolver.Resolver): A resolver object to use for DNS requests timeout (float): number of seconds to wait for an answer from DNS + timeout_retries (int): The number of times to reattempt a query after a timeout Returns: str: An SOA record @@ -311,7 +361,12 @@ def get_soa_record( domain = get_base_domain(domain) try: record = query_dns( - domain, "SOA", nameservers=nameservers, resolver=resolver, timeout=timeout + domain, + "SOA", + nameservers=nameservers, + resolver=resolver, + timeout=timeout, + timeout_retries=timeout_retries, )[0] except dns.resolver.NXDOMAIN: raise DNSExceptionNXDOMAIN("The domain does not exist.") @@ -330,6 +385,7 @@ def get_nameservers( nameservers: list[str] = None, resolver: dns.resolver.Resolver = None, timeout: float = 2.0, + timeout_retries: int = 2, ) -> dict: """ Gets a list of nameservers for a given domain @@ -341,6 +397,7 @@ def get_nameservers( resolver (dns.resolver.Resolver): A resolver object to use for DNS requests timeout (float): number of seconds to wait for a record from DNS + timeout_retries (int): The number of times to reattempt a query after a timeout Returns: OrderedDict: A dictionary with the following keys: @@ -353,7 +410,12 @@ def get_nameservers( ns_records = [] try: ns_records = query_dns( - domain, "NS", nameservers=nameservers, resolver=resolver, timeout=timeout + domain, + "NS", + nameservers=nameservers, + resolver=resolver, + timeout=timeout, + timeout_retries=timeout_retries, ) except dns.resolver.NXDOMAIN: raise DNSExceptionNXDOMAIN("The domain does not exist.") @@ -383,6 +445,7 @@ def get_mx_records( nameservers: list[str] = None, resolver: dns.resolver.Resolver = None, timeout: float = 2.0, + timeout_retries: int = 2, ) -> list[OrderedDict]: """ Queries DNS for a list of Mail Exchange hosts @@ -393,6 +456,7 @@ def get_mx_records( resolver (dns.resolver.Resolver): A resolver object to use for DNS requests timeout (float): number of seconds to wait for an answer from DNS + timeout_retries (int): The number of times to reattempt a query after a timeout Returns: list: A list of ``OrderedDicts``; each containing a ``preference``