Skip to content

Commit d75c343

Browse files
committed
#442 - Review observable endpoints
1 parent 5127056 commit d75c343

File tree

5 files changed

+212
-50
lines changed

5 files changed

+212
-50
lines changed

tests/test_observable_endpoint.py

Lines changed: 55 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from typing import List
44

55
import pytest
6+
7+
from tests.utils import TestConfig
68
from thehive4py import TheHiveApi
79
from thehive4py.errors import TheHiveError
810
from thehive4py.query.sort import Asc
@@ -16,7 +18,7 @@
1618

1719

1820
class TestObservableEndpoint:
19-
def test_create_in_alert_and_get(
21+
def test_create_single_in_alert_and_get(
2022
self, thehive: TheHiveApi, test_alert: OutputAlert
2123
):
2224
created_observable = thehive.observable.create_in_alert(
@@ -33,6 +35,25 @@ def test_create_in_alert_and_get(
3335
)
3436
assert created_observable == fetched_observable
3537

38+
def test_create_multiple_in_alert(
39+
self, thehive: TheHiveApi, test_alert: OutputAlert
40+
):
41+
42+
observable_count = 3
43+
created_observables = thehive.observable.create_in_alert(
44+
alert_id=test_alert["_id"],
45+
observable={
46+
"dataType": "domain",
47+
"data": [f"{i}.example.com" for i in range(observable_count)],
48+
"message": "test observable",
49+
},
50+
)
51+
52+
fetched_observables = thehive.alert.find_observables(alert_id=test_alert["_id"])
53+
54+
for created_observable in created_observables:
55+
assert created_observable in fetched_observables
56+
3657
def test_create_in_alert_from_file_and_download_as_zip(
3758
self, thehive: TheHiveApi, test_alert: OutputAlert, tmp_path: Path
3859
):
@@ -66,7 +87,9 @@ def test_create_in_alert_from_file_and_download_as_zip(
6687
with archive_fp.open(observable_filename, pwd=b"malware") as downloaded_fp:
6788
assert downloaded_fp.read().decode() == observable_content
6889

69-
def test_create_in_case_and_get(self, thehive: TheHiveApi, test_case: OutputCase):
90+
def test_create_single_in_case_and_get(
91+
self, thehive: TheHiveApi, test_case: OutputCase
92+
):
7093
created_observable = thehive.observable.create_in_case(
7194
case_id=test_case["_id"],
7295
observable={
@@ -81,6 +104,23 @@ def test_create_in_case_and_get(self, thehive: TheHiveApi, test_case: OutputCase
81104
)
82105
assert created_observable == fetched_observable
83106

107+
def test_create_multiple_in_case(self, thehive: TheHiveApi, test_case: OutputCase):
108+
109+
observable_count = 3
110+
created_observables = thehive.observable.create_in_case(
111+
case_id=test_case["_id"],
112+
observable={
113+
"dataType": "domain",
114+
"data": [f"{i}.example.com" for i in range(observable_count)],
115+
"message": "test observable",
116+
},
117+
)
118+
119+
fetched_observables = thehive.case.find_observables(case_id=test_case["_id"])
120+
121+
for created_observable in created_observables:
122+
assert created_observable in fetched_observables
123+
84124
def test_create_in_case_from_file_and_download_as_is(
85125
self, thehive: TheHiveApi, test_case: OutputCase, tmp_path: Path
86126
):
@@ -153,25 +193,25 @@ def test_bulk_update(
153193
for key, value in expected_fields.items():
154194
assert updated_task.get(key) == value
155195

156-
@pytest.mark.skip(
157-
reason="documentation is unclear and implementation might be changed"
158-
)
159196
def test_share_and_unshare(
160-
self, thehive: TheHiveApi, test_observable: OutputObservable
197+
self,
198+
thehive: TheHiveApi,
199+
test_observable: OutputObservable,
200+
test_config: TestConfig,
161201
):
162-
organisation = "share-org"
163202

164203
thehive.observable.share(
165-
observable_id=test_observable["_id"], organisations=[organisation]
166-
)
167-
assert (
168-
len(thehive.observable.list_shares(observable_id=test_observable["_id"]))
169-
== 1
204+
observable_id=test_observable["_id"], organisations=[test_config.main_org]
170205
)
171206

172-
thehive.observable.unshare(
173-
observable_id=test_observable["_id"], organisations=[organisation]
174-
)
207+
# TODO: test `unshare` once a second organisation is allowed by the license
208+
# thehive.observable.unshare(
209+
# observable_id=test_observable["_id"], organisations=[test_config.main_org]
210+
# )
211+
212+
# TODO: test `list_shares` better once a second organisation is
213+
# allowed by the license
214+
175215
assert (
176216
len(thehive.observable.list_shares(observable_id=test_observable["_id"]))
177217
== 0

thehive4py/endpoints/observable.py

Lines changed: 140 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -15,77 +15,198 @@
1515

1616

1717
class ObservableEndpoint(EndpointBase):
18-
def create_in_alert(
18+
def create_in_case(
1919
self,
20-
alert_id: str,
20+
case_id: str,
2121
observable: InputObservable,
2222
observable_path: Optional[str] = None,
2323
) -> List[OutputObservable]:
24+
"""Create one or more observables in a case.
25+
26+
Args:
27+
case_id: The id of the case.
28+
observable: The fields of the observable to create.
29+
observable_path: Optional path in case of file based observables.
30+
31+
Returns:
32+
The created case observables.
33+
"""
2434
kwargs = self._build_observable_kwargs(
2535
observable=observable, observable_path=observable_path
2636
)
2737
return self._session.make_request(
28-
"POST", path=f"/api/v1/alert/{alert_id}/observable", **kwargs
38+
"POST", path=f"/api/v1/case/{case_id}/observable", **kwargs
2939
)
3040

31-
def create_in_case(
41+
def create_in_alert(
3242
self,
33-
case_id: str,
43+
alert_id: str,
3444
observable: InputObservable,
3545
observable_path: Optional[str] = None,
3646
) -> List[OutputObservable]:
47+
"""Create one or more observables in an alert.
48+
49+
Args:
50+
alert_id: The id of the alert.
51+
observable: The fields of the observable to create.
52+
observable_path: Optional path in case of file based observables.
53+
54+
Returns:
55+
The created alert observables.
56+
"""
3757
kwargs = self._build_observable_kwargs(
3858
observable=observable, observable_path=observable_path
3959
)
4060
return self._session.make_request(
41-
"POST", path=f"/api/v1/case/{case_id}/observable", **kwargs
61+
"POST", path=f"/api/v1/alert/{alert_id}/observable", **kwargs
4262
)
4363

4464
def get(self, observable_id: str) -> OutputObservable:
65+
"""Get an observable by id.
66+
67+
Args:
68+
observable_id: The id of the observable.
69+
70+
Returns:
71+
The observable specified by the id.
72+
"""
4573
return self._session.make_request(
4674
"GET", path=f"/api/v1/observable/{observable_id}"
4775
)
4876

4977
def delete(self, observable_id: str) -> None:
78+
"""Delete an observable.
79+
80+
Args:
81+
observable_id: The id of the observable.
82+
83+
Returns:
84+
N/A
85+
"""
5086
return self._session.make_request(
5187
"DELETE", path=f"/api/v1/observable/{observable_id}"
5288
)
5389

5490
def update(self, observable_id: str, fields: InputUpdateObservable) -> None:
91+
"""Update an observable.
92+
93+
Args:
94+
observable_id: The id of the observable.
95+
fields: The fields of the observable to update.
96+
97+
Returns:
98+
N/A
99+
"""
55100
return self._session.make_request(
56101
"PATCH", path=f"/api/v1/observable/{observable_id}", json=fields
57102
)
58103

59104
def bulk_update(self, fields: InputBulkUpdateObservable) -> None:
105+
"""Update multiple observables with the same values.
106+
107+
Args:
108+
fields: The ids and the fields of the observables to update.
109+
110+
Returns:
111+
N/A
112+
"""
60113
return self._session.make_request(
61114
"PATCH", path="/api/v1/observable/_bulk", json=fields
62115
)
63116

117+
def download_attachment(
118+
self,
119+
observable_id: str,
120+
attachment_id: str,
121+
observable_path: str,
122+
as_zip: bool = False,
123+
) -> None:
124+
"""Download an observable attachment.
125+
126+
Args:
127+
observable_id: The id of the observable.
128+
attachment_id: The id of the observable attachment.
129+
observable_path: The local path to download the observable attachment to.
130+
as_zip: If `True`, the attachment will be sent as a zip file with a
131+
password. Default password is 'malware'
132+
133+
Returns:
134+
N/A
135+
"""
136+
return self._session.make_request(
137+
"GET",
138+
path=(
139+
f"/api/v1/observable/{observable_id}"
140+
f"/attachment/{attachment_id}/download"
141+
),
142+
params={"asZip": as_zip},
143+
download_path=observable_path,
144+
)
145+
146+
def list_shares(self, observable_id: str) -> List[OutputShare]:
147+
"""List all organisation shares of an observable.
148+
149+
Args:
150+
observable_id: The id of the observable.
151+
152+
Returns:
153+
The list of organisation shares of the observable.
154+
"""
155+
return self._session.make_request(
156+
"GET", path=f"/api/v1/case/{observable_id}/shares"
157+
)
158+
64159
def share(self, observable_id: str, organisations: List[str]) -> None:
160+
"""Share the observable with other organisations.
161+
162+
The case that owns the observable must already be shared with the target
163+
organisations.
164+
165+
Args:
166+
observable_id: The id of the observable.
167+
organisations: The list of organisation names or ids.
168+
169+
Returns:
170+
The list of organisation shares of the observable.
171+
"""
65172
return self._session.make_request(
66173
"POST",
67174
path=f"/api/v1/observable/{observable_id}/shares",
68175
json={"organisations": organisations},
69176
)
70177

71178
def unshare(self, observable_id: str, organisations: List[str]) -> None:
179+
"""Unshare an observable from other organisations.
180+
181+
Args:
182+
observable_id: The id of the observable.
183+
organisations: The list of organisation names or ids.
184+
185+
Returns:
186+
N/A
187+
"""
72188
return self._session.make_request(
73189
"DELETE",
74190
path=f"/api/v1/observable/{observable_id}/shares",
75191
json={"organisations": organisations},
76192
)
77193

78-
def list_shares(self, observable_id: str) -> List[OutputShare]:
79-
return self._session.make_request(
80-
"GET", path=f"/api/v1/case/{observable_id}/shares"
81-
)
82-
83194
def find(
84195
self,
85196
filters: Optional[FilterExpr] = None,
86197
sortby: Optional[SortExpr] = None,
87198
paginate: Optional[Paginate] = None,
88199
) -> List[OutputObservable]:
200+
"""Find multiple observables.
201+
202+
Args:
203+
filters: The filter expressions to apply in the query.
204+
sortby: The sort expressions to apply in the query.
205+
paginate: The pagination experssion to apply in the query.
206+
207+
Returns:
208+
The list of observables matched by the query or an empty list.
209+
"""
89210
query: QueryExpr = [
90211
{"_name": "listObservable"},
91212
*self._build_subquery(filters=filters, sortby=sortby, paginate=paginate),
@@ -99,6 +220,14 @@ def find(
99220
)
100221

101222
def count(self, filters: Optional[FilterExpr] = None) -> int:
223+
"""Count observables.
224+
225+
Args:
226+
filters: The filter expressions to apply in the query.
227+
228+
Returns:
229+
The count of observables matched by the query.
230+
"""
102231
query: QueryExpr = [
103232
{"_name": "listObservable"},
104233
*self._build_subquery(filters=filters),
@@ -111,20 +240,3 @@ def count(self, filters: Optional[FilterExpr] = None) -> int:
111240
params={"name": "observable.count"},
112241
json={"query": query},
113242
)
114-
115-
def download_attachment(
116-
self,
117-
observable_id: str,
118-
attachment_id: str,
119-
observable_path: str,
120-
as_zip=False,
121-
) -> None:
122-
return self._session.make_request(
123-
"GET",
124-
path=(
125-
f"/api/v1/observable/{observable_id}"
126-
f"/attachment/{attachment_id}/download"
127-
),
128-
params={"asZip": as_zip},
129-
download_path=observable_path,
130-
)

thehive4py/endpoints/user.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def delete(self, user_id: str, organisation: Optional[str] = None) -> None:
113113
114114
Args:
115115
user_id: The id of the user.
116-
organisation: The organisation from which to delete the user from. Optional.
116+
organisation: The organisation from which the user should be deleted.
117117
118118
Returns:
119119
N/A

0 commit comments

Comments
 (0)