Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.5.3
2.6.0
2 changes: 1 addition & 1 deletion domaintools/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@

"""

current = "2.5.3"
current = "2.6.0"
72 changes: 60 additions & 12 deletions domaintools/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ def __init__(
self._build_api_url(api_url, api_port)

if not https:
raise Exception("The DomainTools API endpoints no longer support http traffic. Please make sure https=True.")
raise Exception(
"The DomainTools API endpoints no longer support http traffic. Please make sure https=True."
)
if proxy_url and not isinstance(proxy_url, str):
raise Exception("Proxy URL must be a string. For example: '127.0.0.1:8888'")

Expand Down Expand Up @@ -129,7 +131,8 @@ def _results(self, product, path, cls=Results, **kwargs):
always_sign_api_key_previous_value = self.always_sign_api_key
header_authentication_previous_value = self.header_authentication
self._rate_limit()
# Reset always_sign_api_key and header_authentication to its original User-set values as these might be affected when self.account_information() was executed
# Reset always_sign_api_key and header_authentication to its original
# User-set values as these might be affected when self.account_information() was executed
self.always_sign_api_key = always_sign_api_key_previous_value
self.header_authentication = header_authentication_previous_value

Expand All @@ -139,7 +142,13 @@ def _results(self, product, path, cls=Results, **kwargs):
is_rttf_product = product in RTTF_PRODUCTS_LIST
self._handle_api_key_parameters(is_rttf_product)
self.handle_api_key(is_rttf_product, path, parameters)
parameters.update({key: str(value).lower() if value in (True, False) else value for key, value in kwargs.items() if value is not None})
parameters.update(
{
key: str(value).lower() if value in (True, False) else value
for key, value in kwargs.items()
if value is not None
}
)

return cls(self, product, uri, **parameters)

Expand Down Expand Up @@ -189,8 +198,20 @@ def snakecase(string):
string[1:],
)

api_calls = tuple((api_call for api_call in dir(API) if not api_call.startswith("_") and callable(getattr(API, api_call, None))))
return sorted([snakecase(p["id"]) for p in self.account_information()["products"] if snakecase(p["id"]) in api_calls])
api_calls = tuple(
(
api_call
for api_call in dir(API)
if not api_call.startswith("_") and callable(getattr(API, api_call, None))
)
)
return sorted(
[
snakecase(p["id"])
for p in self.account_information()["products"]
if snakecase(p["id"]) in api_calls
]
)

def brand_monitor(self, query, exclude=None, domain_status=None, days_back=None, **kwargs):
"""Pass in one or more terms as a list or separated by the pipe character ( | )"""
Expand Down Expand Up @@ -445,7 +466,16 @@ def iris(
"""Performs a search for the provided search terms ANDed together,
returning the pivot engine row data for the resulting domains.
"""
if not domain and not ip and not email and not nameserver and not registrar and not registrant and not registrant_org and not kwargs:
if (
not domain
and not ip
and not email
and not nameserver
and not registrar
and not registrant
and not registrant_org
and not kwargs
):
raise ValueError("At least one search term must be specified")

return self._results(
Expand Down Expand Up @@ -1069,7 +1099,10 @@ def nod(self, **kwargs) -> FeedsResults:
validate_feeds_parameters(kwargs)
endpoint = kwargs.pop("endpoint", Endpoint.FEED.value)
source = ENDPOINT_TO_SOURCE_MAP.get(endpoint)
if endpoint == Endpoint.DOWNLOAD.value or kwargs.get("output_format", OutputFormat.JSONL.value) != OutputFormat.CSV.value:
if (
endpoint == Endpoint.DOWNLOAD.value
or kwargs.get("output_format", OutputFormat.JSONL.value) != OutputFormat.CSV.value
):
# headers param is allowed only in Feed API and CSV format
kwargs.pop("headers", None)

Expand Down Expand Up @@ -1101,7 +1134,10 @@ def nad(self, **kwargs) -> FeedsResults:
validate_feeds_parameters(kwargs)
endpoint = kwargs.pop("endpoint", Endpoint.FEED.value)
source = ENDPOINT_TO_SOURCE_MAP.get(endpoint).value
if endpoint == Endpoint.DOWNLOAD.value or kwargs.get("output_format", OutputFormat.JSONL.value) != OutputFormat.CSV.value:
if (
endpoint == Endpoint.DOWNLOAD.value
or kwargs.get("output_format", OutputFormat.JSONL.value) != OutputFormat.CSV.value
):
# headers param is allowed only in Feed API and CSV format
kwargs.pop("headers", None)

Expand Down Expand Up @@ -1162,7 +1198,10 @@ def domaindiscovery(self, **kwargs) -> FeedsResults:
validate_feeds_parameters(kwargs)
endpoint = kwargs.pop("endpoint", Endpoint.FEED.value)
source = ENDPOINT_TO_SOURCE_MAP.get(endpoint).value
if endpoint == Endpoint.DOWNLOAD.value or kwargs.get("output_format", OutputFormat.JSONL.value) != OutputFormat.CSV.value:
if (
endpoint == Endpoint.DOWNLOAD.value
or kwargs.get("output_format", OutputFormat.JSONL.value) != OutputFormat.CSV.value
):
# headers param is allowed only in Feed API and CSV format
kwargs.pop("headers", None)

Expand Down Expand Up @@ -1194,7 +1233,10 @@ def noh(self, **kwargs) -> FeedsResults:
validate_feeds_parameters(kwargs)
endpoint = kwargs.pop("endpoint", Endpoint.FEED.value)
source = ENDPOINT_TO_SOURCE_MAP.get(endpoint).value
if endpoint == Endpoint.DOWNLOAD.value or kwargs.get("output_format", OutputFormat.JSONL.value) != OutputFormat.CSV.value:
if (
endpoint == Endpoint.DOWNLOAD.value
or kwargs.get("output_format", OutputFormat.JSONL.value) != OutputFormat.CSV.value
):
# headers param is allowed only in Feed API and CSV format
kwargs.pop("headers", None)

Expand Down Expand Up @@ -1225,7 +1267,10 @@ def realtime_domain_risk(self, **kwargs) -> FeedsResults:
validate_feeds_parameters(kwargs)
endpoint = kwargs.pop("endpoint", Endpoint.FEED.value)
source = ENDPOINT_TO_SOURCE_MAP.get(endpoint).value
if endpoint == Endpoint.DOWNLOAD.value or kwargs.get("output_format", OutputFormat.JSONL.value) != OutputFormat.CSV.value:
if (
endpoint == Endpoint.DOWNLOAD.value
or kwargs.get("output_format", OutputFormat.JSONL.value) != OutputFormat.CSV.value
):
# headers param is allowed only in Feed API and CSV format
kwargs.pop("headers", None)

Expand Down Expand Up @@ -1256,7 +1301,10 @@ def domainhotlist(self, **kwargs) -> FeedsResults:
validate_feeds_parameters(kwargs)
endpoint = kwargs.pop("endpoint", Endpoint.FEED.value)
source = ENDPOINT_TO_SOURCE_MAP.get(endpoint).value
if endpoint == Endpoint.DOWNLOAD.value or kwargs.get("output_format", OutputFormat.JSONL.value) != OutputFormat.CSV.value:
if (
endpoint == Endpoint.DOWNLOAD.value
or kwargs.get("output_format", OutputFormat.JSONL.value) != OutputFormat.CSV.value
):
# headers param is allowed only in Feed API and CSV format
kwargs.pop("headers", None)

Expand Down
35 changes: 27 additions & 8 deletions domaintools/base_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
from datetime import datetime
from httpx import Client

from domaintools.constants import RTTF_PRODUCTS_LIST, OutputFormat, HEADER_ACCEPT_KEY_CSV_FORMAT
from domaintools.constants import (
RTTF_PRODUCTS_LIST,
OutputFormat,
HEADER_ACCEPT_KEY_CSV_FORMAT,
)
from domaintools.exceptions import (
BadRequestException,
InternalServerErrorException,
Expand Down Expand Up @@ -53,6 +57,7 @@ def __init__(
self._response = None
self._items_list = None
self._data = None
self._status = None

def _wait_time(self):
if not self.api.rate_limit or not self.product in self.api.limits:
Expand Down Expand Up @@ -92,10 +97,10 @@ def _get_session_params_and_headers(self):
header_key_for_api_key = "X-Api-Key" if is_rttf_product else "X-API-Key"
headers[header_key_for_api_key] = self.api.key

return {"parameters": parameters, "headers": headers}
session_param_and_headers = {"parameters": parameters, "headers": headers}
return session_param_and_headers

def _make_request(self):

with Client(verify=self.api.verify_ssl, proxy=self.api.proxy_url, timeout=None) as session:
session_params_and_headers = self._get_session_params_and_headers()
headers = session_params_and_headers.get("headers")
Expand All @@ -113,7 +118,12 @@ def _make_request(self):
return session.patch(url=self.url, json=patch_data, headers=headers)
else:
parameters = session_params_and_headers.get("parameters")
return session.get(url=self.url, params=parameters, headers=headers, **self.api.extra_request_params)
return session.get(
url=self.url,
params=parameters,
headers=headers,
**self.api.extra_request_params,
)

def _get_results(self):
wait_for = self._wait_time()
Expand Down Expand Up @@ -152,7 +162,9 @@ def data(self):
def check_limit_exceeded(self):
limit_exceeded, reason = False, ""
if isinstance(self._data, dict) and (
"response" in self._data and "limit_exceeded" in self._data["response"] and self._data["response"]["limit_exceeded"] is True
"response" in self._data
and "limit_exceeded" in self._data["response"]
and self._data["response"]["limit_exceeded"] is True
):
limit_exceeded, reason = True, self._data["response"]["message"]
elif "response" in self._data and "limit_exceeded" in self._data:
Expand All @@ -163,12 +175,12 @@ def check_limit_exceeded(self):

@property
def status(self):
if not getattr(self, "_status", None):
if not getattr(self, "_status", None) and not self.product in RTTF_PRODUCTS_LIST:
self._status = self._get_results().status_code

return self._status

def setStatus(self, code, response=None):
def setStatus(self, code, response=None, reason_text=None):
self._status = code
if code == 200 or (self.product in RTTF_PRODUCTS_LIST and code == 206):
return
Expand All @@ -181,6 +193,9 @@ def setStatus(self, code, response=None):
reason = response.text
if callable(reason):
reason = reason()
else: # optionally pass a customize reason of error for better traceback
if reason_text is not None:
reason = reason_text

if code in (400, 422):
raise BadRequestException(code, reason)
Expand Down Expand Up @@ -330,4 +345,8 @@ def as_list(self):
return "\n".join([json.dumps(item, indent=4, separators=(",", ": ")) for item in self._items()])

def __str__(self):
return str(json.dumps(self.data(), indent=4, separators=(",", ": ")) if self.kwargs.get("format", "json") == "json" else self.data())
return str(
json.dumps(self.data(), indent=4, separators=(",", ": "))
if self.kwargs.get("format", "json") == "json"
else self.data()
)
32 changes: 23 additions & 9 deletions domaintools/cli/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ def validate_after_or_before_input(value: str):
datetime.strptime(value, "%Y-%m-%dT%H:%M:%SZ")
return value
except:
raise typer.BadParameter(f"{value} is neither an integer or a valid ISO 8601 datetime string in UTC form")
raise typer.BadParameter(
f"{value} is neither an integer or a valid ISO 8601 datetime string in UTC form"
)

@staticmethod
def validate_source_file_extension(value: str):
Expand All @@ -78,7 +80,9 @@ def validate_source_file_extension(value: str):
ext = get_file_extension(value)

if ext.lower() not in VALID_EXTENSIONS:
raise typer.BadParameter(f"{value} is not in valid extensions. Valid file extensions: {VALID_EXTENSIONS}")
raise typer.BadParameter(
f"{value} is not in valid extensions. Valid file extensions: {VALID_EXTENSIONS}"
)

return value

Expand Down Expand Up @@ -111,7 +115,7 @@ def _get_formatted_output(cls, cmd_name: str, response, out_format: str = "json"
if cmd_name in ("available_api_calls",):
return "\n".join(response)
if response.product in RTTF_PRODUCTS_LIST:
return "\n".join([data for data in response.response()])
pass # do nothing
return str(getattr(response, out_format) if out_format != "list" else response.as_list())

@classmethod
Expand Down Expand Up @@ -203,7 +207,7 @@ def run(cls, name: str, params: Optional[Dict] = {}, **kwargs):
transient=True,
) as progress:

progress.add_task(
task_id = progress.add_task(
description=f"Using api credentials with a username of: [cyan]{user}[/cyan]\nExecuting [green]{name}[/green] api call...",
total=None,
)
Expand All @@ -222,23 +226,33 @@ def run(cls, name: str, params: Optional[Dict] = {}, **kwargs):
params = params | kwargs

response = dt_api_func(**params)
progress.add_task(
progress.update(
task_id,
description=f"Preparing results with format of {response_format}...",
total=None,
)

output = cls._get_formatted_output(cmd_name=name, response=response, out_format=response_format)
output = cls._get_formatted_output(
cmd_name=name, response=response, out_format=response_format
)

if isinstance(out_file, _io.TextIOWrapper):
progress.update(
task_id,
description=f"Printing the results with format of {response_format}...",
)
# use rich `print` command to prettify the ouput in sys.stdout
if response.product in RTTF_PRODUCTS_LIST:
print(output)
for feeds in response.response():
print(feeds)
else:
print(response)
else:
progress.update(
task_id,
description=f"Writing results to {out_file}",
)
# if it's a file then write
out_file.write(output if output.endswith("\n") else output + "\n")
time.sleep(0.25)
except Exception as e:
if isinstance(e, ServiceException):
code = typer.style(getattr(e, "code", 400), fg=typer.colors.BRIGHT_RED)
Expand Down
Loading