Skip to content

Commit 35d4600

Browse files
committed
IDEV-2020: Implement handling of iterative HTTP 206 response from RTUF.
1 parent 763a9b1 commit 35d4600

File tree

3 files changed

+98
-42
lines changed

3 files changed

+98
-42
lines changed

domaintools/base_results.py

Lines changed: 60 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,53 @@ def _wait_time(self):
7777

7878
return wait_for
7979

80+
def _get_feeds_results_generator(self, parameters, headers):
81+
with Client(verify=self.api.verify_ssl, proxy=self.api.proxy_url, timeout=None) as session:
82+
status_code = None
83+
while status_code != 200:
84+
resp_data = session.get(url=self.url, params=parameters, headers=headers, **self.api.extra_request_params)
85+
status_code = resp_data.status_code
86+
self.setStatus(status_code, resp_data)
87+
88+
# Check limit exceeded here
89+
if "response" in resp_data.text and "limit_exceeded" in resp_data.text:
90+
self._limit_exceeded = True
91+
self._limit_exceeded_message = "limit exceeded"
92+
93+
yield resp_data
94+
95+
if self._limit_exceeded:
96+
raise ServiceException(503, "Limit Exceeded{}".format(self._limit_exceeded_message))
97+
98+
if not self.kwargs.get("sessionID"):
99+
# we'll only do iterative request for queries that has sessionID.
100+
# Otherwise, we will have an infinite request if sessionID was not provided but the required data asked is more than the maximum (1 hour of data)
101+
break
102+
103+
def _get_session_params(self):
104+
parameters = deepcopy(self.kwargs)
105+
parameters.pop("output_format", None)
106+
parameters.pop(
107+
"format", None
108+
) # For some unknownn reasons, even if "format" is not included in the cli params for feeds endpoint, it is being populated thus we need to remove it. Happens only if using CLI.
109+
headers = {}
110+
if self.kwargs.get("output_format", OutputFormat.JSONL.value) == OutputFormat.CSV.value:
111+
parameters["headers"] = int(bool(self.kwargs.get("headers", False)))
112+
headers["accept"] = HEADER_ACCEPT_KEY_CSV_FORMAT
113+
114+
header_api_key = parameters.pop("X-Api-Key", None)
115+
if header_api_key:
116+
headers["X-Api-Key"] = header_api_key
117+
118+
return {"parameters": parameters, "headers": headers}
119+
80120
def _make_request(self):
121+
if self.product in FEEDS_PRODUCTS_LIST:
122+
session_params = self._get_session_params()
123+
parameters = session_params.get("parameters")
124+
headers = session_params.get("headers")
125+
126+
return self._get_feeds_results_generator(parameters=parameters, headers=headers)
81127

82128
with Client(verify=self.api.verify_ssl, proxy=self.api.proxy_url, timeout=None) as session:
83129
if self.product in [
@@ -92,30 +138,16 @@ def _make_request(self):
92138
patch_data = self.kwargs.copy()
93139
patch_data.update(self.api.extra_request_params)
94140
return session.patch(url=self.url, json=patch_data)
95-
elif self.product in FEEDS_PRODUCTS_LIST:
96-
parameters = deepcopy(self.kwargs)
97-
parameters.pop("output_format", None)
98-
parameters.pop(
99-
"format", None
100-
) # For some unknownn reasons, even if "format" is not included in the cli params for feeds endpoint, it is being populated thus we need to remove it. Happens only if using CLI.
101-
headers = {}
102-
if self.kwargs.get("output_format", OutputFormat.JSONL.value) == OutputFormat.CSV.value:
103-
parameters["headers"] = int(bool(self.kwargs.get("headers", False)))
104-
headers["accept"] = HEADER_ACCEPT_KEY_CSV_FORMAT
105-
106-
header_api_key = parameters.pop("X-Api-Key", None)
107-
if header_api_key:
108-
headers["X-Api-Key"] = header_api_key
109-
110-
return session.get(url=self.url, params=parameters, headers=headers, **self.api.extra_request_params)
111141
else:
112142
return session.get(url=self.url, params=self.kwargs, **self.api.extra_request_params)
113143

114144
def _get_results(self):
115145
wait_for = self._wait_time()
146+
print(f"JD: waitfor: {wait_for}")
116147
if self.api.rate_limit and (wait_for is None or self.product == "account-information"):
117148
data = self._make_request()
118-
if data.status_code == 503: # pragma: no cover
149+
status_code = data.status_code if self.product not in FEEDS_PRODUCTS_LIST else 200
150+
if status_code == 503: # pragma: no cover
119151
sleeptime = 60
120152
log.info(
121153
"503 encountered for [%s] - sleeping [%s] seconds before retrying request.",
@@ -129,18 +161,22 @@ def _get_results(self):
129161

130162
if wait_for > 0:
131163
log.info("Sleeping for [%s] prior to requesting [%s].", wait_for, self.product)
164+
print("Sleeping for [%s] prior to requesting [%s].", wait_for, self.product)
132165
time.sleep(wait_for)
133166
return self._make_request()
134167

135168
def data(self):
136169
if self._data is None:
137170
results = self._get_results()
138-
self.setStatus(results.status_code, results)
171+
status_code = results.status_code if self.product not in FEEDS_PRODUCTS_LIST else 200
172+
self.setStatus(status_code, results)
139173
if (
140174
self.kwargs.get("format", "json") == "json"
141175
and self.product not in FEEDS_PRODUCTS_LIST # Special handling of feeds products' data to preserve the result in jsonline format
142176
):
143177
self._data = results.json()
178+
elif self.product in FEEDS_PRODUCTS_LIST:
179+
self._data = results # Uses generator to handle large data results from feeds endpoint
144180
else:
145181
self._data = results.text
146182
limit_exceeded, message = self.check_limit_exceeded()
@@ -155,6 +191,10 @@ def data(self):
155191
return self._data
156192

157193
def check_limit_exceeded(self):
194+
if self.product in FEEDS_PRODUCTS_LIST:
195+
# bypass here as this is handled in generator already
196+
return False, ""
197+
158198
if self.kwargs.get("format", "json") == "json" and self.product not in FEEDS_PRODUCTS_LIST:
159199
if "response" in self._data and "limit_exceeded" in self._data["response"] and self._data["response"]["limit_exceeded"] is True:
160200
return True, self._data["response"]["message"]
@@ -172,7 +212,7 @@ def status(self):
172212

173213
def setStatus(self, code, response=None):
174214
self._status = code
175-
if code == 200:
215+
if code == 200 or code == 206:
176216
return
177217

178218
reason = None
@@ -211,7 +251,7 @@ def response(self):
211251
return self._response
212252

213253
def items(self):
214-
return self.response().items()
254+
return self.response().items() if isinstance(self.response(), dict) else self.response()
215255

216256
def emails(self):
217257
"""Find and returns all emails mentioned in the response"""

domaintools/cli/api.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from rich.progress import Progress, SpinnerColumn, TextColumn
1010

1111
from domaintools.api import API
12-
from domaintools.constants import Endpoint, OutputFormat
12+
from domaintools.constants import Endpoint, OutputFormat, FEEDS_PRODUCTS_LIST
1313
from domaintools.cli.utils import get_file_extension
1414
from domaintools.exceptions import ServiceException
1515
from domaintools._version import current as version
@@ -110,6 +110,9 @@ def args_to_dict(*args) -> Dict:
110110
def _get_formatted_output(cls, cmd_name: str, response, out_format: str = "json"):
111111
if cmd_name in ("available_api_calls",):
112112
return "\n".join(response)
113+
if response.product in FEEDS_PRODUCTS_LIST:
114+
return "\n".join([data.text for data in response])
115+
113116
return str(getattr(response, out_format) if out_format != "list" else response.as_list())
114117

115118
@classmethod
@@ -227,7 +230,10 @@ def run(cls, name: str, params: Optional[Dict] = {}, **kwargs):
227230

228231
if isinstance(out_file, _io.TextIOWrapper):
229232
# use rich `print` command to prettify the ouput in sys.stdout
230-
print(response)
233+
if response.product in FEEDS_PRODUCTS_LIST:
234+
print(output)
235+
else:
236+
print(response)
231237
else:
232238
# if it's a file then write
233239
out_file.write(output if output.endswith("\n") else output + "\n")

domaintools_async/__init__.py

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@
22

33
import asyncio
44

5-
from copy import deepcopy
65
from httpx import AsyncClient
76

87
from domaintools.base_results import Results
9-
from domaintools.constants import FEEDS_PRODUCTS_LIST, OutputFormat, HEADER_ACCEPT_KEY_CSV_FORMAT
10-
from domaintools.exceptions import ServiceUnavailableException
8+
from domaintools.constants import FEEDS_PRODUCTS_LIST
9+
from domaintools.exceptions import ServiceUnavailableException, ServiceException
1110

1211

1312
class _AIter(object):
@@ -42,6 +41,26 @@ class AsyncResults(Results):
4241
def __await__(self):
4342
return self.__awaitable__().__await__()
4443

44+
async def _get_feeds_async_results_generator(self, session, parameters, headers):
45+
status_code = None
46+
while status_code != 200:
47+
resp_data = await session.get(url=self.url, params=parameters, headers=headers, **self.api.extra_request_params)
48+
status_code = resp_data.status_code
49+
self.setStatus(status_code, resp_data)
50+
51+
# Check limit exceeded here
52+
if "response" in resp_data.text and "limit_exceeded" in resp_data.text:
53+
self._limit_exceeded = True
54+
self._limit_exceeded_message = "limit exceeded"
55+
yield resp_data
56+
57+
if self._limit_exceeded:
58+
raise ServiceException(503, "Limit Exceeded{}".format(self._limit_exceeded_message))
59+
if not self.kwargs.get("sessionID"):
60+
# we'll only do iterative request for queries that has sessionID.
61+
# Otherwise, we will have an infinite request if sessionID was not provided but the required data asked is more than the maximum (1 hour of data)
62+
break
63+
4564
async def _make_async_request(self, session):
4665
if self.product in ["iris-investigate", "iris-enrich", "iris-detect-escalate-domains"]:
4766
post_data = self.kwargs.copy()
@@ -52,27 +71,19 @@ async def _make_async_request(self, session):
5271
patch_data.update(self.api.extra_request_params)
5372
results = await session.patch(url=self.url, json=patch_data)
5473
elif self.product in FEEDS_PRODUCTS_LIST:
55-
parameters = deepcopy(self.kwargs)
56-
parameters.pop("output_format", None)
57-
parameters.pop(
58-
"format", None
59-
) # For some unknownn reasons, even if "format" is not included in the cli params for feeds endpoint, it is being populated thus we need to remove it. Happens only if using CLI.
60-
headers = {}
61-
if self.kwargs.get("output_format", OutputFormat.JSONL.value) == OutputFormat.CSV.value:
62-
parameters["headers"] = int(bool(self.kwargs.get("headers", False)))
63-
headers["accept"] = HEADER_ACCEPT_KEY_CSV_FORMAT
64-
65-
header_api_key = parameters.pop("X-Api-Key", None)
66-
if header_api_key:
67-
headers["X-Api-Key"] = header_api_key
68-
69-
results = await session.get(url=self.url, params=parameters, headers=headers, **self.api.extra_request_params)
74+
generator_params = self._get_session_params()
75+
parameters = generator_params.get("parameters")
76+
headers = generator_params.get("headers")
77+
results = await self._get_feeds_async_results_generator(session=session, parameters=parameters, headers=headers)
7078
else:
7179
results = await session.get(url=self.url, params=self.kwargs, **self.api.extra_request_params)
7280
if results:
73-
self.setStatus(results.status_code, results)
81+
status_code = results.status_code if self.product not in FEEDS_PRODUCTS_LIST else 200
82+
self.setStatus(status_code, results)
7483
if self.kwargs.get("format", "json") == "json":
7584
self._data = results.json()
85+
elif self.product in FEEDS_PRODUCTS_LIST:
86+
self._data = results # Uses generator to handle large data results from feeds endpoint
7687
else:
7788
self._data = results.text()
7889
limit_exceeded, message = self.check_limit_exceeded()
@@ -83,7 +94,6 @@ async def _make_async_request(self, session):
8394

8495
async def __awaitable__(self):
8596
if self._data is None:
86-
8797
async with AsyncClient(verify=self.api.verify_ssl, proxy=self.api.proxy_url, timeout=None) as session:
8898
wait_time = self._wait_time()
8999
if wait_time is None and self.api:

0 commit comments

Comments
 (0)