Skip to content

Commit 37b77c5

Browse files
authored
[client] add basic injector contract client (OpenAEV-Platform/openaev#3334)
Signed-off-by: Antoine MAZEAS <[email protected]>
1 parent a98717a commit 37b77c5

File tree

10 files changed

+273
-23
lines changed

10 files changed

+273
-23
lines changed

pyobas/apis/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .inject_expectation import * # noqa: F401,F403
88
from .inject_expectation_trace import * # noqa: F401,F403
99
from .injector import * # noqa: F401,F403
10+
from .injector_contract import * # noqa: F401,F403
1011
from .kill_chain_phase import * # noqa: F401,F403
1112
from .me import * # noqa: F401,F403
1213
from .organization import * # noqa: F401,F403

pyobas/apis/injector_contract.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from typing import Any, Dict
2+
3+
from pyobas import exceptions as exc
4+
from pyobas.apis.inputs.search import InjectorContractSearchPaginationInput
5+
from pyobas.base import RESTManager, RESTObject
6+
from pyobas.mixins import CreateMixin, DeleteMixin, UpdateMixin
7+
from pyobas.utils import RequiredOptional
8+
9+
10+
class InjectorContract(RESTObject):
11+
pass
12+
13+
14+
class InjectorContractManager(CreateMixin, UpdateMixin, DeleteMixin, RESTManager):
15+
_path = "/injector_contracts"
16+
_obj_cls = InjectorContract
17+
_create_attrs = RequiredOptional(
18+
required=(
19+
"contract_content",
20+
"contract_id",
21+
"contract_labels",
22+
"injector_id",
23+
),
24+
optional=(
25+
"contract_attack_patterns_ids",
26+
"contract_attack_patterns_external_ids",
27+
"contract_vulnerability_external_ids",
28+
"contract_manual",
29+
"contract_platforms",
30+
"external_contract_id",
31+
"is_atomic_testing",
32+
),
33+
)
34+
_update_attrs = RequiredOptional(
35+
required=(
36+
"contract_content",
37+
"contract_labels",
38+
),
39+
optional=(
40+
"contract_attack_patterns_ids",
41+
"contract_vulnerability_ids",
42+
"contract_vulnerability_external_ids",
43+
"contract_manual",
44+
"contract_platforms",
45+
"is_atomic_testing",
46+
),
47+
)
48+
49+
@exc.on_http_error(exc.OpenBASUpdateError)
50+
def search(
51+
self, input: InjectorContractSearchPaginationInput, **kwargs: Any
52+
) -> Dict[str, Any]:
53+
path = f"{self.path}/search"
54+
# force the serialisation here since we only need a naive serialisation to json
55+
result = self.openbas.http_post(path, post_data=input.to_dict(), **kwargs)
56+
return result

pyobas/apis/inputs/__init__.py

Whitespace-only changes.

pyobas/apis/inputs/search.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from typing import Any, Dict, List
2+
3+
4+
class Filter:
5+
def __init__(self, key: str, mode: str, operator: str, values: List[str]):
6+
self.key = key
7+
self.mode = mode
8+
self.operator = operator
9+
self.values = values
10+
11+
def to_dict(self) -> dict[str, Any]:
12+
return {k: v for k, v in self.__dict__.items() if v is not None}
13+
14+
15+
class FilterGroup:
16+
def __init__(self, mode: str, filters: List[Filter]):
17+
self.mode = mode
18+
self.filters = filters
19+
20+
def to_dict(self) -> dict[str, Any]:
21+
dictionary: dict[str, Any] = {"mode": self.mode}
22+
if self.filters:
23+
filter_dicts: List[dict[str, Any]] = []
24+
for filter_ in self.filters:
25+
filter_dicts.append(filter_.to_dict())
26+
dictionary["filters"] = filter_dicts
27+
return dictionary
28+
29+
30+
class SearchPaginationInput:
31+
def __init__(
32+
self,
33+
page: int,
34+
size: int,
35+
filter_group: FilterGroup,
36+
text_search: str,
37+
sorts: Dict[str, str],
38+
):
39+
self.size = size
40+
self.page = page
41+
self.filterGroup = filter_group
42+
self.text_search = text_search
43+
self.sorts = sorts
44+
45+
def to_dict(self) -> dict[str, Any]:
46+
dictionary: dict[str, Any] = {"page": self.page, "size": self.size}
47+
if self.sorts:
48+
dictionary["sorts"] = self.sorts
49+
if self.text_search:
50+
dictionary["textSearch"] = self.text_search
51+
if self.filterGroup:
52+
dictionary["filterGroup"] = self.filterGroup.to_dict()
53+
return dictionary
54+
55+
56+
class InjectorContractSearchPaginationInput(SearchPaginationInput):
57+
def __init__(
58+
self,
59+
page: int,
60+
size: int,
61+
filter_group: FilterGroup,
62+
text_search: str = None,
63+
sorts: Dict[str, str] = None,
64+
include_full_details: bool = True,
65+
):
66+
super().__init__(page, size, filter_group, text_search, sorts)
67+
self.include_full_details = include_full_details
68+
69+
def to_dict(self) -> dict[str, Any]:
70+
dictionary: dict[str, Any] = super().to_dict()
71+
dictionary["include_full_details"] = self.include_full_details
72+
return dictionary

pyobas/client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def __init__(
6161
self.collector = apis.CollectorManager(self)
6262
self.cve = apis.CveManager(self)
6363
self.inject = apis.InjectManager(self)
64+
self.injector_contract = apis.InjectorContractManager(self)
6465
self.document = apis.DocumentManager(self)
6566
self.kill_chain_phase = apis.KillChainPhaseManager(self)
6667
self.attack_pattern = apis.AttackPatternManager(self)

pyobas/contracts/contract_config.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,15 +132,45 @@ class Contract:
132132
+ VariableHelper.uri_variables()
133133
)
134134
contract_attack_patterns_external_ids: List[str] = field(default_factory=list)
135+
contract_vulnerability_external_ids: List[str] = field(default_factory=list)
135136
is_atomic_testing: bool = True
136137
platforms: List[str] = field(default_factory=list)
138+
external_id: str = None
137139

138140
def add_attack_pattern(self, var: str):
139141
self.contract_attack_patterns_external_ids.append(var)
140142

143+
def add_vulnerability(self, var: str):
144+
self.contract_vulnerability_external_ids.append(var)
145+
141146
def add_variable(self, var: ContractVariable):
142147
self.variables.append(var)
143148

149+
def to_contract_add_input(self, source_id: str):
150+
return {
151+
"contract_id": self.contract_id,
152+
"external_contract_id": self.external_id,
153+
"injector_id": source_id,
154+
"contract_manual": self.manual,
155+
"contract_labels": self.label,
156+
"contract_attack_patterns_external_ids": self.contract_attack_patterns_external_ids,
157+
"contract_vulnerability_external_ids": self.contract_vulnerability_external_ids,
158+
"contract_content": json.dumps(self, cls=utils.EnhancedJSONEncoder),
159+
"is_atomic_testing": self.is_atomic_testing,
160+
"contract_platforms": self.platforms,
161+
}
162+
163+
def to_contract_update_input(self):
164+
return {
165+
"contract_manual": self.manual,
166+
"contract_labels": self.label,
167+
"contract_attack_patterns_external_ids": self.contract_attack_patterns_external_ids,
168+
"contract_vulnerability_external_ids": self.contract_vulnerability_external_ids,
169+
"contract_content": json.dumps(self, cls=utils.EnhancedJSONEncoder),
170+
"is_atomic_testing": self.is_atomic_testing,
171+
"contract_platforms": self.platforms,
172+
}
173+
144174

145175
@dataclass
146176
class ContractTeam(ContractCardinalityElement):

pyobas/mixins.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,3 +216,27 @@ def create(
216216
assert not isinstance(server_data, requests.Response)
217217
assert self._obj_cls is not None
218218
return self._obj_cls(self, server_data)
219+
220+
221+
class DeleteMixin(_RestManagerBase):
222+
_computed_path: Optional[str]
223+
_from_parent_attrs: Dict[str, Any]
224+
_obj_cls: Optional[Type[base.RESTObject]]
225+
_parent: Optional[base.RESTObject]
226+
_parent_attrs: Dict[str, Any]
227+
_path: Optional[str]
228+
openbas: pyobas.OpenBAS
229+
230+
@exc.on_http_error(exc.OpenBASCreateError)
231+
def delete(
232+
self, id: Optional[Union[str, int]] = None, **kwargs: Any
233+
) -> requests.Response:
234+
if id is None:
235+
path = self.path
236+
else:
237+
path = f"{self.path}/{utils.EncodedId(id)}"
238+
239+
result = self.openbas.http_delete(path, **kwargs)
240+
if TYPE_CHECKING:
241+
assert isinstance(result, requests.Response)
242+
return result

pyobas/utils.py

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,7 @@ def add_fields(self, log_record, record, message_dict):
129129
log_record["level"] = record.levelname
130130

131131

132-
def logger(level, json_logging=True):
133-
# Exceptions
132+
def setup_logging_config(level, json_logging=True):
134133
logging.getLogger("urllib3").setLevel(logging.WARNING)
135134
logging.getLogger("pika").setLevel(logging.ERROR)
136135
# Exceptions
@@ -143,34 +142,43 @@ def logger(level, json_logging=True):
143142
else:
144143
logging.basicConfig(level=level)
145144

146-
class AppLogger:
147-
def __init__(self, name):
148-
self.local_logger = logging.getLogger(name)
149145

150-
@staticmethod
151-
def prepare_meta(meta=None):
152-
return None if meta is None else {"attributes": meta}
146+
class AppLogger:
147+
def __init__(self, level, json_logging=True, name: str = __name__):
148+
self.log_level = level
149+
self.json_logging = json_logging
150+
setup_logging_config(self.log_level, self.json_logging)
151+
self.local_logger = logging.getLogger(name)
152+
153+
def __call__(self, name):
154+
self.local_logger = logging.getLogger(name)
155+
return self
153156

154-
@staticmethod
155-
def setup_logger_level(lib, log_level):
156-
logging.getLogger(lib).setLevel(log_level)
157+
@staticmethod
158+
def prepare_meta(meta=None):
159+
return None if meta is None else {"attributes": meta}
157160

158-
def debug(self, message, meta=None):
159-
self.local_logger.debug(message, extra=AppLogger.prepare_meta(meta))
161+
@staticmethod
162+
def setup_logger_level(lib, log_level):
163+
logging.getLogger(lib).setLevel(log_level)
160164

161-
def info(self, message, meta=None):
162-
self.local_logger.info(message, extra=AppLogger.prepare_meta(meta))
165+
def debug(self, message, meta=None):
166+
self.local_logger.debug(message, extra=AppLogger.prepare_meta(meta))
163167

164-
def warning(self, message, meta=None):
165-
self.local_logger.warning(message, extra=AppLogger.prepare_meta(meta))
168+
def info(self, message, meta=None):
169+
self.local_logger.info(message, extra=AppLogger.prepare_meta(meta))
166170

167-
def error(self, message, meta=None):
168-
# noinspection PyTypeChecker
169-
self.local_logger.error(
170-
message, exc_info=1, extra=AppLogger.prepare_meta(meta)
171-
)
171+
def warning(self, message, meta=None):
172+
self.local_logger.warning(message, extra=AppLogger.prepare_meta(meta))
172173

173-
return AppLogger
174+
def error(self, message, meta=None):
175+
# noinspection PyTypeChecker
176+
self.local_logger.error(message, exc_info=1, extra=AppLogger.prepare_meta(meta))
177+
178+
179+
# DEPRECATED: compatibility
180+
def logger(level, json_logging=True):
181+
return AppLogger(level, json_logging)
174182

175183

176184
class PingAlive(threading.Thread):

test/apis/injector_contract/__init__.py

Whitespace-only changes.
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from unittest import TestCase, main, mock
2+
from unittest.mock import ANY
3+
4+
from pyobas import OpenBAS
5+
from pyobas.apis.inputs.search import (
6+
Filter,
7+
FilterGroup,
8+
InjectorContractSearchPaginationInput,
9+
)
10+
11+
12+
def mock_response(**kwargs):
13+
class MockResponse:
14+
def __init__(self, json_data, status_code):
15+
self.json_data = json_data
16+
self.status_code = status_code
17+
self.history = None
18+
self.content = None
19+
self.headers = {"Content-Type": "application/json"}
20+
21+
def json(self):
22+
return self.json_data
23+
24+
return MockResponse(None, 200)
25+
26+
27+
class TestInjectorContract(TestCase):
28+
@mock.patch("requests.Session.request", side_effect=mock_response)
29+
def test_search_input_correctly_serialised(self, mock_request):
30+
api_client = OpenBAS("url", "token")
31+
32+
search_input = InjectorContractSearchPaginationInput(
33+
0,
34+
20,
35+
FilterGroup("or", [Filter("prop", "and", "eq", ["titi", "toto"])]),
36+
None,
37+
None,
38+
)
39+
40+
expected_json = search_input.to_dict()
41+
api_client.injector_contract.search(search_input)
42+
43+
mock_request.assert_called_once_with(
44+
method="post",
45+
url="url/api/injector_contracts/search",
46+
params={},
47+
data=None,
48+
timeout=None,
49+
stream=False,
50+
verify=True,
51+
json=expected_json,
52+
headers=ANY,
53+
auth=ANY,
54+
)
55+
56+
57+
if __name__ == "__main__":
58+
main()

0 commit comments

Comments
 (0)