Skip to content

Commit 17dcbf9

Browse files
committed
Implement streaming request for RTTF endpoints
1 parent fd36d96 commit 17dcbf9

File tree

5 files changed

+272
-58
lines changed

5 files changed

+272
-58
lines changed

domaintools/api.py

Lines changed: 60 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,9 @@ def __init__(
9292
self._build_api_url(api_url, api_port)
9393

9494
if not https:
95-
raise Exception("The DomainTools API endpoints no longer support http traffic. Please make sure https=True.")
95+
raise Exception(
96+
"The DomainTools API endpoints no longer support http traffic. Please make sure https=True."
97+
)
9698
if proxy_url and not isinstance(proxy_url, str):
9799
raise Exception("Proxy URL must be a string. For example: '127.0.0.1:8888'")
98100

@@ -129,7 +131,8 @@ def _results(self, product, path, cls=Results, **kwargs):
129131
always_sign_api_key_previous_value = self.always_sign_api_key
130132
header_authentication_previous_value = self.header_authentication
131133
self._rate_limit()
132-
# Reset always_sign_api_key and header_authentication to its original User-set values as these might be affected when self.account_information() was executed
134+
# Reset always_sign_api_key and header_authentication to its original
135+
# User-set values as these might be affected when self.account_information() was executed
133136
self.always_sign_api_key = always_sign_api_key_previous_value
134137
self.header_authentication = header_authentication_previous_value
135138

@@ -139,7 +142,13 @@ def _results(self, product, path, cls=Results, **kwargs):
139142
is_rttf_product = product in RTTF_PRODUCTS_LIST
140143
self._handle_api_key_parameters(is_rttf_product)
141144
self.handle_api_key(is_rttf_product, path, parameters)
142-
parameters.update({key: str(value).lower() if value in (True, False) else value for key, value in kwargs.items() if value is not None})
145+
parameters.update(
146+
{
147+
key: str(value).lower() if value in (True, False) else value
148+
for key, value in kwargs.items()
149+
if value is not None
150+
}
151+
)
143152

144153
return cls(self, product, uri, **parameters)
145154

@@ -189,8 +198,20 @@ def snakecase(string):
189198
string[1:],
190199
)
191200

192-
api_calls = tuple((api_call for api_call in dir(API) if not api_call.startswith("_") and callable(getattr(API, api_call, None))))
193-
return sorted([snakecase(p["id"]) for p in self.account_information()["products"] if snakecase(p["id"]) in api_calls])
201+
api_calls = tuple(
202+
(
203+
api_call
204+
for api_call in dir(API)
205+
if not api_call.startswith("_") and callable(getattr(API, api_call, None))
206+
)
207+
)
208+
return sorted(
209+
[
210+
snakecase(p["id"])
211+
for p in self.account_information()["products"]
212+
if snakecase(p["id"]) in api_calls
213+
]
214+
)
194215

195216
def brand_monitor(self, query, exclude=None, domain_status=None, days_back=None, **kwargs):
196217
"""Pass in one or more terms as a list or separated by the pipe character ( | )"""
@@ -445,7 +466,16 @@ def iris(
445466
"""Performs a search for the provided search terms ANDed together,
446467
returning the pivot engine row data for the resulting domains.
447468
"""
448-
if not domain and not ip and not email and not nameserver and not registrar and not registrant and not registrant_org and not kwargs:
469+
if (
470+
not domain
471+
and not ip
472+
and not email
473+
and not nameserver
474+
and not registrar
475+
and not registrant
476+
and not registrant_org
477+
and not kwargs
478+
):
449479
raise ValueError("At least one search term must be specified")
450480

451481
return self._results(
@@ -1069,7 +1099,10 @@ def nod(self, **kwargs) -> FeedsResults:
10691099
validate_feeds_parameters(kwargs)
10701100
endpoint = kwargs.pop("endpoint", Endpoint.FEED.value)
10711101
source = ENDPOINT_TO_SOURCE_MAP.get(endpoint)
1072-
if endpoint == Endpoint.DOWNLOAD.value or kwargs.get("output_format", OutputFormat.JSONL.value) != OutputFormat.CSV.value:
1102+
if (
1103+
endpoint == Endpoint.DOWNLOAD.value
1104+
or kwargs.get("output_format", OutputFormat.JSONL.value) != OutputFormat.CSV.value
1105+
):
10731106
# headers param is allowed only in Feed API and CSV format
10741107
kwargs.pop("headers", None)
10751108

@@ -1101,7 +1134,10 @@ def nad(self, **kwargs) -> FeedsResults:
11011134
validate_feeds_parameters(kwargs)
11021135
endpoint = kwargs.pop("endpoint", Endpoint.FEED.value)
11031136
source = ENDPOINT_TO_SOURCE_MAP.get(endpoint).value
1104-
if endpoint == Endpoint.DOWNLOAD.value or kwargs.get("output_format", OutputFormat.JSONL.value) != OutputFormat.CSV.value:
1137+
if (
1138+
endpoint == Endpoint.DOWNLOAD.value
1139+
or kwargs.get("output_format", OutputFormat.JSONL.value) != OutputFormat.CSV.value
1140+
):
11051141
# headers param is allowed only in Feed API and CSV format
11061142
kwargs.pop("headers", None)
11071143

@@ -1162,7 +1198,10 @@ def domaindiscovery(self, **kwargs) -> FeedsResults:
11621198
validate_feeds_parameters(kwargs)
11631199
endpoint = kwargs.pop("endpoint", Endpoint.FEED.value)
11641200
source = ENDPOINT_TO_SOURCE_MAP.get(endpoint).value
1165-
if endpoint == Endpoint.DOWNLOAD.value or kwargs.get("output_format", OutputFormat.JSONL.value) != OutputFormat.CSV.value:
1201+
if (
1202+
endpoint == Endpoint.DOWNLOAD.value
1203+
or kwargs.get("output_format", OutputFormat.JSONL.value) != OutputFormat.CSV.value
1204+
):
11661205
# headers param is allowed only in Feed API and CSV format
11671206
kwargs.pop("headers", None)
11681207

@@ -1194,7 +1233,10 @@ def noh(self, **kwargs) -> FeedsResults:
11941233
validate_feeds_parameters(kwargs)
11951234
endpoint = kwargs.pop("endpoint", Endpoint.FEED.value)
11961235
source = ENDPOINT_TO_SOURCE_MAP.get(endpoint).value
1197-
if endpoint == Endpoint.DOWNLOAD.value or kwargs.get("output_format", OutputFormat.JSONL.value) != OutputFormat.CSV.value:
1236+
if (
1237+
endpoint == Endpoint.DOWNLOAD.value
1238+
or kwargs.get("output_format", OutputFormat.JSONL.value) != OutputFormat.CSV.value
1239+
):
11981240
# headers param is allowed only in Feed API and CSV format
11991241
kwargs.pop("headers", None)
12001242

@@ -1225,7 +1267,10 @@ def realtime_domain_risk(self, **kwargs) -> FeedsResults:
12251267
validate_feeds_parameters(kwargs)
12261268
endpoint = kwargs.pop("endpoint", Endpoint.FEED.value)
12271269
source = ENDPOINT_TO_SOURCE_MAP.get(endpoint).value
1228-
if endpoint == Endpoint.DOWNLOAD.value or kwargs.get("output_format", OutputFormat.JSONL.value) != OutputFormat.CSV.value:
1270+
if (
1271+
endpoint == Endpoint.DOWNLOAD.value
1272+
or kwargs.get("output_format", OutputFormat.JSONL.value) != OutputFormat.CSV.value
1273+
):
12291274
# headers param is allowed only in Feed API and CSV format
12301275
kwargs.pop("headers", None)
12311276

@@ -1256,7 +1301,10 @@ def domainhotlist(self, **kwargs) -> FeedsResults:
12561301
validate_feeds_parameters(kwargs)
12571302
endpoint = kwargs.pop("endpoint", Endpoint.FEED.value)
12581303
source = ENDPOINT_TO_SOURCE_MAP.get(endpoint).value
1259-
if endpoint == Endpoint.DOWNLOAD.value or kwargs.get("output_format", OutputFormat.JSONL.value) != OutputFormat.CSV.value:
1304+
if (
1305+
endpoint == Endpoint.DOWNLOAD.value
1306+
or kwargs.get("output_format", OutputFormat.JSONL.value) != OutputFormat.CSV.value
1307+
):
12601308
# headers param is allowed only in Feed API and CSV format
12611309
kwargs.pop("headers", None)
12621310

domaintools/base_results.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@
99
from datetime import datetime
1010
from httpx import Client
1111

12-
from domaintools.constants import RTTF_PRODUCTS_LIST, OutputFormat, HEADER_ACCEPT_KEY_CSV_FORMAT
12+
from domaintools.constants import (
13+
RTTF_PRODUCTS_LIST,
14+
OutputFormat,
15+
HEADER_ACCEPT_KEY_CSV_FORMAT,
16+
)
1317
from domaintools.exceptions import (
1418
BadRequestException,
1519
InternalServerErrorException,
@@ -92,10 +96,10 @@ def _get_session_params_and_headers(self):
9296
header_key_for_api_key = "X-Api-Key" if is_rttf_product else "X-API-Key"
9397
headers[header_key_for_api_key] = self.api.key
9498

95-
return {"parameters": parameters, "headers": headers}
99+
session_param_and_headers = {"parameters": parameters, "headers": headers}
100+
return session_param_and_headers
96101

97102
def _make_request(self):
98-
99103
with Client(verify=self.api.verify_ssl, proxy=self.api.proxy_url, timeout=None) as session:
100104
session_params_and_headers = self._get_session_params_and_headers()
101105
headers = session_params_and_headers.get("headers")
@@ -113,7 +117,12 @@ def _make_request(self):
113117
return session.patch(url=self.url, json=patch_data, headers=headers)
114118
else:
115119
parameters = session_params_and_headers.get("parameters")
116-
return session.get(url=self.url, params=parameters, headers=headers, **self.api.extra_request_params)
120+
return session.get(
121+
url=self.url,
122+
params=parameters,
123+
headers=headers,
124+
**self.api.extra_request_params,
125+
)
117126

118127
def _get_results(self):
119128
wait_for = self._wait_time()
@@ -152,7 +161,9 @@ def data(self):
152161
def check_limit_exceeded(self):
153162
limit_exceeded, reason = False, ""
154163
if isinstance(self._data, dict) and (
155-
"response" in self._data and "limit_exceeded" in self._data["response"] and self._data["response"]["limit_exceeded"] is True
164+
"response" in self._data
165+
and "limit_exceeded" in self._data["response"]
166+
and self._data["response"]["limit_exceeded"] is True
156167
):
157168
limit_exceeded, reason = True, self._data["response"]["message"]
158169
elif "response" in self._data and "limit_exceeded" in self._data:
@@ -168,7 +179,7 @@ def status(self):
168179

169180
return self._status
170181

171-
def setStatus(self, code, response=None):
182+
def setStatus(self, code, response=None, reason_text=None):
172183
self._status = code
173184
if code == 200 or (self.product in RTTF_PRODUCTS_LIST and code == 206):
174185
return
@@ -181,6 +192,9 @@ def setStatus(self, code, response=None):
181192
reason = response.text
182193
if callable(reason):
183194
reason = reason()
195+
else: # optionally pass a customize reason of error for better traceback
196+
if reason_text is not None:
197+
reason = reason_text
184198

185199
if code in (400, 422):
186200
raise BadRequestException(code, reason)
@@ -330,4 +344,8 @@ def as_list(self):
330344
return "\n".join([json.dumps(item, indent=4, separators=(",", ": ")) for item in self._items()])
331345

332346
def __str__(self):
333-
return str(json.dumps(self.data(), indent=4, separators=(",", ": ")) if self.kwargs.get("format", "json") == "json" else self.data())
347+
return str(
348+
json.dumps(self.data(), indent=4, separators=(",", ": "))
349+
if self.kwargs.get("format", "json") == "json"
350+
else self.data()
351+
)

domaintools/cli/api.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,9 @@ def validate_after_or_before_input(value: str):
5858
datetime.strptime(value, "%Y-%m-%dT%H:%M:%SZ")
5959
return value
6060
except:
61-
raise typer.BadParameter(f"{value} is neither an integer or a valid ISO 8601 datetime string in UTC form")
61+
raise typer.BadParameter(
62+
f"{value} is neither an integer or a valid ISO 8601 datetime string in UTC form"
63+
)
6264

6365
@staticmethod
6466
def validate_source_file_extension(value: str):
@@ -78,7 +80,9 @@ def validate_source_file_extension(value: str):
7880
ext = get_file_extension(value)
7981

8082
if ext.lower() not in VALID_EXTENSIONS:
81-
raise typer.BadParameter(f"{value} is not in valid extensions. Valid file extensions: {VALID_EXTENSIONS}")
83+
raise typer.BadParameter(
84+
f"{value} is not in valid extensions. Valid file extensions: {VALID_EXTENSIONS}"
85+
)
8286

8387
return value
8488

@@ -111,7 +115,7 @@ def _get_formatted_output(cls, cmd_name: str, response, out_format: str = "json"
111115
if cmd_name in ("available_api_calls",):
112116
return "\n".join(response)
113117
if response.product in RTTF_PRODUCTS_LIST:
114-
return "\n".join([data for data in response.response()])
118+
pass # do nothing
115119
return str(getattr(response, out_format) if out_format != "list" else response.as_list())
116120

117121
@classmethod
@@ -203,7 +207,7 @@ def run(cls, name: str, params: Optional[Dict] = {}, **kwargs):
203207
transient=True,
204208
) as progress:
205209

206-
progress.add_task(
210+
task_id = progress.add_task(
207211
description=f"Using api credentials with a username of: [cyan]{user}[/cyan]\nExecuting [green]{name}[/green] api call...",
208212
total=None,
209213
)
@@ -222,23 +226,33 @@ def run(cls, name: str, params: Optional[Dict] = {}, **kwargs):
222226
params = params | kwargs
223227

224228
response = dt_api_func(**params)
225-
progress.add_task(
229+
progress.update(
230+
task_id,
226231
description=f"Preparing results with format of {response_format}...",
227-
total=None,
228232
)
229233

230-
output = cls._get_formatted_output(cmd_name=name, response=response, out_format=response_format)
234+
output = cls._get_formatted_output(
235+
cmd_name=name, response=response, out_format=response_format
236+
)
231237

232238
if isinstance(out_file, _io.TextIOWrapper):
239+
progress.update(
240+
task_id,
241+
description=f"Printing the results with format of {response_format}...",
242+
)
233243
# use rich `print` command to prettify the ouput in sys.stdout
234244
if response.product in RTTF_PRODUCTS_LIST:
235-
print(output)
245+
for feeds in response.response():
246+
print(feeds)
236247
else:
237248
print(response)
238249
else:
250+
progress.update(
251+
task_id,
252+
description=f"Writing results to {out_file}",
253+
)
239254
# if it's a file then write
240255
out_file.write(output if output.endswith("\n") else output + "\n")
241-
time.sleep(0.25)
242256
except Exception as e:
243257
if isinstance(e, ServiceException):
244258
code = typer.style(getattr(e, "code", 400), fg=typer.colors.BRIGHT_RED)

0 commit comments

Comments
 (0)