Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions domaintools/api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from datetime import datetime, timedelta, timezone
from hashlib import sha1, sha256, md5
from hmac import new as hmac

import re

from domaintools.constants import Endpoint, ENDPOINT_TO_SOURCE_MAP, FEEDS_PRODUCTS_LIST, OutputFormat
Expand All @@ -11,6 +12,7 @@
ParsedDomainRdap,
Reputation,
Results,
FeedsResults,
)
from domaintools.filters import (
filter_by_riskscore,
Expand Down Expand Up @@ -1065,7 +1067,7 @@ def iris_detect_ignored_domains(
**kwargs,
)

def nod(self, **kwargs):
def nod(self, **kwargs) -> FeedsResults:
"""Returns back list of the newly observed domains feed"""
validate_feeds_parameters(kwargs)
endpoint = kwargs.pop("endpoint", Endpoint.FEED.value)
Expand All @@ -1078,10 +1080,11 @@ def nod(self, **kwargs):
f"newly-observed-domains-feed-({source.value})",
f"v1/{endpoint}/nod/",
response_path=(),
cls=FeedsResults,
**kwargs,
)

def nad(self, **kwargs):
def nad(self, **kwargs) -> FeedsResults:
"""Returns back list of the newly active domains feed"""
validate_feeds_parameters(kwargs)
endpoint = kwargs.pop("endpoint", Endpoint.FEED.value)
Expand All @@ -1094,10 +1097,11 @@ def nad(self, **kwargs):
f"newly-active-domains-feed-({source})",
f"v1/{endpoint}/nad/",
response_path=(),
cls=FeedsResults,
**kwargs,
)

def domainrdap(self, **kwargs):
def domainrdap(self, **kwargs) -> FeedsResults:
"""Returns changes to global domain registration information, populated by the Registration Data Access Protocol (RDAP)"""
validate_feeds_parameters(kwargs)
endpoint = kwargs.pop("endpoint", Endpoint.FEED.value)
Expand All @@ -1107,10 +1111,11 @@ def domainrdap(self, **kwargs):
f"domain-registration-data-access-protocol-feed-({source})",
f"v1/{endpoint}/domainrdap/",
response_path=(),
cls=FeedsResults,
**kwargs,
)

def domaindiscovery(self, **kwargs):
def domaindiscovery(self, **kwargs) -> FeedsResults:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

"""Returns new domains as they are either discovered in domain registration information, observed by our global sensor network, or reported by trusted third parties"""
validate_feeds_parameters(kwargs)
endpoint = kwargs.pop("endpoint", Endpoint.FEED.value)
Expand All @@ -1123,5 +1128,6 @@ def domaindiscovery(self, **kwargs):
f"real-time-domain-discovery-feed-({source})",
f"v1/{endpoint}/domaindiscovery/",
response_path=(),
cls=FeedsResults,
**kwargs,
)
80 changes: 20 additions & 60 deletions domaintools/base_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@ def __init__(
self._response = None
self._items_list = None
self._data = None
self._limit_exceeded = None
self._limit_exceeded_message = None

def _wait_time(self):
if not self.api.rate_limit or not self.product in self.api.limits:
Expand All @@ -77,29 +75,6 @@ def _wait_time(self):

return wait_for

def _get_feeds_results_generator(self, parameters, headers):
with Client(verify=self.api.verify_ssl, proxy=self.api.proxy_url, timeout=None) as session:
status_code = None
while status_code != 200:
resp_data = session.get(url=self.url, params=parameters, headers=headers, **self.api.extra_request_params)
status_code = resp_data.status_code
self.setStatus(status_code, resp_data)

# Check limit exceeded here
if "response" in resp_data.text and "limit_exceeded" in resp_data.text:
self._limit_exceeded = True
self._limit_exceeded_message = "limit exceeded"

yield resp_data

if self._limit_exceeded:
raise ServiceException(503, "Limit Exceeded{}".format(self._limit_exceeded_message))

if not self.kwargs.get("sessionID"):
# we'll only do iterative request for queries that has sessionID.
# Otherwise, we will have an infinite request if sessionID was not provided but the required data asked is more than the maximum (1 hour of data)
break

def _get_session_params(self):
parameters = deepcopy(self.kwargs)
parameters.pop("output_format", None)
Expand All @@ -118,12 +93,6 @@ def _get_session_params(self):
return {"parameters": parameters, "headers": headers}

def _make_request(self):
if self.product in FEEDS_PRODUCTS_LIST:
session_params = self._get_session_params()
parameters = session_params.get("parameters")
headers = session_params.get("headers")

return self._get_feeds_results_generator(parameters=parameters, headers=headers)

with Client(verify=self.api.verify_ssl, proxy=self.api.proxy_url, timeout=None) as session:
if self.product in [
Expand All @@ -138,15 +107,19 @@ def _make_request(self):
patch_data = self.kwargs.copy()
patch_data.update(self.api.extra_request_params)
return session.patch(url=self.url, json=patch_data)
elif self.product in FEEDS_PRODUCTS_LIST:
session_params = self._get_session_params()
parameters = session_params.get("parameters")
headers = session_params.get("headers")
return session.get(url=self.url, params=parameters, headers=headers, **self.api.extra_request_params)
else:
return session.get(url=self.url, params=self.kwargs, **self.api.extra_request_params)

def _get_results(self):
wait_for = self._wait_time()
if self.api.rate_limit and (wait_for is None or self.product == "account-information"):
data = self._make_request()
status_code = data.status_code if self.product not in FEEDS_PRODUCTS_LIST else 200
if status_code == 503: # pragma: no cover
if data.status_code == 503: # pragma: no cover
sleeptime = 60
log.info(
"503 encountered for [%s] - sleeping [%s] seconds before retrying request.",
Expand All @@ -166,40 +139,27 @@ def _get_results(self):
def data(self):
if self._data is None:
results = self._get_results()
status_code = results.status_code if self.product not in FEEDS_PRODUCTS_LIST else 200
self.setStatus(status_code, results)
if (
self.kwargs.get("format", "json") == "json"
and self.product not in FEEDS_PRODUCTS_LIST # Special handling of feeds products' data to preserve the result in jsonline format
):
self.setStatus(results.status_code, results)
if self.kwargs.get("format", "json") == "json":
self._data = results.json()
elif self.product in FEEDS_PRODUCTS_LIST:
self._data = results # Uses generator to handle large data results from feeds endpoint
else:
self._data = results.text
limit_exceeded, message = self.check_limit_exceeded()

if limit_exceeded:
self._limit_exceeded = True
self._limit_exceeded_message = message
self.check_limit_exceeded()

if self._limit_exceeded is True:
raise ServiceException(503, "Limit Exceeded{}".format(self._limit_exceeded_message))
else:
return self._data
return self._data

def check_limit_exceeded(self):
if self.product in FEEDS_PRODUCTS_LIST:
# bypass here as this is handled in generator already
return False, ""

if self.kwargs.get("format", "json") == "json" and self.product not in FEEDS_PRODUCTS_LIST:
if "response" in self._data and "limit_exceeded" in self._data["response"] and self._data["response"]["limit_exceeded"] is True:
return True, self._data["response"]["message"]
# TODO: handle html, xml response errors better.
limit_exceeded, reason = False, ""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a pedantic change:

can we put the Limit Exceeded as default reason? so we dont need to define it in elif and just concat the message in the if statement.

Overall, LGTM ✅ . Nice!!

if isinstance(self._data, dict) and (
"response" in self._data and "limit_exceeded" in self._data["response"] and self._data["response"]["limit_exceeded"] is True
):
limit_exceeded, reason = True, f"Limit Exceeded: {self._data['response']['message']}"
elif "response" in self._data and "limit_exceeded" in self._data:
return True, "limit exceeded"
return False, ""
limit_exceeded, reason = True, "Limit Exceeded"

if limit_exceeded:
raise ServiceException(503, reason)

@property
def status(self):
Expand Down Expand Up @@ -249,7 +209,7 @@ def response(self):
return self._response

def items(self):
return self.response().items() if isinstance(self.response(), dict) else self.response()
return self.response().items()

def emails(self):
"""Find and returns all emails mentioned in the response"""
Expand Down
5 changes: 2 additions & 3 deletions domaintools/cli/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from rich.progress import Progress, SpinnerColumn, TextColumn

from domaintools.api import API
from domaintools.constants import Endpoint, OutputFormat, FEEDS_PRODUCTS_LIST
from domaintools.constants import Endpoint, FEEDS_PRODUCTS_LIST, OutputFormat
from domaintools.cli.utils import get_file_extension
from domaintools.exceptions import ServiceException
from domaintools._version import current as version
Expand Down Expand Up @@ -111,8 +111,7 @@ def _get_formatted_output(cls, cmd_name: str, response, out_format: str = "json"
if cmd_name in ("available_api_calls",):
return "\n".join(response)
if response.product in FEEDS_PRODUCTS_LIST:
return "\n".join([data.text for data in response])

return "\n".join([data for data in response.response()])
return str(getattr(response, out_format) if out_format != "list" else response.as_list())

@classmethod
Expand Down
27 changes: 27 additions & 0 deletions domaintools/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from ordereddict import OrderedDict

from itertools import zip_longest
from typing import Generator

from domaintools_async import AsyncResults as Results


Expand Down Expand Up @@ -141,3 +143,28 @@ def flattened(self):
flat[f"contact_{contact_key}_{i}"] = " | ".join(contact_value) if type(contact_value) in (list, tuple) else contact_value

return flat


class FeedsResults(Results):
"""Returns the generator for feeds results"""

def response(self) -> Generator:
status_code = None
while status_code != 200:
resp_data = self.data()
status_code = self.status
yield resp_data

self._data = None # clear the data here
if not self.kwargs.get("sessionID"):
# we'll only do iterative request for queries that has sessionID.
# Otherwise, we will have an infinite request if sessionID was not provided but the required data asked is more than the maximum (1 hour of data)
break

def data(self):
results = self._get_results()
self.setStatus(results.status_code, results)
self._data = results.text
self.check_limit_exceeded()

return self._data
44 changes: 10 additions & 34 deletions domaintools_async/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

import asyncio

from copy import deepcopy
from httpx import AsyncClient

from domaintools.base_results import Results
from domaintools.constants import FEEDS_PRODUCTS_LIST
from domaintools.exceptions import ServiceUnavailableException, ServiceException
from domaintools.constants import FEEDS_PRODUCTS_LIST, OutputFormat, HEADER_ACCEPT_KEY_CSV_FORMAT
from domaintools.exceptions import ServiceUnavailableException


class _AIter(object):
Expand Down Expand Up @@ -41,26 +42,6 @@ class AsyncResults(Results):
def __await__(self):
return self.__awaitable__().__await__()

async def _get_feeds_async_results_generator(self, session, parameters, headers):
status_code = None
while status_code != 200:
resp_data = await session.get(url=self.url, params=parameters, headers=headers, **self.api.extra_request_params)
status_code = resp_data.status_code
self.setStatus(status_code, resp_data)

# Check limit exceeded here
if "response" in resp_data.text and "limit_exceeded" in resp_data.text:
self._limit_exceeded = True
self._limit_exceeded_message = "limit exceeded"
yield resp_data

if self._limit_exceeded:
raise ServiceException(503, "Limit Exceeded{}".format(self._limit_exceeded_message))
if not self.kwargs.get("sessionID"):
# we'll only do iterative request for queries that has sessionID.
# Otherwise, we will have an infinite request if sessionID was not provided but the required data asked is more than the maximum (1 hour of data)
break

async def _make_async_request(self, session):
if self.product in ["iris-investigate", "iris-enrich", "iris-detect-escalate-domains"]:
post_data = self.kwargs.copy()
Expand All @@ -71,29 +52,24 @@ async def _make_async_request(self, session):
patch_data.update(self.api.extra_request_params)
results = await session.patch(url=self.url, json=patch_data)
elif self.product in FEEDS_PRODUCTS_LIST:
generator_params = self._get_session_params()
parameters = generator_params.get("parameters")
headers = generator_params.get("headers")
results = await self._get_feeds_async_results_generator(session=session, parameters=parameters, headers=headers)
session_params = self._get_session_params()
parameters = session_params.get("parameters")
headers = session_params.get("headers")
results = await session.get(url=self.url, params=parameters, headers=headers, **self.api.extra_request_params)
else:
results = await session.get(url=self.url, params=self.kwargs, **self.api.extra_request_params)
if results:
status_code = results.status_code if self.product not in FEEDS_PRODUCTS_LIST else 200
self.setStatus(status_code, results)
self.setStatus(results.status_code, results)
if self.kwargs.get("format", "json") == "json":
self._data = results.json()
elif self.product in FEEDS_PRODUCTS_LIST:
self._data = results # Uses generator to handle large data results from feeds endpoint
else:
self._data = results.text()
limit_exceeded, message = self.check_limit_exceeded()

if limit_exceeded:
self._limit_exceeded = True
self._limit_exceeded_message = message
self.check_limit_exceeded()

async def __awaitable__(self):
if self._data is None:

async with AsyncClient(verify=self.api.verify_ssl, proxy=self.api.proxy_url, timeout=None) as session:
wait_time = self._wait_time()
if wait_time is None and self.api:
Expand Down
Loading
Loading