Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions domaintools/api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from datetime import datetime, timedelta, timezone
from hashlib import sha1, sha256, md5
from hmac import new as hmac

import re

from domaintools.constants import Endpoint, ENDPOINT_TO_SOURCE_MAP, FEEDS_PRODUCTS_LIST, OutputFormat
Expand All @@ -11,6 +12,7 @@
ParsedDomainRdap,
Reputation,
Results,
FeedsResults,
)
from domaintools.filters import (
filter_by_riskscore,
Expand Down Expand Up @@ -1065,7 +1067,7 @@ def iris_detect_ignored_domains(
**kwargs,
)

def nod(self, **kwargs):
def nod(self, **kwargs) -> FeedsResults:
"""Returns back list of the newly observed domains feed"""
validate_feeds_parameters(kwargs)
endpoint = kwargs.pop("endpoint", Endpoint.FEED.value)
Expand All @@ -1078,10 +1080,11 @@ def nod(self, **kwargs):
f"newly-observed-domains-feed-({source.value})",
f"v1/{endpoint}/nod/",
response_path=(),
cls=FeedsResults,
**kwargs,
)

def nad(self, **kwargs):
def nad(self, **kwargs) -> FeedsResults:
"""Returns back list of the newly active domains feed"""
validate_feeds_parameters(kwargs)
endpoint = kwargs.pop("endpoint", Endpoint.FEED.value)
Expand All @@ -1094,10 +1097,11 @@ def nad(self, **kwargs):
f"newly-active-domains-feed-({source})",
f"v1/{endpoint}/nad/",
response_path=(),
cls=FeedsResults,
**kwargs,
)

def domainrdap(self, **kwargs):
def domainrdap(self, **kwargs) -> FeedsResults:
"""Returns changes to global domain registration information, populated by the Registration Data Access Protocol (RDAP)"""
validate_feeds_parameters(kwargs)
endpoint = kwargs.pop("endpoint", Endpoint.FEED.value)
Expand All @@ -1107,10 +1111,11 @@ def domainrdap(self, **kwargs):
f"domain-registration-data-access-protocol-feed-({source})",
f"v1/{endpoint}/domainrdap/",
response_path=(),
cls=FeedsResults,
**kwargs,
)

def domaindiscovery(self, **kwargs):
def domaindiscovery(self, **kwargs) -> FeedsResults:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

"""Returns new domains as they are either discovered in domain registration information, observed by our global sensor network, or reported by trusted third parties"""
validate_feeds_parameters(kwargs)
endpoint = kwargs.pop("endpoint", Endpoint.FEED.value)
Expand All @@ -1123,5 +1128,6 @@ def domaindiscovery(self, **kwargs):
f"real-time-domain-discovery-feed-({source})",
f"v1/{endpoint}/domaindiscovery/",
response_path=(),
cls=FeedsResults,
**kwargs,
)
49 changes: 8 additions & 41 deletions domaintools/base_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,29 +77,6 @@ def _wait_time(self):

return wait_for

def _get_feeds_results_generator(self, parameters, headers):
with Client(verify=self.api.verify_ssl, proxy=self.api.proxy_url, timeout=None) as session:
status_code = None
while status_code != 200:
resp_data = session.get(url=self.url, params=parameters, headers=headers, **self.api.extra_request_params)
status_code = resp_data.status_code
self.setStatus(status_code, resp_data)

# Check limit exceeded here
if "response" in resp_data.text and "limit_exceeded" in resp_data.text:
self._limit_exceeded = True
self._limit_exceeded_message = "limit exceeded"

yield resp_data

if self._limit_exceeded:
raise ServiceException(503, "Limit Exceeded{}".format(self._limit_exceeded_message))

if not self.kwargs.get("sessionID"):
# we'll only do iterative request for queries that has sessionID.
# Otherwise, we will have an infinite request if sessionID was not provided but the required data asked is more than the maximum (1 hour of data)
break

def _get_session_params(self):
parameters = deepcopy(self.kwargs)
parameters.pop("output_format", None)
Expand All @@ -118,12 +95,6 @@ def _get_session_params(self):
return {"parameters": parameters, "headers": headers}

def _make_request(self):
if self.product in FEEDS_PRODUCTS_LIST:
session_params = self._get_session_params()
parameters = session_params.get("parameters")
headers = session_params.get("headers")

return self._get_feeds_results_generator(parameters=parameters, headers=headers)

with Client(verify=self.api.verify_ssl, proxy=self.api.proxy_url, timeout=None) as session:
if self.product in [
Expand All @@ -138,15 +109,19 @@ def _make_request(self):
patch_data = self.kwargs.copy()
patch_data.update(self.api.extra_request_params)
return session.patch(url=self.url, json=patch_data)
elif self.product in FEEDS_PRODUCTS_LIST:
session_params = self._get_session_params()
parameters = session_params.get("parameters")
headers = session_params.get("headers")
return session.get(url=self.url, params=parameters, headers=headers, **self.api.extra_request_params)
else:
return session.get(url=self.url, params=self.kwargs, **self.api.extra_request_params)

def _get_results(self):
wait_for = self._wait_time()
if self.api.rate_limit and (wait_for is None or self.product == "account-information"):
data = self._make_request()
status_code = data.status_code if self.product not in FEEDS_PRODUCTS_LIST else 200
if status_code == 503: # pragma: no cover
if data.status_code == 503: # pragma: no cover
sleeptime = 60
log.info(
"503 encountered for [%s] - sleeping [%s] seconds before retrying request.",
Expand All @@ -166,15 +141,12 @@ def _get_results(self):
def data(self):
if self._data is None:
results = self._get_results()
status_code = results.status_code if self.product not in FEEDS_PRODUCTS_LIST else 200
self.setStatus(status_code, results)
self.setStatus(results.status_code, results)
if (
self.kwargs.get("format", "json") == "json"
and self.product not in FEEDS_PRODUCTS_LIST # Special handling of feeds products' data to preserve the result in jsonline format
):
self._data = results.json()
elif self.product in FEEDS_PRODUCTS_LIST:
self._data = results # Uses generator to handle large data results from feeds endpoint
else:
self._data = results.text
limit_exceeded, message = self.check_limit_exceeded()
Expand All @@ -189,14 +161,9 @@ def data(self):
return self._data

def check_limit_exceeded(self):
if self.product in FEEDS_PRODUCTS_LIST:
# bypass here as this is handled in generator already
return False, ""

if self.kwargs.get("format", "json") == "json" and self.product not in FEEDS_PRODUCTS_LIST:
if "response" in self._data and "limit_exceeded" in self._data["response"] and self._data["response"]["limit_exceeded"] is True:
return True, self._data["response"]["message"]
# TODO: handle html, xml response errors better.
elif "response" in self._data and "limit_exceeded" in self._data:
return True, "limit exceeded"
return False, ""
Expand Down Expand Up @@ -249,7 +216,7 @@ def response(self):
return self._response

def items(self):
return self.response().items() if isinstance(self.response(), dict) else self.response()
return self.response().items()

def emails(self):
"""Find and returns all emails mentioned in the response"""
Expand Down
5 changes: 2 additions & 3 deletions domaintools/cli/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from rich.progress import Progress, SpinnerColumn, TextColumn

from domaintools.api import API
from domaintools.constants import Endpoint, OutputFormat, FEEDS_PRODUCTS_LIST
from domaintools.constants import Endpoint, FEEDS_PRODUCTS_LIST, OutputFormat
from domaintools.cli.utils import get_file_extension
from domaintools.exceptions import ServiceException
from domaintools._version import current as version
Expand Down Expand Up @@ -111,8 +111,7 @@ def _get_formatted_output(cls, cmd_name: str, response, out_format: str = "json"
if cmd_name in ("available_api_calls",):
return "\n".join(response)
if response.product in FEEDS_PRODUCTS_LIST:
return "\n".join([data.text for data in response])

return "\n".join([data for data in response.response()])
return str(getattr(response, out_format) if out_format != "list" else response.as_list())

@classmethod
Expand Down
20 changes: 20 additions & 0 deletions domaintools/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from ordereddict import OrderedDict

from itertools import zip_longest
from typing import Generator

from domaintools_async import AsyncResults as Results


Expand Down Expand Up @@ -141,3 +143,21 @@ def flattened(self):
flat[f"contact_{contact_key}_{i}"] = " | ".join(contact_value) if type(contact_value) in (list, tuple) else contact_value

return flat


class FeedsResults(Results):
"""Returns the generator for feeds results"""

def response(self) -> Generator:
status_code = None
while status_code != 200:
resp_data = super().response()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can directly call the self.data() and override it here in the subclass, and remove the special handling case in the data() in the base class, to isolate the feeds results, as the response path is not actually being used for feeds as the results is always a text not a json/dict. wdyt?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my opinion I would like to have it as is. why? because I want it to undergo the same processes as the normal endpoint would take (e.g. checking the limit exceeded, retry action, rate limiting checking, etc.) If we would override the self.data() method we would still need those processes to be implemented as well causing code duplication. what are your thoughts about it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah make sense, hmm I think not actually code duplication rather code isolation for single responsibility, we don't want to put too much if-else condition if there will be another products in the future to support different kinds of results.

also side note: I think we can move the raise condition inside check_limit_exceeded so we dont have the 'code duplication' in the async request as it will automatically handle the ServiceException.

Checking the async request, it seems it doesn't raise the 503 exception unless it returns the actual 503 status code and it will be catched in the setStatus which returns a different exception.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addressed already.

status_code = self.status
yield resp_data

self._data = None
self._response = None
if not self.kwargs.get("sessionID"):
# we'll only do iterative request for queries that has sessionID.
# Otherwise, we will have an infinite request if sessionID was not provided but the required data asked is more than the maximum (1 hour of data)
break
39 changes: 9 additions & 30 deletions domaintools_async/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

import asyncio

from copy import deepcopy
from httpx import AsyncClient

from domaintools.base_results import Results
from domaintools.constants import FEEDS_PRODUCTS_LIST
from domaintools.exceptions import ServiceUnavailableException, ServiceException
from domaintools.constants import FEEDS_PRODUCTS_LIST, OutputFormat, HEADER_ACCEPT_KEY_CSV_FORMAT
from domaintools.exceptions import ServiceUnavailableException


class _AIter(object):
Expand Down Expand Up @@ -41,26 +42,6 @@ class AsyncResults(Results):
def __await__(self):
return self.__awaitable__().__await__()

async def _get_feeds_async_results_generator(self, session, parameters, headers):
status_code = None
while status_code != 200:
resp_data = await session.get(url=self.url, params=parameters, headers=headers, **self.api.extra_request_params)
status_code = resp_data.status_code
self.setStatus(status_code, resp_data)

# Check limit exceeded here
if "response" in resp_data.text and "limit_exceeded" in resp_data.text:
self._limit_exceeded = True
self._limit_exceeded_message = "limit exceeded"
yield resp_data

if self._limit_exceeded:
raise ServiceException(503, "Limit Exceeded{}".format(self._limit_exceeded_message))
if not self.kwargs.get("sessionID"):
# we'll only do iterative request for queries that has sessionID.
# Otherwise, we will have an infinite request if sessionID was not provided but the required data asked is more than the maximum (1 hour of data)
break

async def _make_async_request(self, session):
if self.product in ["iris-investigate", "iris-enrich", "iris-detect-escalate-domains"]:
post_data = self.kwargs.copy()
Expand All @@ -71,19 +52,16 @@ async def _make_async_request(self, session):
patch_data.update(self.api.extra_request_params)
results = await session.patch(url=self.url, json=patch_data)
elif self.product in FEEDS_PRODUCTS_LIST:
generator_params = self._get_session_params()
parameters = generator_params.get("parameters")
headers = generator_params.get("headers")
results = await self._get_feeds_async_results_generator(session=session, parameters=parameters, headers=headers)
session_params = self._get_session_params()
parameters = session_params.get("parameters")
headers = session_params.get("headers")
results = await session.get(url=self.url, params=parameters, headers=headers, **self.api.extra_request_params)
else:
results = await session.get(url=self.url, params=self.kwargs, **self.api.extra_request_params)
if results:
status_code = results.status_code if self.product not in FEEDS_PRODUCTS_LIST else 200
self.setStatus(status_code, results)
self.setStatus(results.status_code, results)
if self.kwargs.get("format", "json") == "json":
self._data = results.json()
elif self.product in FEEDS_PRODUCTS_LIST:
self._data = results # Uses generator to handle large data results from feeds endpoint
else:
self._data = results.text()
limit_exceeded, message = self.check_limit_exceeded()
Expand All @@ -94,6 +72,7 @@ async def _make_async_request(self, session):

async def __awaitable__(self):
if self._data is None:

async with AsyncClient(verify=self.api.verify_ssl, proxy=self.api.proxy_url, timeout=None) as session:
wait_time = self._wait_time()
if wait_time is None and self.api:
Expand Down
Loading
Loading