Skip to content

Commit 957904d

Browse files
committed
#437 - Review task endpoints
1 parent c405d7c commit 957904d

File tree

2 files changed

+166
-13
lines changed

2 files changed

+166
-13
lines changed

tests/test_task_endpoint.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,22 @@ def test_set_as_required_and_done(
6666
actions = thehive.task.get_required_actions(task_id=test_task["_id"])
6767
assert actions[organisation] is False
6868

69-
@pytest.mark.skip(
70-
reason="documentation is unclear and implementation might be changed"
71-
)
72-
def test_share_and_unshare(self, thehive: TheHiveApi, test_task: OutputTask):
73-
pass
69+
def test_share_and_unshare(
70+
self, thehive: TheHiveApi, test_task: OutputTask, test_config: TestConfig
71+
):
72+
73+
thehive.task.share(
74+
task_id=test_task["_id"], organisations=[test_config.main_org]
75+
)
76+
77+
# TODO: test `unshare` once a second organisation is allowed by the license
78+
# thehive.task.unshare(
79+
# task_id=test_task["_id"], organisations=[test_config.main_org]
80+
# )
81+
82+
# TODO: test `list_shares` better once a second organisation is
83+
# allowed by the license
84+
assert len(thehive.task.list_shares(task_id=test_task["_id"])) == 0
7485

7586
def test_find_and_count(self, thehive: TheHiveApi, test_tasks: List[OutputTask]):
7687
filters = Eq("title", test_tasks[0]["title"]) | Eq(

thehive4py/endpoints/task.py

Lines changed: 150 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
from typing import List, Optional
1+
from typing import Dict, List, Optional
22

33
from thehive4py.endpoints._base import EndpointBase
44
from thehive4py.query import QueryExpr
55
from thehive4py.query.filters import FilterExpr
66
from thehive4py.query.page import Paginate
77
from thehive4py.query.sort import SortExpr
8+
from thehive4py.types.share import OutputShare
89
from thehive4py.types.task import (
910
InputBulkUpdateTask,
1011
InputTask,
@@ -16,56 +17,170 @@
1617

1718
class TaskEndpoint(EndpointBase):
1819
def create(self, case_id: str, task: InputTask) -> OutputTask:
20+
"""Create a task.
21+
22+
Args:
23+
case_id: The id of the case to create the task for.
24+
task: The body of the task.
25+
26+
Returns:
27+
The created task.
28+
"""
1929
return self._session.make_request(
2030
"POST", path=f"/api/v1/case/{case_id}/task", json=task
2131
)
2232

2333
def get(self, task_id: str) -> OutputTask:
34+
"""Get a task by id.
35+
36+
Args:
37+
task_id: The id of the task.
38+
39+
Returns:
40+
The task specified by the id.
41+
"""
2442
return self._session.make_request("GET", path=f"/api/v1/task/{task_id}")
2543

2644
def delete(self, task_id: str) -> None:
45+
"""Delete a task.
46+
47+
Args:
48+
task_id: The id of the task.
49+
50+
Returns:
51+
N/A
52+
"""
2753
return self._session.make_request("DELETE", path=f"/api/v1/task/{task_id}")
2854

2955
def update(self, task_id: str, fields: InputUpdateTask) -> None:
56+
"""Update a task.
57+
58+
Args:
59+
task_id: The id of the task.
60+
fields: The fields of the task to update.
61+
62+
Returns:
63+
N/A
64+
"""
3065
return self._session.make_request(
3166
"PATCH", path=f"/api/v1/task/{task_id}", json=fields
3267
)
3368

3469
def bulk_update(self, fields: InputBulkUpdateTask) -> None:
70+
"""Update multiple tasks with the same values.
71+
72+
Args:
73+
fields: The ids and the fields of the tasks to update.
74+
75+
Returns:
76+
N/A
77+
"""
3578
return self._session.make_request(
3679
"PATCH", path="/api/v1/task/_bulk", json=fields
3780
)
3881

39-
def get_required_actions(self, task_id: str) -> dict:
82+
def get_required_actions(self, task_id: str) -> Dict[str, bool]:
83+
"""Get the required actions per organization for a specific task.
84+
85+
Args:
86+
task_id: The id of the task.
87+
88+
Returns:
89+
A dictionary where the keys are organization ids and the values are
90+
booleans indicating whether the task is required for that organization.
91+
"""
4092
return self._session.make_request(
4193
"GET", path=f"/api/v1/task/{task_id}/actionRequired"
4294
)
4395

4496
def set_as_required(self, task_id: str, org_id: str) -> None:
97+
"""Set a task as required.
98+
99+
Args:
100+
task_id: The id of the task.
101+
org_id: The id of the organization where the task is required.
102+
103+
Returns:
104+
N/A
105+
"""
45106
return self._session.make_request(
46107
"PUT", f"/api/v1/task/{task_id}/actionRequired/{org_id}"
47108
)
48109

49110
def set_as_done(self, task_id: str, org_id: str) -> None:
111+
"""Set a task as done.
112+
Args:
113+
task_id: The id of the task.
114+
org_id: The id of the organization where the task is done.
115+
Returns:
116+
N/A
117+
"""
50118
return self._session.make_request(
51119
"PUT", f"/api/v1/task/{task_id}/actionDone/{org_id}"
52120
)
53121

54-
def share(self):
55-
raise NotImplementedError()
122+
def list_shares(self, task_id: str) -> List[OutputShare]:
123+
"""List the shares of a task.
124+
125+
Args:
126+
task_id: The id of the task.
127+
128+
Returns:
129+
A list of shares associated with the task.
130+
"""
131+
return self._session.make_request("GET", f"/api/v1/task/{task_id}/shares")
132+
133+
def share(self, task_id: str, organisations: List[str]) -> None:
134+
"""Share the task with other organisations.
56135
57-
def list_shares(self):
58-
raise NotImplementedError()
136+
The case that owns the observable must already be shared with the
137+
target organisations.
59138
60-
def unshare(self):
61-
raise NotImplementedError()
139+
Args:
140+
task_id: The id of the task to share.
141+
organisations: The list of organisation ids or names.
142+
143+
Returns:
144+
N/A
145+
"""
146+
return self._session.make_request(
147+
"POST",
148+
f"/api/v1/task/{task_id}/shares",
149+
json={"organisations": organisations},
150+
)
151+
152+
def unshare(self, task_id: str, organisations: List[str]) -> None:
153+
"""Unshare the task with other organisations.
154+
155+
Args:
156+
task_id: The id of the task to unshare.
157+
organisations: The list of organisation ids or names.
158+
159+
Returns:
160+
N/A
161+
"""
162+
return self._session.make_request(
163+
"DELETE",
164+
f"/api/v1/task/{task_id}/shares",
165+
json={"organisations": organisations},
166+
)
62167

63168
def find(
64169
self,
65170
filters: Optional[FilterExpr] = None,
66171
sortby: Optional[SortExpr] = None,
67172
paginate: Optional[Paginate] = None,
68173
) -> List[OutputTask]:
174+
"""Find multiple tasks.
175+
176+
Args:
177+
filters: The filter expressions to apply in the query.
178+
sortby: The sort expressions to apply in the query.
179+
paginate: The pagination expression to apply in the query.
180+
181+
Returns:
182+
The list of tasks matched by the query or an empty list.
183+
"""
69184
query: QueryExpr = [
70185
{"_name": "listTask"},
71186
*self._build_subquery(filters=filters, sortby=sortby, paginate=paginate),
@@ -79,6 +194,14 @@ def find(
79194
)
80195

81196
def count(self, filters: Optional[FilterExpr] = None) -> int:
197+
"""Count tasks.
198+
199+
Args:
200+
filters: The filter expressions to apply in the query.
201+
202+
Returns:
203+
The count of tasks matched by the query.
204+
"""
82205
query: QueryExpr = [
83206
{"_name": "listTask"},
84207
*self._build_subquery(filters=filters),
@@ -93,6 +216,14 @@ def count(self, filters: Optional[FilterExpr] = None) -> int:
93216
)
94217

95218
def create_log(self, task_id: str, task_log: InputTaskLog) -> OutputTaskLog:
219+
"""Create a task log.
220+
221+
Args:
222+
task_id: The id of the task to create the log for.
223+
task_log: The body of the task log.
224+
Returns:
225+
The created task log.
226+
"""
96227
return self._session.make_request(
97228
"POST", path=f"/api/v1/task/{task_id}/log", json=task_log
98229
)
@@ -104,6 +235,17 @@ def find_logs(
104235
sortby: Optional[SortExpr] = None,
105236
paginate: Optional[Paginate] = None,
106237
) -> List[OutputTaskLog]:
238+
"""Find task logs.
239+
240+
Args:
241+
task_id: The id of the task to find logs for.
242+
filters: The filter expressions to apply in the query.
243+
sortby: The sort expressions to apply in the query.
244+
paginate: The pagination expression to apply in the query.
245+
246+
Returns:
247+
The list of task logs matched by the query or an empty list.
248+
"""
107249
query: QueryExpr = [
108250
{"_name": "getTask", "idOrName": task_id},
109251
{"_name": "logs"},

0 commit comments

Comments
 (0)