Skip to content

Commit bb0e087

Browse files
authored
Merge pull request #146 from DomainTools/IDEV-2020-handle-large-response-feeds
IDEV-2020: Handle partial response from RTUF endpoints.
2 parents 763a9b1 + ef7619d commit bb0e087

File tree

7 files changed

+44215
-99
lines changed

7 files changed

+44215
-99
lines changed

domaintools/base_results.py

Lines changed: 58 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,53 @@ def _wait_time(self):
7777

7878
return wait_for
7979

80+
def _get_feeds_results_generator(self, parameters, headers):
81+
with Client(verify=self.api.verify_ssl, proxy=self.api.proxy_url, timeout=None) as session:
82+
status_code = None
83+
while status_code != 200:
84+
resp_data = session.get(url=self.url, params=parameters, headers=headers, **self.api.extra_request_params)
85+
status_code = resp_data.status_code
86+
self.setStatus(status_code, resp_data)
87+
88+
# Check limit exceeded here
89+
if "response" in resp_data.text and "limit_exceeded" in resp_data.text:
90+
self._limit_exceeded = True
91+
self._limit_exceeded_message = "limit exceeded"
92+
93+
yield resp_data
94+
95+
if self._limit_exceeded:
96+
raise ServiceException(503, "Limit Exceeded{}".format(self._limit_exceeded_message))
97+
98+
if not self.kwargs.get("sessionID"):
99+
# we'll only do iterative request for queries that has sessionID.
100+
# Otherwise, we will have an infinite request if sessionID was not provided but the required data asked is more than the maximum (1 hour of data)
101+
break
102+
103+
def _get_session_params(self):
104+
parameters = deepcopy(self.kwargs)
105+
parameters.pop("output_format", None)
106+
parameters.pop(
107+
"format", None
108+
) # For some unknownn reasons, even if "format" is not included in the cli params for feeds endpoint, it is being populated thus we need to remove it. Happens only if using CLI.
109+
headers = {}
110+
if self.kwargs.get("output_format", OutputFormat.JSONL.value) == OutputFormat.CSV.value:
111+
parameters["headers"] = int(bool(self.kwargs.get("headers", False)))
112+
headers["accept"] = HEADER_ACCEPT_KEY_CSV_FORMAT
113+
114+
header_api_key = parameters.pop("X-Api-Key", None)
115+
if header_api_key:
116+
headers["X-Api-Key"] = header_api_key
117+
118+
return {"parameters": parameters, "headers": headers}
119+
80120
def _make_request(self):
121+
if self.product in FEEDS_PRODUCTS_LIST:
122+
session_params = self._get_session_params()
123+
parameters = session_params.get("parameters")
124+
headers = session_params.get("headers")
125+
126+
return self._get_feeds_results_generator(parameters=parameters, headers=headers)
81127

82128
with Client(verify=self.api.verify_ssl, proxy=self.api.proxy_url, timeout=None) as session:
83129
if self.product in [
@@ -92,30 +138,15 @@ def _make_request(self):
92138
patch_data = self.kwargs.copy()
93139
patch_data.update(self.api.extra_request_params)
94140
return session.patch(url=self.url, json=patch_data)
95-
elif self.product in FEEDS_PRODUCTS_LIST:
96-
parameters = deepcopy(self.kwargs)
97-
parameters.pop("output_format", None)
98-
parameters.pop(
99-
"format", None
100-
) # For some unknownn reasons, even if "format" is not included in the cli params for feeds endpoint, it is being populated thus we need to remove it. Happens only if using CLI.
101-
headers = {}
102-
if self.kwargs.get("output_format", OutputFormat.JSONL.value) == OutputFormat.CSV.value:
103-
parameters["headers"] = int(bool(self.kwargs.get("headers", False)))
104-
headers["accept"] = HEADER_ACCEPT_KEY_CSV_FORMAT
105-
106-
header_api_key = parameters.pop("X-Api-Key", None)
107-
if header_api_key:
108-
headers["X-Api-Key"] = header_api_key
109-
110-
return session.get(url=self.url, params=parameters, headers=headers, **self.api.extra_request_params)
111141
else:
112142
return session.get(url=self.url, params=self.kwargs, **self.api.extra_request_params)
113143

114144
def _get_results(self):
115145
wait_for = self._wait_time()
116146
if self.api.rate_limit and (wait_for is None or self.product == "account-information"):
117147
data = self._make_request()
118-
if data.status_code == 503: # pragma: no cover
148+
status_code = data.status_code if self.product not in FEEDS_PRODUCTS_LIST else 200
149+
if status_code == 503: # pragma: no cover
119150
sleeptime = 60
120151
log.info(
121152
"503 encountered for [%s] - sleeping [%s] seconds before retrying request.",
@@ -135,12 +166,15 @@ def _get_results(self):
135166
def data(self):
136167
if self._data is None:
137168
results = self._get_results()
138-
self.setStatus(results.status_code, results)
169+
status_code = results.status_code if self.product not in FEEDS_PRODUCTS_LIST else 200
170+
self.setStatus(status_code, results)
139171
if (
140172
self.kwargs.get("format", "json") == "json"
141173
and self.product not in FEEDS_PRODUCTS_LIST # Special handling of feeds products' data to preserve the result in jsonline format
142174
):
143175
self._data = results.json()
176+
elif self.product in FEEDS_PRODUCTS_LIST:
177+
self._data = results # Uses generator to handle large data results from feeds endpoint
144178
else:
145179
self._data = results.text
146180
limit_exceeded, message = self.check_limit_exceeded()
@@ -155,6 +189,10 @@ def data(self):
155189
return self._data
156190

157191
def check_limit_exceeded(self):
192+
if self.product in FEEDS_PRODUCTS_LIST:
193+
# bypass here as this is handled in generator already
194+
return False, ""
195+
158196
if self.kwargs.get("format", "json") == "json" and self.product not in FEEDS_PRODUCTS_LIST:
159197
if "response" in self._data and "limit_exceeded" in self._data["response"] and self._data["response"]["limit_exceeded"] is True:
160198
return True, self._data["response"]["message"]
@@ -172,7 +210,7 @@ def status(self):
172210

173211
def setStatus(self, code, response=None):
174212
self._status = code
175-
if code == 200:
213+
if code == 200 or (self.product in FEEDS_PRODUCTS_LIST and code == 206):
176214
return
177215

178216
reason = None
@@ -211,7 +249,7 @@ def response(self):
211249
return self._response
212250

213251
def items(self):
214-
return self.response().items()
252+
return self.response().items() if isinstance(self.response(), dict) else self.response()
215253

216254
def emails(self):
217255
"""Find and returns all emails mentioned in the response"""

domaintools/cli/api.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from rich.progress import Progress, SpinnerColumn, TextColumn
1010

1111
from domaintools.api import API
12-
from domaintools.constants import Endpoint, OutputFormat
12+
from domaintools.constants import Endpoint, OutputFormat, FEEDS_PRODUCTS_LIST
1313
from domaintools.cli.utils import get_file_extension
1414
from domaintools.exceptions import ServiceException
1515
from domaintools._version import current as version
@@ -110,6 +110,9 @@ def args_to_dict(*args) -> Dict:
110110
def _get_formatted_output(cls, cmd_name: str, response, out_format: str = "json"):
111111
if cmd_name in ("available_api_calls",):
112112
return "\n".join(response)
113+
if response.product in FEEDS_PRODUCTS_LIST:
114+
return "\n".join([data.text for data in response])
115+
113116
return str(getattr(response, out_format) if out_format != "list" else response.as_list())
114117

115118
@classmethod
@@ -227,7 +230,10 @@ def run(cls, name: str, params: Optional[Dict] = {}, **kwargs):
227230

228231
if isinstance(out_file, _io.TextIOWrapper):
229232
# use rich `print` command to prettify the ouput in sys.stdout
230-
print(response)
233+
if response.product in FEEDS_PRODUCTS_LIST:
234+
print(output)
235+
else:
236+
print(response)
231237
else:
232238
# if it's a file then write
233239
out_file.write(output if output.endswith("\n") else output + "\n")

domaintools_async/__init__.py

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@
22

33
import asyncio
44

5-
from copy import deepcopy
65
from httpx import AsyncClient
76

87
from domaintools.base_results import Results
9-
from domaintools.constants import FEEDS_PRODUCTS_LIST, OutputFormat, HEADER_ACCEPT_KEY_CSV_FORMAT
10-
from domaintools.exceptions import ServiceUnavailableException
8+
from domaintools.constants import FEEDS_PRODUCTS_LIST
9+
from domaintools.exceptions import ServiceUnavailableException, ServiceException
1110

1211

1312
class _AIter(object):
@@ -42,6 +41,26 @@ class AsyncResults(Results):
4241
def __await__(self):
4342
return self.__awaitable__().__await__()
4443

44+
async def _get_feeds_async_results_generator(self, session, parameters, headers):
45+
status_code = None
46+
while status_code != 200:
47+
resp_data = await session.get(url=self.url, params=parameters, headers=headers, **self.api.extra_request_params)
48+
status_code = resp_data.status_code
49+
self.setStatus(status_code, resp_data)
50+
51+
# Check limit exceeded here
52+
if "response" in resp_data.text and "limit_exceeded" in resp_data.text:
53+
self._limit_exceeded = True
54+
self._limit_exceeded_message = "limit exceeded"
55+
yield resp_data
56+
57+
if self._limit_exceeded:
58+
raise ServiceException(503, "Limit Exceeded{}".format(self._limit_exceeded_message))
59+
if not self.kwargs.get("sessionID"):
60+
# we'll only do iterative request for queries that has sessionID.
61+
# Otherwise, we will have an infinite request if sessionID was not provided but the required data asked is more than the maximum (1 hour of data)
62+
break
63+
4564
async def _make_async_request(self, session):
4665
if self.product in ["iris-investigate", "iris-enrich", "iris-detect-escalate-domains"]:
4766
post_data = self.kwargs.copy()
@@ -52,27 +71,19 @@ async def _make_async_request(self, session):
5271
patch_data.update(self.api.extra_request_params)
5372
results = await session.patch(url=self.url, json=patch_data)
5473
elif self.product in FEEDS_PRODUCTS_LIST:
55-
parameters = deepcopy(self.kwargs)
56-
parameters.pop("output_format", None)
57-
parameters.pop(
58-
"format", None
59-
) # For some unknownn reasons, even if "format" is not included in the cli params for feeds endpoint, it is being populated thus we need to remove it. Happens only if using CLI.
60-
headers = {}
61-
if self.kwargs.get("output_format", OutputFormat.JSONL.value) == OutputFormat.CSV.value:
62-
parameters["headers"] = int(bool(self.kwargs.get("headers", False)))
63-
headers["accept"] = HEADER_ACCEPT_KEY_CSV_FORMAT
64-
65-
header_api_key = parameters.pop("X-Api-Key", None)
66-
if header_api_key:
67-
headers["X-Api-Key"] = header_api_key
68-
69-
results = await session.get(url=self.url, params=parameters, headers=headers, **self.api.extra_request_params)
74+
generator_params = self._get_session_params()
75+
parameters = generator_params.get("parameters")
76+
headers = generator_params.get("headers")
77+
results = await self._get_feeds_async_results_generator(session=session, parameters=parameters, headers=headers)
7078
else:
7179
results = await session.get(url=self.url, params=self.kwargs, **self.api.extra_request_params)
7280
if results:
73-
self.setStatus(results.status_code, results)
81+
status_code = results.status_code if self.product not in FEEDS_PRODUCTS_LIST else 200
82+
self.setStatus(status_code, results)
7483
if self.kwargs.get("format", "json") == "json":
7584
self._data = results.json()
85+
elif self.product in FEEDS_PRODUCTS_LIST:
86+
self._data = results # Uses generator to handle large data results from feeds endpoint
7687
else:
7788
self._data = results.text()
7889
limit_exceeded, message = self.check_limit_exceeded()
@@ -83,7 +94,6 @@ async def _make_async_request(self, session):
8394

8495
async def __awaitable__(self):
8596
if self._data is None:
86-
8797
async with AsyncClient(verify=self.api.verify_ssl, proxy=self.api.proxy_url, timeout=None) as session:
8898
wait_time = self._wait_time()
8999
if wait_time is None and self.api:

0 commit comments

Comments
 (0)