Skip to content

Commit f0ffd6b

Browse files
committed
Cleanup code and add improvements in the implementation.
1 parent 5c98576 commit f0ffd6b

File tree

3 files changed

+29
-34
lines changed

3 files changed

+29
-34
lines changed

domaintools/api.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -147,15 +147,12 @@ def _handle_api_key_parameters(self, is_rttf_product):
147147

148148
def handle_api_key(self, is_rttf_product, path, parameters):
149149
if self.https and not self.always_sign_api_key:
150-
if self.header_authentication:
151-
parameters["X-Api-Key"] = self.key
152-
else:
153-
parameters["api_key"] = self.key
150+
parameters["api_key"] = self.key
154151
else:
155152
if is_rttf_product:
156153
# As per requirement in IDEV-2272, raise this error when the user explicitly sets signing of API key for RTTF endpoints
157154
raise ValueError("Real Time Threat Feeds do not support signed API keys.")
158-
elif self.key_sign_hash and self.key_sign_hash in AVAILABLE_KEY_SIGN_HASHES:
155+
if self.key_sign_hash and self.key_sign_hash in AVAILABLE_KEY_SIGN_HASHES:
159156
signing_hash = eval(self.key_sign_hash)
160157
else:
161158
raise ValueError(

domaintools/base_results.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -75,45 +75,45 @@ def _wait_time(self):
7575

7676
return wait_for
7777

78-
def _get_session_params(self):
79-
parameters = deepcopy(self.kwargs)
80-
parameters.pop("output_format", None)
81-
parameters.pop(
82-
"format", None
83-
) # For some unknownn reasons, even if "format" is not included in the cli params for feeds endpoint, it is being populated thus we need to remove it. Happens only if using CLI.
78+
def _get_session_params_and_headers_and_headers(self):
8479
headers = {}
85-
if self.kwargs.get("output_format", OutputFormat.JSONL.value) == OutputFormat.CSV.value:
86-
parameters["headers"] = int(bool(self.kwargs.get("headers", False)))
87-
headers["accept"] = HEADER_ACCEPT_KEY_CSV_FORMAT
88-
89-
header_api_key = parameters.pop("X-Api-Key", None)
90-
if header_api_key:
91-
headers["X-Api-Key"] = header_api_key
80+
parameters = deepcopy(self.kwargs)
81+
is_rttf_product = self.product in RTTF_PRODUCTS_LIST
82+
if is_rttf_product:
83+
parameters.pop("output_format", None)
84+
parameters.pop(
85+
"format", None
86+
) # For some unknownn reasons, even if "format" is not included in the cli params for feeds endpoint, it is being populated thus we need to remove it. Happens only if using CLI.
87+
if self.kwargs.get("output_format", OutputFormat.JSONL.value) == OutputFormat.CSV.value:
88+
parameters["headers"] = int(bool(self.kwargs.get("headers", False)))
89+
headers["accept"] = HEADER_ACCEPT_KEY_CSV_FORMAT
90+
91+
if self.api.header_authentication:
92+
header_key_for_api_key = "X-Api-Key" if is_rttf_product else "X-API-Key"
93+
headers[header_key_for_api_key] = parameters.pop("api_key", None)
9294

9395
return {"parameters": parameters, "headers": headers}
9496

9597
def _make_request(self):
9698

9799
with Client(verify=self.api.verify_ssl, proxy=self.api.proxy_url, timeout=None) as session:
100+
session_params_and_headers = self._get_session_params_and_headers_and_headers()
101+
headers = session_params_and_headers.get("headers")
98102
if self.product in [
99103
"iris-investigate",
100104
"iris-enrich",
101105
"iris-detect-escalate-domains",
102106
]:
103107
post_data = self.kwargs.copy()
104108
post_data.update(self.api.extra_request_params)
105-
return session.post(url=self.url, data=post_data)
109+
return session.post(url=self.url, data=post_data, headers=headers)
106110
elif self.product in ["iris-detect-manage-watchlist-domains"]:
107111
patch_data = self.kwargs.copy()
108112
patch_data.update(self.api.extra_request_params)
109-
return session.patch(url=self.url, json=patch_data)
110-
elif self.product in RTTF_PRODUCTS_LIST:
111-
session_params = self._get_session_params()
112-
parameters = session_params.get("parameters")
113-
headers = session_params.get("headers")
114-
return session.get(url=self.url, params=parameters, headers=headers, **self.api.extra_request_params)
113+
return session.patch(url=self.url, json=patch_data, headers=headers)
115114
else:
116-
return session.get(url=self.url, params=self.kwargs, **self.api.extra_request_params)
115+
parameters = session_params_and_headers.get("parameters")
116+
return session.get(url=self.url, params=parameters, headers=headers, **self.api.extra_request_params)
117117

118118
def _get_results(self):
119119
wait_for = self._wait_time()

domaintools_async/__init__.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,21 +43,19 @@ def __await__(self):
4343
return self.__awaitable__().__await__()
4444

4545
async def _make_async_request(self, session):
46+
session_params_and_headers = self._get_session_params_and_headers_and_headers()
47+
headers = session_params_and_headers.get("headers")
4648
if self.product in ["iris-investigate", "iris-enrich", "iris-detect-escalate-domains"]:
4749
post_data = self.kwargs.copy()
4850
post_data.update(self.api.extra_request_params)
49-
results = await session.post(url=self.url, data=post_data)
51+
results = await session.post(url=self.url, data=post_data, headers=headers)
5052
elif self.product in ["iris-detect-manage-watchlist-domains"]:
5153
patch_data = self.kwargs.copy()
52-
patch_data.update(self.api.extra_request_params)
54+
patch_data.update(self.api.extra_request_params, headers=headers)
5355
results = await session.patch(url=self.url, json=patch_data)
54-
elif self.product in RTTF_PRODUCTS_LIST:
55-
session_params = self._get_session_params()
56-
parameters = session_params.get("parameters")
57-
headers = session_params.get("headers")
58-
results = await session.get(url=self.url, params=parameters, headers=headers, **self.api.extra_request_params)
5956
else:
60-
results = await session.get(url=self.url, params=self.kwargs, **self.api.extra_request_params)
57+
parameters = session_params_and_headers.get("parameters")
58+
results = await session.get(url=self.url, params=parameters, headers=headers, **self.api.extra_request_params)
6159
if results:
6260
self.setStatus(results.status_code, results)
6361
if self.kwargs.get("format", "json") == "json":

0 commit comments

Comments
 (0)