diff --git a/roles/importer/files/importer/azure2022ff/azure_getter.py b/roles/importer/files/importer/azure2022ff/azure_getter.py index 861d99847b..c623d024ec 100644 --- a/roles/importer/files/importer/azure2022ff/azure_getter.py +++ b/roles/importer/files/importer/azure2022ff/azure_getter.py @@ -1,14 +1,13 @@ # library for API get functions import base64 +from typing import Any from fwo_log import getFwoLogger -import requests.packages import requests import json import fwo_globals from fwo_exceptions import FwLoginFailed - -def api_call(url, params = {}, headers = {}, data = {}, azure_jwt = '', show_progress=False, method='get'): +def api_call(url: str, params: dict[str, Any] = {}, headers: dict[str, Any] = {}, data: dict[str, Any] | str = {}, azure_jwt: str = '', show_progress: bool = False, method: str = 'get') -> tuple[dict[str, Any], dict[str, Any]]: logger = getFwoLogger() request_headers = {} if not 'Content-Type' in headers: @@ -31,13 +30,6 @@ def api_call(url, params = {}, headers = {}, data = {}, azure_jwt = '', show_pro # error handling: exception_text = '' - if response is None: - if 'password' in json.dumps(data): - exception_text = "error while sending api_call containing credential information to url '" + \ - str(url) - else: - exception_text = "error while sending api_call to url '" + str(url) + "' with payload '" + json.dumps( - data, indent=2) + "' and headers: '" + json.dumps(request_headers, indent=2) if not response.ok: exception_text = 'error code: {error_code}, error={error}'.format(error_code=response.status_code, error=response.content) #logger.error(response.content) @@ -58,13 +50,13 @@ def api_call(url, params = {}, headers = {}, data = {}, azure_jwt = '', show_pro logger.debug("api_call to url '" + str(url) + "' with payload '" + json.dumps( data, indent=2) + "' and headers: '" + json.dumps(request_headers, indent=2)) - return response.headers, body_json + return dict(response.headers), body_json -def login(azure_user, azure_password, tenant_id, client_id, client_secret): +def login(azure_user: str, azure_password: str, tenant_id: str, client_id: str, client_secret: str) -> str | None: base_url = 'https://login.microsoftonline.com/{tenant_id}/oauth2/token'.format(tenant_id=tenant_id) try: - headers, body = api_call(base_url, method="post", + _, body = api_call(base_url, method="post", headers={'Content-Type': 'application/x-www-form-urlencoded'}, data={ "grant_type" : "client_credentials", @@ -78,7 +70,7 @@ def login(azure_user, azure_password, tenant_id, client_id, client_secret): raise FwLoginFailed("Azure login ERROR for client_id id=" + str(client_id) + " Message: " + str(e)) from None if body.get("access_token") == None: # leaving out payload as it contains pwd - raise FwLoginFailed("Azure login ERROR for client_id=" + str(client_id) + " Message: " + str(e)) from None + raise FwLoginFailed("Azure login ERROR for client_id=" + str(client_id) + " Message: None") from None if fwo_globals.debug_level > 2: logger = getFwoLogger() @@ -87,12 +79,12 @@ def login(azure_user, azure_password, tenant_id, client_id, client_secret): return body["access_token"] -def update_config_with_azure_api_call(azure_jwt, api_base_url, config, api_path, key, parameters={}, payload={}, show_progress=False, limit: int=1000, method="get"): - offset = 0 - limit = 1000 +def update_config_with_azure_api_call(azure_jwt: str, api_base_url: str, config: dict[str, Any], api_path: str, key: str, parameters: dict[str, Any]={}, payload: dict[str, Any]={}, show_progress: bool=False, limit: int=1000, method: str="get") -> None: + _ = 0 + __ = 1000 returned_new_data = True - - full_result = [] + + full_result: list[Any] = [] #while returned_new_data: # parameters["offset"] = offset # parameters["limit"] = limit diff --git a/roles/importer/files/importer/azure2022ff/azure_network.py b/roles/importer/files/importer/azure2022ff/azure_network.py index 7ffa720d8c..194179eca4 100644 --- a/roles/importer/files/importer/azure2022ff/azure_network.py +++ b/roles/importer/files/importer/azure2022ff/azure_network.py @@ -1,11 +1,11 @@ -from asyncio.log import logger -from fwo_log import getFwoLogger +from typing import Any +from netaddr import IPAddress from fwo_const import list_delimiter import ipaddress -def normalize_nwobjects(full_config, config2import, import_id, jwt=None, mgm_id=None): - nw_objects = [] +def normalize_nwobjects(full_config: dict[str, Any], config2import: dict[str, Any], import_id: str, jwt: str | None = None, mgm_id: str | None = None) -> None: + nw_objects: list[dict[str, Any]] = [] for obj_orig in full_config["networkObjects"]: nw_objects.append(parse_object(obj_orig, import_id, config2import, nw_objects)) for obj_grp_orig in full_config["networkObjectGroups"]: @@ -16,8 +16,8 @@ def normalize_nwobjects(full_config, config2import, import_id, jwt=None, mgm_id= config2import['network_objects'] = nw_objects -def extract_base_object_infos(obj_orig, import_id, config2import, nw_objects): - obj = {} +def extract_base_object_infos(obj_orig: dict[str, Any], import_id: str, config2import: dict[str, Any], nw_objects: list[dict[str, Any]]) -> dict[str, Any]: + obj: dict[str, Any] = {} if "type" in obj_orig: obj["obj_name"] = obj_orig["name"] @@ -30,9 +30,9 @@ def extract_base_object_infos(obj_orig, import_id, config2import, nw_objects): return obj -def parse_obj_group(orig_grp, import_id, nw_objects, config2import, id = None): - refs = [] - names = [] +def parse_obj_group(orig_grp: dict[str, Any], import_id: str, nw_objects: list[dict[str, Any]], config2import: dict[str, Any], id: str | None = None) -> tuple[str, str]: + refs: list[str] = [] + names: list[str] = [] if "properties" in orig_grp: if 'ipAddresses' in orig_grp['properties']: for ip in orig_grp['properties']['ipAddresses']: @@ -43,14 +43,14 @@ def parse_obj_group(orig_grp, import_id, nw_objects, config2import, id = None): return list_delimiter.join(refs), list_delimiter.join(names) -def parse_obj_list(ip_list, import_id, config, id): - refs = [] - names = [] +def parse_obj_list(ip_list: list[str], import_id: str, config: dict[str, Any], id: str | None = None) -> tuple[str, str]: + refs: list[str] = [] + names: list[str] = [] for ip in ip_list: # TODO: lookup ip in network_objects and re-use - ip_obj = {} + ip_obj: dict[str, Any] = {} ip_obj['obj_name'] = ip - ip_obj['obj_uid'] = ip_obj['obj_name'] + "_" + id + ip_obj['obj_uid'] = ip_obj['obj_name'] + "_" + (id if id is not None else "") try: ipaddress.ip_network(ip) # valid ip @@ -73,13 +73,13 @@ def parse_obj_list(ip_list, import_id, config, id): ip_obj['control_id'] = import_id - config.append(ip_obj) + config.append(ip_obj) # type: ignore # TODO: config is dict[str, Any], not list refs.append(ip_obj['obj_uid']) names.append(ip_obj['obj_name']) return list_delimiter.join(refs), list_delimiter.join(names) -def parse_object(obj_orig, import_id, config2import, nw_objects): +def parse_object(obj_orig: dict[str, Any], import_id: str, config2import: dict[str, Any], nw_objects: list[dict[str, Any]]) -> dict[str, Any]: obj = extract_base_object_infos(obj_orig, import_id, config2import, nw_objects) if obj_orig["type"] == "network": # network obj["obj_typ"] = "network" @@ -113,7 +113,7 @@ def parse_object(obj_orig, import_id, config2import, nw_objects): return obj -def add_network_object(config2import, ip=None): +def add_network_object(config2import: dict[str, Any], ip: str | None = None) -> dict[str, Any]: if "-" in str(ip): type = 'ip_range' else: diff --git a/roles/importer/files/importer/azure2022ff/azure_rule.py b/roles/importer/files/importer/azure2022ff/azure_rule.py index 8ac6ab9266..c5ff1978ba 100644 --- a/roles/importer/files/importer/azure2022ff/azure_rule.py +++ b/roles/importer/files/importer/azure2022ff/azure_rule.py @@ -1,25 +1,25 @@ +from typing import Any, Literal from azure_service import parse_svc_list from azure_network import parse_obj_list -from fwo_log import getFwoLogger import hashlib import base64 -def make_hash_sha256(o): +def make_hash_sha256(o: Any) -> str: hasher = hashlib.sha256() hasher.update(repr(make_hashable(o)).encode()) return base64.b64encode(hasher.digest()).decode() -def make_hashable(o): +def make_hashable(o: Any) -> tuple[Any, ...] | Any: if isinstance(o, (tuple, list)): - return tuple((make_hashable(e) for e in o)) + return tuple([make_hashable(e) for e in o]) # type: ignore if isinstance(o, dict): - return tuple(sorted((k,make_hashable(v)) for k,v in o.items())) + return tuple(sorted((k, make_hashable(v)) for k, v in o.items())) # type: ignore if isinstance(o, (set, frozenset)): - return tuple(sorted(make_hashable(e) for e in o)) + return tuple(sorted(make_hashable(e) for e in o)) # type: ignore return o @@ -32,27 +32,28 @@ def make_hashable(o): # rule_scope = rule_access_scope + rule_nat_scope -def normalize_access_rules(full_config, config2import, import_id, mgm_details={}): - rules = [] +def normalize_access_rules(full_config: dict[str, Any], config2import: dict[str, Any], import_id: str, mgm_details: dict[str, Any] = {}) -> None: + rules: list[dict[str, Any]] = [] - nw_obj_names = [] + nw_obj_names: list[str] = [] for o in config2import['network_objects']: nw_obj_names.append(o["obj_name"]) - for device in full_config["devices"]: + for _ in full_config["devices"]: rule_number = 0 for policy_name in full_config['devices'].keys(): for rule_prop in full_config['devices'][policy_name]['rules']: rule_coll_container = rule_prop['properties'] if 'ruleCollections' in rule_coll_container: for rule_coll in rule_coll_container['ruleCollections']: + rule_action: Literal["accept", "deny"] | None = None if 'ruleCollectionType' in rule_coll and rule_coll['ruleCollectionType'] == 'FirewallPolicyFilterRuleCollection': rule_action = "accept" if rule_coll['action']['type'] == 'Deny': rule_action = "deny" for rule_orig in rule_coll['rules']: - rule = {'rule_src': 'any', 'rule_dst': 'any', 'rule_svc': 'any', + rule: dict[str, Any] = {'rule_src': 'any', 'rule_dst': 'any', 'rule_svc': 'any', 'rule_src_refs': 'any_obj_placeholder', 'rule_dst_refs': 'any_obj_placeholder', 'rule_svc_refs': 'any_svc_placeholder'} rule['rulebase_name'] = policy_name @@ -73,7 +74,7 @@ def normalize_access_rules(full_config, config2import, import_id, mgm_details={} if "sourceAddresses" in rule_orig: rule['rule_src_refs'], rule["rule_src"] = parse_obj_list(rule_orig["sourceAddresses"], import_id, config2import['network_objects'], rule["rule_uid"]) if "destinationAddresses" in rule_orig: - undefObjects = [] + undefObjects: list[str] = [] for obj in rule_orig['destinationAddresses']: if "obj_name" in obj: diff --git a/roles/importer/files/importer/azure2022ff/azure_service.py b/roles/importer/files/importer/azure2022ff/azure_service.py index d537e314a5..64b11af1bd 100644 --- a/roles/importer/files/importer/azure2022ff/azure_service.py +++ b/roles/importer/files/importer/azure2022ff/azure_service.py @@ -1,21 +1,22 @@ import random +from typing import Any from fwo_const import list_delimiter -def normalize_svcobjects(full_config, config2import, import_id): - svc_objects = [] +def normalize_svcobjects(full_config: dict[str, Any], config2import: dict[str, Any], import_id: str) -> None: + svc_objects: list[dict[str, Any]] = [] for svc_orig in full_config["serviceObjects"]: svc_objects.append(parse_svc(svc_orig, import_id)) for svc_grp_orig in full_config["serviceObjectGroups"]: svc_grp = extract_base_svc_infos(svc_grp_orig, import_id) svc_grp["svc_typ"] = "group" - svc_grp["svc_member_refs"] , svc_grp["svc_member_names"] = parse_svc_group(svc_grp_orig, import_id, svc_objects) + svc_grp["svc_member_refs"] , svc_grp["svc_member_names"] = parse_svc_group(svc_grp_orig, import_id, svc_objects) # type: ignore # TODO: parse_svc_group is not defined svc_objects.append(svc_grp) config2import['service_objects'] = svc_objects - -def extract_base_svc_infos(svc_orig, import_id): - svc = {} + +def extract_base_svc_infos(svc_orig: dict[str, Any], import_id: str) -> dict[str, Any]: + svc: dict[str, Any] = {} if "id" in svc_orig: svc["svc_uid"] = svc_orig["id"] else: @@ -36,7 +37,7 @@ def extract_base_svc_infos(svc_orig, import_id): return svc -def parse_svc(orig_svc, import_id): +def parse_svc(orig_svc: dict[str, Any], import_id: str) -> dict[str, Any]: svc = extract_base_svc_infos(orig_svc, import_id) svc["svc_typ"] = "simple" parse_port(orig_svc, svc) @@ -58,7 +59,7 @@ def parse_svc(orig_svc, import_id): return svc -def parse_port(orig_svc, svc): +def parse_port(orig_svc: dict[str, Any], svc: dict[str, Any]) -> None: if "port" in orig_svc: if orig_svc["port"].find("-") != -1: # port range port_range = orig_svc["port"].split("-") @@ -68,18 +69,16 @@ def parse_port(orig_svc, svc): svc["svc_port"] = orig_svc["port"] svc["svc_port_end"] = None - -def parse_svc_list(ports, ip_protos, import_id, svc_objects, id = None): - refs = [] - names = [] + +def parse_svc_list(ports: list[str], ip_protos: list[str], import_id: str, svc_objects: list[dict[str, Any]], id: str | None = None) -> tuple[str, str]: + refs: list[str] = [] + names: list[str] = [] for port in ports: for ip_proto in ip_protos: # TODO: lookup port in svc_objects and re-use - svc = {} - - + svc: dict[str, Any] = {} - if id == None: + if id is None: id = str(random.random()) svc['svc_name'] = ip_proto + "_" + port diff --git a/roles/importer/files/importer/azure2022ff/fwcommon.py b/roles/importer/files/importer/azure2022ff/fwcommon.py index b38954b4d9..baa2941714 100644 --- a/roles/importer/files/importer/azure2022ff/fwcommon.py +++ b/roles/importer/files/importer/azure2022ff/fwcommon.py @@ -1,6 +1,7 @@ # import sys # from common import importer_base_dir # sys.path.append(importer_base_dir + '/azure2022ff') +from typing import Any from azure_service import normalize_svcobjects from azure_rule import normalize_access_rules from azure_network import normalize_nwobjects @@ -8,12 +9,12 @@ from fwo_log import getFwoLogger from azure_base import azure_api_version_str -def has_config_changed(full_config, mgm_details, force=False): +def has_config_changed(full_config: dict[str, Any], mgm_details: dict[str, Any], force: bool=False): # dummy - may be filled with real check later on return True -def get_config(config2import, full_config, current_import_id, mgm_details, limit=1000, force=False, jwt=''): +def get_config(config2import: dict[str, Any], full_config: dict[str, Any], current_import_id: str, mgm_details: dict[str, Any], limit: int=1000, force: bool=False, jwt: str=''): logger = getFwoLogger() if full_config == {}: # no native config was passed in, so getting it from Azzure parsing_config_only = False @@ -41,7 +42,7 @@ def get_config(config2import, full_config, current_import_id, mgm_details, limit # login azure_jwt = login(azure_user, azure_password, azure_tenant_id, azure_client_id, azure_client_secret) - if azure_jwt == None or azure_jwt == "": + if azure_jwt is None or azure_jwt == "": logger.error('Did not succeed in logging in to Azure API, no jwt returned.') return 1 @@ -87,9 +88,9 @@ def get_config(config2import, full_config, current_import_id, mgm_details, limit normalize_nwobjects(full_config, config2import, current_import_id, jwt=jwt, mgm_id=mgm_details['id']) normalize_svcobjects(full_config, config2import, current_import_id) - any_nw_svc = {"svc_uid": "any_svc_placeholder", "svc_name": "Any", "svc_comment": "Placeholder service.", + any_nw_svc: dict[str, Any] = {"svc_uid": "any_svc_placeholder", "svc_name": "Any", "svc_comment": "Placeholder service.", "svc_typ": "simple", "ip_proto": -1, "svc_port": 0, "svc_port_end": 65535, "control_id": current_import_id} - any_nw_object = {"obj_uid": "any_obj_placeholder", "obj_name": "Any", "obj_comment": "Placeholder object.", + any_nw_object: dict[str, Any] = {"obj_uid": "any_obj_placeholder", "obj_name": "Any", "obj_comment": "Placeholder object.", "obj_typ": "network", "obj_ip": "0.0.0.0/0", "control_id": current_import_id} config2import["service_objects"].append(any_nw_svc) config2import["network_objects"].append(any_nw_object) @@ -101,15 +102,15 @@ def get_config(config2import, full_config, current_import_id, mgm_details, limit return 0 -def extract_nw_objects(rule, config): +def extract_nw_objects(rule: str, config: dict[str, Any]): pass -def extract_svc_objects(rule, config): +def extract_svc_objects(rule: str, config: dict[str, Any]): pass -def extract_user_objects(rule, config): +def extract_user_objects(rule: str, config: dict[str, Any]): pass diff --git a/roles/importer/files/importer/checkpointR8x/cp_gateway.py b/roles/importer/files/importer/checkpointR8x/cp_gateway.py index f1d0383ae1..b135777849 100644 --- a/roles/importer/files/importer/checkpointR8x/cp_gateway.py +++ b/roles/importer/files/importer/checkpointR8x/cp_gateway.py @@ -1,22 +1,20 @@ from fwo_log import getFwoLogger -import fwo_globals from typing import Any +from model_controllers.import_state_controller import ImportStateController + """ normalize all gateway details """ -def normalize_gateways (nativeConfig, importState, normalizedConfig): - if fwo_globals.debug_level>0: - logger = getFwoLogger() - +def normalize_gateways (nativeConfig: dict[str, Any], importState: ImportStateController, normalizedConfig: dict[str, Any]): normalizedConfig['gateways'] = [] normalize_rulebase_links (nativeConfig, importState, normalizedConfig) normalize_interfaces (nativeConfig, importState, normalizedConfig) normalize_routing (nativeConfig, importState, normalizedConfig) -def normalize_rulebase_links (nativeConfig, importState, normalizedConfig): +def normalize_rulebase_links (nativeConfig: dict[str, Any], importState: ImportStateController, normalizedConfig: dict[str, Any]): gwRange = range(len(nativeConfig['gateways'])) for gwId in gwRange: gwUid = nativeConfig['gateways'][gwId]['uid'] @@ -29,7 +27,7 @@ def normalize_rulebase_links (nativeConfig, importState, normalizedConfig): break -def get_normalized_rulebase_link(nativeConfig, gwId): +def get_normalized_rulebase_link(nativeConfig: dict[str, Any], gwId: int) -> list[dict[str, Any]]: links = nativeConfig.get('gateways', {})[gwId].get('rulebase_links') for link in links: if 'type' in link: @@ -48,8 +46,8 @@ def get_normalized_rulebase_link(nativeConfig, gwId): return links -def create_normalized_gateway(nativeConfig, gwId) -> dict[str, Any]: - gw = {} +def create_normalized_gateway(nativeConfig: dict[str, Any], gwId: int) -> dict[str, Any]: + gw: dict[str, Any] = {} gw['Uid'] = nativeConfig['gateways'][gwId]['uid'] gw['Name'] = nativeConfig['gateways'][gwId]['name'] gw['Interfaces'] = [] @@ -58,17 +56,17 @@ def create_normalized_gateway(nativeConfig, gwId) -> dict[str, Any]: return gw -def normalize_interfaces (nativeConfig, importState, normalizedConfig): +def normalize_interfaces (nativeConfig: dict[str, Any], importState: ImportStateController, normalizedConfig: dict[str, Any]): # TODO: Implement this pass -def normalize_routing (nativeConfig, importState, normalizedConfig): +def normalize_routing (nativeConfig: dict[str, Any], importState: ImportStateController, normalizedConfig: dict[str, Any]): # TODO: Implement this pass -def gw_in_normalized_config(normalizedConfig, gwUid) -> bool: +def gw_in_normalized_config(normalizedConfig: dict[str, Any], gwUid: str) -> bool: for gw in normalizedConfig['gateways']: if gw['Uid'] == gwUid: return True diff --git a/roles/importer/files/importer/checkpointR8x/cp_getter.py b/roles/importer/files/importer/checkpointR8x/cp_getter.py index c0bc68e701..2be3f469cb 100644 --- a/roles/importer/files/importer/checkpointR8x/cp_getter.py +++ b/roles/importer/files/importer/checkpointR8x/cp_getter.py @@ -17,10 +17,10 @@ from services.enums import Services from fwo_api_call import FwoApiCall, FwoApi -def cp_api_call(url, command, json_payload, sid, show_progress=False): +def cp_api_call(url: str, command: str, json_payload: dict[str, Any], sid: str | None, show_progress: bool=False): url += command request_headers = {'Content-Type' : 'application/json'} - if sid != '': # only not set for login + if sid: # only not set for login request_headers.update({'X-chkp-sid' : sid}) if fwo_globals.debug_level>8: @@ -30,7 +30,7 @@ def cp_api_call(url, command, json_payload, sid, show_progress=False): try: r = requests.post(url, json=json_payload, headers=request_headers, verify=fwo_globals.verify_certs) - except requests.exceptions.RequestException as e: + except requests.exceptions.RequestException as _: if 'password' in json.dumps(json_payload): exception_text = "\nerror while sending api_call containing credential information to url '" + str(url) else: @@ -50,7 +50,7 @@ def login(mgm_details: ManagementController): logger = getFwoLogger() payload = {'user': mgm_details.ImportUser, 'password': mgm_details.Secret} domain = mgm_details.getDomainString() - if domain is not None and domain != '': + if domain is not None and domain != '': # type: ignore # TODO: shouldnt be None payload.update({'domain': domain}) base_url = mgm_details.buildFwApiString() if int(fwo_globals.debug_level)>2: @@ -62,7 +62,7 @@ def login(mgm_details: ManagementController): return response["sid"] -def logout(url, sid): +def logout(url: str, sid: str): logger = getFwoLogger() if int(fwo_globals.debug_level)>2: logger.debug("logout from url " + url) @@ -70,7 +70,7 @@ def logout(url, sid): return response -def get_changes(sid,api_host,api_port,fromdate): +def get_changes(sid: str, api_host: str, api_port: str, fromdate: str) -> int: logger = getFwoLogger() dt_object = datetime.fromisoformat(fromdate) @@ -118,7 +118,7 @@ def get_changes(sid,api_host,api_port,fromdate): return 0 -def get_policy_structure(api_v_url, sid, show_params_policy_structure, managerDetails, policy_structure = None): +def get_policy_structure(api_v_url: str, sid: str, show_params_policy_structure: dict[str, Any], managerDetails: ManagementController, policy_structure: list[dict[str, Any]] | None = None) -> int: if policy_structure is None: policy_structure = [] @@ -142,7 +142,7 @@ def get_policy_structure(api_v_url, sid, show_params_policy_structure, managerDe return 0 -def get_show_packages_via_api(api_v_url, sid, show_params_policy_structure): +def get_show_packages_via_api(api_v_url: str, sid: str, show_params_policy_structure: dict[str, Any]) -> tuple[dict[str, Any], int, int]: try: packages = cp_api_call(api_v_url, 'show-packages', show_params_policy_structure, sid) except Exception: @@ -169,13 +169,13 @@ def get_show_packages_via_api(api_v_url, sid, show_params_policy_structure): raise FwApiError('packages do not contain to field') return packages, current, total -def parse_package(package, managerDetails): +def parse_package(package: dict[str, Any], managerDetails: ManagementController) -> tuple[dict[str, Any], bool]: alreadyFetchedPackage = False currentPackage = {} if 'installation-targets' in package and package['installation-targets'] == 'all': if not alreadyFetchedPackage: - currentPackage = { 'name': package['name'], + currentPackage: dict[str, Any] = { 'name': package['name'], 'uid': package['uid'], 'targets': [{'name': 'all', 'uid': 'all'}], 'access-layers': []} @@ -198,7 +198,7 @@ def parse_package(package, managerDetails): logger.warning ( 'installation target in package: ' + package['uid'] + ' is missing name or uid') return currentPackage, alreadyFetchedPackage -def is_valid_installation_target(installationTarget, managerDetails): +def is_valid_installation_target(installationTarget: dict[str, Any], managerDetails: ManagementController) -> bool: """ensures that target is defined as gateway in database""" if 'target-name' in installationTarget and 'target-uid' in installationTarget: for device in managerDetails.Devices: @@ -206,7 +206,7 @@ def is_valid_installation_target(installationTarget, managerDetails): return True return False -def add_access_layers_to_current_package(package, currentPackage): +def add_access_layers_to_current_package(package: dict[str, Any], currentPackage: dict[str, Any]) -> None: if 'access-layers' in package: for accessLayer in package['access-layers']: @@ -217,12 +217,12 @@ def add_access_layers_to_current_package(package, currentPackage): else: raise FwApiError('access layer in package: ' + package['uid'] + ' is missing name or uid') -def get_global_assignments(api_v_url, sid, show_params_policy_structure) -> list[Any]: +def get_global_assignments(api_v_url: str, sid: str, show_params_policy_structure: dict[str, Any]) -> list[Any]: logger = getFwoLogger() current=0 total=current+1 show_params_policy_structure.update({'offset': current}) - global_assignments = [] + global_assignments: list[dict[str, Any]] = [] while (current list for assignment in assignments['objects']: if 'type' not in assignment and assignment['type'] != 'global-assignment': raise FwoImporterError ('global assignment with unexpected type') - global_assignment = { + global_assignment: dict[str, Any] = { 'uid': assignment['uid'], 'global-domain': { 'uid': assignment['global-domain']['uid'], @@ -271,7 +271,9 @@ def get_global_assignments(api_v_url, sid, show_params_policy_structure) -> list return global_assignments -def get_rulebases(api_v_url, sid, show_params_rules, nativeConfigDomain, deviceConfig, policy_rulebases_uid_list, is_global=False, access_type='access', rulebaseUid=None, rulebaseName=None): +def get_rulebases(api_v_url: str, sid: str | None, show_params_rules: dict[str, Any], nativeConfigDomain: dict[str, Any] | None, + deviceConfig: dict[str, Any] | None, policy_rulebases_uid_list: list[str], is_global: bool = False, + access_type: str = 'access', rulebaseUid: str | None = None, rulebaseName: str | None = None) -> list[str]: # access_type: access / nat logger = getFwoLogger() @@ -297,10 +299,10 @@ def get_rulebases(api_v_url, sid, show_params_rules, nativeConfigDomain, deviceC rulebaseUid = get_uid_of_rulebase(rulebaseName, api_v_url, access_type, sid) else: logger.error('must provide either rulebaseUid or rulebaseName') - policy_rulebases_uid_list.append(rulebaseUid) + policy_rulebases_uid_list.append(rulebaseUid) #type: ignore # TODO: get_uid_of_rulebase can return None but in theory should not # search all rulebases in nativeConfigDomain and import if rulebase is not already fetched - fetchedRulebaseList = [] + fetchedRulebaseList: list[str] = [] for fetchedRulebase in nativeConfigDomain[nativeConfigRulebaseKey]: fetchedRulebaseList.append(fetchedRulebase['uid']) if fetchedRulebase['uid'] == rulebaseUid: @@ -309,7 +311,7 @@ def get_rulebases(api_v_url, sid, show_params_rules, nativeConfigDomain, deviceC # get rulebase in chunks if rulebaseUid not in fetchedRulebaseList: - current_rulebase = get_rulebases_in_chunks(rulebaseUid, show_params_rules, api_v_url, access_type, sid, nativeConfigDomain) + current_rulebase = get_rulebases_in_chunks(rulebaseUid, show_params_rules, api_v_url, access_type, sid, nativeConfigDomain) #type: ignore # TODO: rulebaseUid can be None but in theory should not nativeConfigDomain[nativeConfigRulebaseKey].append(current_rulebase) # use recursion to get inline layers @@ -319,9 +321,9 @@ def get_rulebases(api_v_url, sid, show_params_rules, nativeConfigDomain, deviceC return policy_rulebases_uid_list -def get_uid_of_rulebase(rulebaseName, api_v_url, access_type, sid): +def get_uid_of_rulebase(rulebaseName: str, api_v_url: str, access_type: str, sid: str | None) -> str | None: # TODO: what happens if rulebaseUid None? Error? rulebaseUid = None - get_rulebase_uid_params = { + get_rulebase_uid_params: dict[str, Any] = { 'name': rulebaseName, 'limit': 1, 'use-object-dictionary': False, @@ -337,9 +339,9 @@ def get_uid_of_rulebase(rulebaseName, api_v_url, access_type, sid): return rulebaseUid -def get_rulebases_in_chunks(rulebaseUid, show_params_rules, api_v_url, access_type, sid, nativeConfigDomain): +def get_rulebases_in_chunks(rulebaseUid: str, show_params_rules: dict[str, Any], api_v_url: str, access_type: str, sid: str, nativeConfigDomain: dict[str, Any]) -> dict[str, Any]: - current_rulebase = {'uid': rulebaseUid, 'name': '', 'chunks': []} + current_rulebase: dict[str, Any] = {'uid': rulebaseUid, 'name': '', 'chunks': []} show_params_rules.update({'uid': rulebaseUid}) current=0 total=current+1 @@ -369,9 +371,9 @@ def get_rulebases_in_chunks(rulebaseUid, show_params_rules, api_v_url, access_ty return current_rulebase -def resolve_checkpoint_uids_via_object_dict(rulebase, nativeConfigDomain, - current_rulebase, - rulebaseUid, show_params_rules): +def resolve_checkpoint_uids_via_object_dict(rulebase: dict[str, Any], nativeConfigDomain: dict[str, Any], + current_rulebase: dict[str, Any], + rulebaseUid: str, show_params_rules: dict[str, Any]) -> None: """ Checkpoint stores some rulefields as uids, function translates them to names """ @@ -387,7 +389,7 @@ def resolve_checkpoint_uids_via_object_dict(rulebase, nativeConfigDomain, + rulebaseUid + ", params: " + str(show_params_rules)) -def control_while_loop_in_get_rulebases_in_chunks(current_rulebase, rulebase, sid, api_v_url, show_params_rules): +def control_while_loop_in_get_rulebases_in_chunks(current_rulebase: dict[str, Any], rulebase: dict[str, Any], sid: str, api_v_url: str, show_params_rules: dict[str, Any]) -> tuple[int, int]: total=0 if 'total' in rulebase: total=rulebase['total'] @@ -410,7 +412,7 @@ def control_while_loop_in_get_rulebases_in_chunks(current_rulebase, rulebase, si return total, current -def get_inline_layers_recursively(current_rulebase, deviceConfig, nativeConfigDomain, api_v_url, sid, show_params_rules, is_global, policy_rulebases_uid_list): +def get_inline_layers_recursively(current_rulebase: dict[str, Any], deviceConfig: dict[str, Any], nativeConfigDomain: dict[str, Any], api_v_url: str, sid: str | None, show_params_rules: dict[str, Any], is_global: bool, policy_rulebases_uid_list: list[str]) -> list[str]: """Takes current_rulebase, splits sections into sub-rulebases and searches for layerguards to fetch """ current_rulebase_uid = current_rulebase['uid'] @@ -445,7 +447,7 @@ def get_inline_layers_recursively(current_rulebase, deviceConfig, nativeConfigDo return policy_rulebases_uid_list -def section_traversal_and_links(section, current_rulebase_uid, deviceConfig, is_global): +def section_traversal_and_links(section: dict[str, Any], current_rulebase_uid: str, deviceConfig: dict[str, Any], is_global: bool) -> tuple[dict[str, Any], str]: """If section is actually rule, fake it to be section and link sections as self-contained rulebases """ @@ -481,7 +483,7 @@ def section_traversal_and_links(section, current_rulebase_uid, deviceConfig, is_ return section, current_rulebase_uid -def get_placeholder_in_rulebase(rulebase): +def get_placeholder_in_rulebase(rulebase: dict[str, Any]) -> tuple[str | None, str | None]: placeholder_rule_uid = None placeholder_rulebase_uid = None @@ -492,7 +494,7 @@ def get_placeholder_in_rulebase(rulebase): # if no section is used, use dummy section if section['type'] != 'access-section': - section = { + section: dict[str, Any] = { 'type': 'access-section', 'rulebase': [section] } @@ -504,7 +506,7 @@ def get_placeholder_in_rulebase(rulebase): return placeholder_rule_uid, placeholder_rulebase_uid -def assign_placeholder_uids(rulebase, section, rule, placeholder_rule_uid, placeholder_rulebase_uid): +def assign_placeholder_uids(rulebase: dict[str, Any], section: dict[str, Any], rule: dict[str, Any], placeholder_rule_uid: str | None, placeholder_rulebase_uid: str | None) -> tuple[str | None, str | None]: if rule['type'] == 'place-holder': placeholder_rule_uid = rule['uid'] if 'uid' in section: @@ -514,9 +516,9 @@ def assign_placeholder_uids(rulebase, section, rule, placeholder_rule_uid, place return placeholder_rule_uid, placeholder_rulebase_uid -def get_nat_rules_from_api_as_dict (api_v_url, sid, show_params_rules, nativeConfigDomain={}): +def get_nat_rules_from_api_as_dict (api_v_url: str, sid: str, show_params_rules: dict[str, Any], nativeConfigDomain: dict[str, Any]={}): logger = getFwoLogger() - nat_rules = { "nat_rule_chunks": [] } + nat_rules: dict[str, list[Any]] = { "nat_rule_chunks": [] } current=0 total=current+1 while (current dict[str, Any] | None: for el in array: if 'uid' in el and el['uid']==uid: return el return None -def resolve_ref_from_object_dictionary(uid, objDict, native_config_domain={}, field_name=None): +def resolve_ref_from_object_dictionary(uid: str | None, objDict: list[dict[str, Any]], native_config_domain: dict[str, Any]={}, field_name: str | None=None) -> dict[str, Any] | None: matched_obj = find_element_by_uid(objDict, uid) @@ -582,21 +584,22 @@ def resolve_ref_from_object_dictionary(uid, objDict, native_config_domain={}, fi # resolving all uid references using the object dictionary # dealing with a single chunk -def resolve_ref_list_from_object_dictionary(rulebase, value, objDict={}, nativeConfigDomain={}): - if 'objects-dictionary' in rulebase: - objDict = rulebase['objects-dictionary'] +def resolve_ref_list_from_object_dictionary(rulebase: list[dict[str, Any]] | dict[str, Any], value: str, objDicts: list[dict[str, Any]]=[], nativeConfigDomain: dict[str, Any]={}): # TODO: what is objDict: I think it should be a list of dicts + if isinstance(rulebase, dict): + if 'objects-dictionary' in rulebase: + objDicts = rulebase['objects-dictionary'] if isinstance(rulebase, list): # found a list of rules for rule in rulebase: if value in rule: - categorize_value_for_resolve_ref(rule, value, objDict, nativeConfigDomain) + categorize_value_for_resolve_ref(rule, value, objDicts, nativeConfigDomain) if 'rulebase' in rule: - resolve_ref_list_from_object_dictionary(rule['rulebase'], value, objDict=objDict, nativeConfigDomain=nativeConfigDomain) + resolve_ref_list_from_object_dictionary(rule['rulebase'], value, objDicts=objDicts, nativeConfigDomain=nativeConfigDomain) elif 'rulebase' in rulebase: - resolve_ref_list_from_object_dictionary(rulebase['rulebase'], value, objDict=objDict, nativeConfigDomain=nativeConfigDomain) + resolve_ref_list_from_object_dictionary(rulebase['rulebase'], value, objDicts=objDicts, nativeConfigDomain=nativeConfigDomain) -def categorize_value_for_resolve_ref(rule, value, objDict, nativeConfigDomain): - value_list = [] +def categorize_value_for_resolve_ref(rule: dict[str, Any], value: str, objDict: list[dict[str, Any]], nativeConfigDomain: dict[str, Any]): + value_list: list[Any] = [] if isinstance(rule[value], str): # assuming single uid rule[value] = resolve_ref_from_object_dictionary(rule[value], objDict, native_config_domain=nativeConfigDomain, field_name=value) else: @@ -608,7 +611,7 @@ def categorize_value_for_resolve_ref(rule, value, objDict, nativeConfigDomain): rule[value] = value_list # replace ref list with object list -def getObjectDetailsFromApi(uid_missing_obj, sid='', apiurl='') -> dict[str, Any]: +def getObjectDetailsFromApi(uid_missing_obj: str, sid: str='', apiurl: str='') -> dict[str, Any]: logger = getFwoLogger() if fwo_globals.debug_level>5: logger.debug(f"getting {uid_missing_obj} from API") diff --git a/roles/importer/files/importer/checkpointR8x/cp_network.py b/roles/importer/files/importer/checkpointR8x/cp_network.py index 4b1ca74342..90ba8fd109 100644 --- a/roles/importer/files/importer/checkpointR8x/cp_network.py +++ b/roles/importer/files/importer/checkpointR8x/cp_network.py @@ -1,3 +1,4 @@ +from typing import Any from fwo_log import getFwoLogger import json import cp_const @@ -12,8 +13,8 @@ from fwo_api_call import FwoApiCall, FwoApi -def normalize_network_objects(full_config, config2import, import_id, mgm_id=0): - nw_objects = [] +def normalize_network_objects(full_config: dict[str, Any], config2import: dict[str, Any], import_id: int, mgm_id: int=0): + nw_objects: list[dict[str, Any]] = [] logger = getFwoLogger() global_domain = initialize_global_domain(full_config['objects']) @@ -35,14 +36,14 @@ def normalize_network_objects(full_config, config2import, import_id, mgm_id=0): config2import.update({'network_objects': nw_objects}) -def set_dummy_ip_for_object_without_ip(nw_obj): +def set_dummy_ip_for_object_without_ip(nw_obj: dict[str, Any]) -> None: logger = getFwoLogger() if nw_obj['obj_typ']!='group' and (nw_obj['obj_ip'] is None or nw_obj['obj_ip'] == ''): logger.warning("found object without IP :" + nw_obj['obj_name'] + " (type=" + nw_obj['obj_typ'] + ") - setting dummy IP") nw_obj.update({'obj_ip': fwo_const.dummy_ip}) nw_obj.update({'obj_ip_end': fwo_const.dummy_ip}) -def initialize_global_domain(objects : list[dict]): +def initialize_global_domain(objects : list[dict[str, Any]]) -> dict[str, Any]: """Returns CP Global Domain for MDS and standalone domain otherwise """ @@ -61,7 +62,7 @@ def initialize_global_domain(objects : list[dict]): return global_domain -def collect_nw_objects(object_table, nw_objects, global_domain, mgm_id=0): +def collect_nw_objects(object_table: dict[str, Any], nw_objects: list[dict[str, Any]], global_domain: dict[str, Any], mgm_id: int=0) -> None: """Collect nw_objects from object tables and write them into global nw_objects dict """ @@ -84,18 +85,18 @@ def collect_nw_objects(object_table, nw_objects, global_domain, mgm_id=0): 'obj_member_refs': member_refs, 'obj_member_names': member_names}) -def get_domain_uid(obj, global_domain): +def get_domain_uid(obj: dict[str, Any], global_domain: dict[str, Any]) -> str | dict[str, Any] | None: """Returns the domain UID for the given object. If the object has a 'domain' key with a 'uid', it returns that UID. Otherwise, it returns the global domain UID. """ if 'domain' not in obj or 'uid' not in obj['domain']: - return obj.update({'domain': global_domain}) + return obj.update({'domain': global_domain}) #TODO: check if the None value is wanted else: return obj['domain']['uid'] -def is_obj_already_collected(nw_objects, obj): +def is_obj_already_collected(nw_objects: list[dict[str, Any]], obj: dict[str, Any]) -> bool: logger = getFwoLogger() if 'uid' not in obj: logger.warning("found nw_object without uid: " + str(obj)) @@ -111,7 +112,7 @@ def is_obj_already_collected(nw_objects, obj): return False -def handle_members(obj): +def handle_members(obj: dict[str, Any]) -> tuple[str | None, str | None]: """Gets group member uids, currently no member_names """ member_refs = None @@ -126,7 +127,7 @@ def handle_members(obj): obj['members'] = None return member_refs, member_names -def handle_object_type_and_ip(obj, ip_addr): +def handle_object_type_and_ip(obj: dict[str, Any], ip_addr: str | None) -> tuple[str, str | None, str | None]: logger = getFwoLogger() obj_type = 'undef' ipArray = cidrToRange(ip_addr) @@ -176,7 +177,7 @@ def handle_object_type_and_ip(obj, ip_addr): return obj_type, first_ip, last_ip -def get_comment_and_color_of_obj(obj): +def get_comment_and_color_of_obj(obj: dict[str, Any]) -> str | None: """Returns comment and sets missing color to black """ if 'comments' not in obj or obj['comments'] == '': @@ -188,7 +189,7 @@ def get_comment_and_color_of_obj(obj): return comments # for members of groups, the name of the member obj needs to be fetched separately (starting from API v1.?) -def resolve_nw_uid_to_name(uid, nw_objects): +def resolve_nw_uid_to_name(uid: str, nw_objects: list[dict[str, Any]]) -> str: # return name of nw_objects element where obj_uid = uid for obj in nw_objects: if obj['obj_uid'] == uid: @@ -196,7 +197,7 @@ def resolve_nw_uid_to_name(uid, nw_objects): return 'ERROR: uid "' + uid + '" not found' -def add_member_names_for_nw_group(idx, nw_objects): +def add_member_names_for_nw_group(idx: int, nw_objects: list[dict[str, Any]]) -> None: group = nw_objects.pop(idx) if group['obj_member_refs'] == '' or group['obj_member_refs'] is None: #member_names = None @@ -213,7 +214,7 @@ def add_member_names_for_nw_group(idx, nw_objects): nw_objects.insert(idx, group) -def validate_ip_address(address): +def validate_ip_address(address: str) -> bool: try: # ipaddress.ip_address(address) ipaddress.ip_network(address) @@ -224,7 +225,7 @@ def validate_ip_address(address): # print("IP address {} is not valid".format(address)) -def get_ip_of_obj(obj, mgm_id=None): +def get_ip_of_obj(obj: dict[str, Any], mgm_id: int | None = None) -> str | None: if 'ipv4-address' in obj: ip_addr = obj['ipv4-address'] elif 'ipv6-address' in obj: @@ -255,11 +256,11 @@ def get_ip_of_obj(obj, mgm_id=None): return ip_addr -def make_host(ip_in) -> str | None: - ip_obj = ipaddress.ip_address(ip_in) +def make_host(ip_in: str) -> str | None: + ip_obj: ipaddress.IPv4Address | ipaddress.IPv6Address = ipaddress.ip_address(ip_in) # If it's a valid address, append the appropriate CIDR notation if isinstance(ip_obj, ipaddress.IPv4Address): return f"{ip_in}/32" - elif isinstance(ip_obj, ipaddress.IPv6Address): + elif isinstance(ip_obj, ipaddress.IPv6Address): # TODO: check if just else is sufficient # type: ignore return f"{ip_in}/128" diff --git a/roles/importer/files/importer/checkpointR8x/cp_rule.py b/roles/importer/files/importer/checkpointR8x/cp_rule.py index 2c5806fc5e..dc67feee7f 100644 --- a/roles/importer/files/importer/checkpointR8x/cp_rule.py +++ b/roles/importer/files/importer/checkpointR8x/cp_rule.py @@ -4,16 +4,16 @@ import ast from fwo_log import getFwoLogger -import fwo_const import fwo_globals from fwo_const import list_delimiter, default_section_header_text from fwo_base import sanitize -from fwo_exceptions import ImportRecursionLimitReached, FwoImporterErrorInconsistencies +from fwo_exceptions import FwoImporterErrorInconsistencies from models.rulebase import Rulebase from models.rule import RuleNormalized from models.rule_enforced_on_gateway import RuleEnforcedOnGatewayNormalized +from roles.importer.files.importer.model_controllers.import_state_controller import ImportStateController -uid_to_name_map = {} +uid_to_name_map: dict[str, str] = {} """ new import format which takes the following cases into account without duplicating any rules in the DB: @@ -23,8 +23,8 @@ - migrate section headers from rule to ordering element ... """ -def normalize_rulebases (nativeConfig, native_config_global, importState, normalized_config_dict, - normalized_config_global, is_global_loop_iteration): +def normalize_rulebases (nativeConfig: dict[str, Any], native_config_global: dict[str, Any] | None, importState: ImportStateController, normalized_config_dict: dict[str, Any], + normalized_config_global: dict[str, Any] | None, is_global_loop_iteration: bool): normalized_config_dict['policies'] = [] @@ -32,7 +32,7 @@ def normalize_rulebases (nativeConfig, native_config_global, importState, normal for nw_obj in normalized_config_dict['network_objects']: uid_to_name_map[nw_obj['obj_uid']] = nw_obj['obj_name'] - fetched_rulebase_uids = [] + fetched_rulebase_uids: list[str] = [] if normalized_config_global is not None and normalized_config_global != {}: for normalized_rulebase_global in normalized_config_global['policies']: fetched_rulebase_uids.append(normalized_rulebase_global.uid) @@ -40,13 +40,13 @@ def normalize_rulebases (nativeConfig, native_config_global, importState, normal normalize_rulebases_for_each_link_destination( gateway, fetched_rulebase_uids, nativeConfig, native_config_global, is_global_loop_iteration, importState, normalized_config_dict, - normalized_config_global) + normalized_config_global) #type: ignore # TODO: check if normalized_config_global can be None, I am pretty sure it cannot be None here # todo: parse nat rulebase here def normalize_rulebases_for_each_link_destination( - gateway, fetched_rulebase_uids, nativeConfig, - native_config_global, is_global_loop_iteration, importState, normalized_config_dict, normalized_config_global): + gateway: dict[str, Any], fetched_rulebase_uids: list[str], nativeConfig: dict[str, Any], + native_config_global: dict[str, Any] | None, is_global_loop_iteration: bool, importState: ImportStateController, normalized_config_dict: dict[str, Any], normalized_config_global: dict[str, Any]): logger = getFwoLogger() for rulebase_link in gateway['rulebase_links']: if rulebase_link['to_rulebase_uid'] not in fetched_rulebase_uids and rulebase_link['to_rulebase_uid'] != '': @@ -71,7 +71,7 @@ def normalize_rulebases_for_each_link_destination( else: normalized_config_dict['policies'].append(normalized_rulebase) -def find_rulebase_to_parse(rulebase_list, rulebase_uid): +def find_rulebase_to_parse(rulebase_list: list[dict[str, Any]], rulebase_uid: str) -> tuple[dict[str, Any], bool, bool]: """ decide if input rulebase is true rulebase, section or placeholder """ @@ -85,7 +85,7 @@ def find_rulebase_to_parse(rulebase_list, rulebase_uid): # handle case: no rulebase found return {}, False, False -def find_rulebase_to_parse_in_case_of_chunk(rulebase, rulebase_uid): +def find_rulebase_to_parse_in_case_of_chunk(rulebase: dict[str, Any], rulebase_uid: str) -> tuple[dict[str, Any], bool, bool]: is_section = False rulebase_to_parse = {} for chunk in rulebase['chunks']: @@ -97,7 +97,7 @@ def find_rulebase_to_parse_in_case_of_chunk(rulebase, rulebase_uid): rulebase_to_parse, is_section = find_rulebase_to_parse_in_case_of_section(is_section, rulebase_to_parse, section) return rulebase_to_parse, is_section, False -def find_rulebase_to_parse_in_case_of_section(is_section, rulebase_to_parse, section): +def find_rulebase_to_parse_in_case_of_section(is_section: bool, rulebase_to_parse: dict[str, Any], section: dict[str, Any]) -> tuple[dict[str, Any], bool]: if is_section: rulebase_to_parse = concatenat_sections_across_chunks(rulebase_to_parse, section) else: @@ -105,7 +105,7 @@ def find_rulebase_to_parse_in_case_of_section(is_section, rulebase_to_parse, sec rulebase_to_parse = section return rulebase_to_parse, is_section -def concatenat_sections_across_chunks(rulebase_to_parse, section): +def concatenat_sections_across_chunks(rulebase_to_parse: dict[str, Any], section: dict[str, Any]) -> dict[str, Any]: if 'to' in rulebase_to_parse and 'from' in section: if rulebase_to_parse['to'] + 1 == section['from']: if rulebase_to_parse['name'] == section['name']: @@ -121,13 +121,13 @@ def concatenat_sections_across_chunks(rulebase_to_parse, section): return rulebase_to_parse -def initialize_normalized_rulebase(rulebase_to_parse, mgm_uid): +def initialize_normalized_rulebase(rulebase_to_parse: dict[str, Any], mgm_uid: str) -> Rulebase: rulebaseName = rulebase_to_parse['name'] rulebaseUid = rulebase_to_parse['uid'] normalized_rulebase = Rulebase(uid=rulebaseUid, name=rulebaseName, mgm_uid=mgm_uid, rules={}) return normalized_rulebase -def parse_rulebase(rulebase_to_parse, is_section, is_placeholder, normalized_rulebase, gateway, policy_structure): +def parse_rulebase(rulebase_to_parse: dict[str, Any], is_section: bool, is_placeholder: bool, normalized_rulebase: Rulebase, gateway: dict[str, Any], policy_structure: list[dict[str, Any]]): logger = getFwoLogger() if is_section: @@ -143,7 +143,7 @@ def parse_rulebase(rulebase_to_parse, is_section, is_placeholder, normalized_rul else: parse_rulebase_chunk(rulebase_to_parse, normalized_rulebase, gateway, policy_structure) -def parse_rulebase_chunk(rulebase_to_parse, normalized_rulebase, gateway, policy_structure): +def parse_rulebase_chunk(rulebase_to_parse: dict[str, Any], normalized_rulebase: Rulebase, gateway: dict[str, Any], policy_structure: list[dict[str, Any]]): logger = getFwoLogger() for chunk in rulebase_to_parse['chunks']: for rule in chunk['rulebase']: @@ -154,7 +154,7 @@ def parse_rulebase_chunk(rulebase_to_parse, normalized_rulebase, gateway, policy return -def acceptMalformedParts(objects: dict, part: str ='') -> dict[str, Any]: +def acceptMalformedParts(objects: dict[str, Any] | list[dict[str, Any]], part: str ='') -> dict[str, Any]: if fwo_globals.debug_level>9: logger.debug(f'about to accept malformed rule part ({part}): {str(objects)}') @@ -179,39 +179,34 @@ def acceptMalformedParts(objects: dict, part: str ='') -> dict[str, Any]: return {} -def parseRulePart (objects: dict, part: str = 'source') -> dict[str, Any]: +def parseRulePart (objects: dict[str, Any] | list[dict[str, Any] | None] | None, part: str = 'source') -> dict[str, Any]: addressObjects: dict[str, Any] = {} if objects is None: logger.debug(f"rule part {part} is None: {str(objects)}, which is normal for track field in inline layer guards") - return None + return None # type: ignore #TODO: check if this is ok or should raise an Exception if 'chunks' in objects: # for chunks of actions?! - addressObjects.update(parseRulePart(objects['chunks'], part=part)) # need to parse chunk first - return addressObjects + addressObjects.update(parseRulePart(objects['chunks'], part=part)) # need to parse chunk first # type: ignore # TODO: This Has to be refactored if isinstance(objects, dict): return _parse_single_address_object(addressObjects, objects, part) # assuming list of objects - if objects is None: - logger.error(f'rule part {part} is None: {str(objects)}') - return None for obj in objects: if obj is None: logger.warning(f'found list with a single None obj: {str(objects)}') continue - if 'chunks' in obj: - addressObjects.update(parseRulePart(obj['chunks'], part=part)) # need to parse chunk first + addressObjects.update(parseRulePart(obj['chunks'], part=part)) # need to parse chunk first # type: ignore # TODO: check if this is ok or should raise an Exception elif 'objects' in obj: for o in obj['objects']: - addressObjects.update(parseRulePart(o, part=part)) # need to parse chunk first + addressObjects.update(parseRulePart(o, part=part)) # need to parse chunk first # type: ignore # TODO: check if this is ok or should raise an Exception return addressObjects else: if 'type' in obj: # found checkpoint object _parse_obj_with_type(obj, addressObjects) else: - return acceptMalformedParts(objects, part=part) + return acceptMalformedParts(objects, part=part) # type: ignore # TODO: check if this is ok or should raise an Exception if '' in addressObjects.values(): logger.warning('found empty name in one rule part (' + part + '): ' + str(addressObjects)) @@ -256,7 +251,7 @@ def _parse_obj_with_access_role(obj: dict[str,Any], addressObjects: dict[str,Any addressObjects[obj['uid']] = obj['name'] + '@' + nw_resolved -def parse_single_rule(nativeRule, rulebase, layer_name, parent_uid, gateway, policy_structure): +def parse_single_rule(nativeRule: dict[str, Any], rulebase: Rulebase, layer_name: str, parent_uid: str | None, gateway: dict[str, Any], policy_structure: list[dict[str, Any]]): logger = getFwoLogger() # reference to domain rule layer, filling up basic fields @@ -282,7 +277,7 @@ def parse_single_rule(nativeRule, rulebase, layer_name, parent_uid, gateway, pol rule_track = _parse_track(native_rule=nativeRule) actionObjects = parseRulePart (nativeRule['action'], 'action') - if actionObjects is not None: + if actionObjects is not None: # type: ignore # TODO: this should be never None rule_action = list_delimiter.join(actionObjects.values()) # expecting only a single action else: rule_action = None @@ -318,7 +313,7 @@ def parse_single_rule(nativeRule, rulebase, layer_name, parent_uid, gateway, pol else: last_hit = None - rule = { + rule: dict[str, Any] = { "rule_num": 0, "rule_num_numeric": 0, "rulebase_name": sanitize(layer_name), @@ -352,7 +347,7 @@ def parse_single_rule(nativeRule, rulebase, layer_name, parent_uid, gateway, pol return -def _parse_parent_rule_uid(parent_uid: str, native_rule: dict[str,Any]) -> str|None: +def _parse_parent_rule_uid(parent_uid: str | None, native_rule: dict[str,Any]) -> str | None: # new in v5.1.17: if 'parent_rule_uid' in native_rule: @@ -372,15 +367,15 @@ def _parse_track(native_rule: dict[str, Any]) -> str: if isinstance(native_rule['track'],str): rule_track = native_rule['track'] else: - trackObjects = parseRulePart (native_rule['track'], 'track') - if trackObjects is None: + trackObjects = parseRulePart(native_rule['track'], 'track') + if trackObjects is None: # type: ignore # TODO: should never be None rule_track = 'none' else: rule_track = list_delimiter.join(trackObjects.values()) return rule_track -def parse_rule_enforced_on_gateway(gateway, policy_structure, native_rule: dict) -> list[RuleEnforcedOnGatewayNormalized]: +def parse_rule_enforced_on_gateway(gateway: dict[str, Any], policy_structure: list[dict[str, Any]], native_rule: dict[str, Any]) -> list[RuleEnforcedOnGatewayNormalized]: """Parse rule enforcement information from native rule. Args: @@ -395,7 +390,7 @@ def parse_rule_enforced_on_gateway(gateway, policy_structure, native_rule: dict) if not native_rule: raise ValueError('Native rule cannot be empty') - enforce_entries = [] + enforce_entries: list[RuleEnforcedOnGatewayNormalized] = [] all_target_gw_names_dict = parseRulePart(native_rule['install-on'], 'install-on') for targetUid in all_target_gw_names_dict: @@ -410,8 +405,8 @@ def parse_rule_enforced_on_gateway(gateway, policy_structure, native_rule: dict) enforce_entries.append(enforceEntry) return enforce_entries -def find_devices_for_current_policy(gateway, policy_structure): - device_uid_list = [] +def find_devices_for_current_policy(gateway: dict[str, Any], policy_structure: list[dict[str, Any]]) -> list[str]: + device_uid_list: list[str] = [] for policy in policy_structure: for target in policy['targets']: if target['uid'] == gateway['uid']: @@ -420,7 +415,7 @@ def find_devices_for_current_policy(gateway, policy_structure): return device_uid_list -def resolveNwObjUidToName(nw_obj_uid): +def resolveNwObjUidToName(nw_obj_uid: str) -> str: if nw_obj_uid in uid_to_name_map: return uid_to_name_map[nw_obj_uid] else: @@ -430,7 +425,7 @@ def resolveNwObjUidToName(nw_obj_uid): # delete_v: left here only for nat case -def check_and_add_section_header(src_rulebase, target_rulebase, layer_name, import_id, section_header_uids, parent_uid): +def check_and_add_section_header(src_rulebase: dict[str, Any], target_rulebase: Rulebase, layer_name: str, import_id: str, section_header_uids: set[str], parent_uid: str): # if current rulebase starts a new section, add section header, but only if it does not exist yet (can happen by chunking a section) if 'type' in src_rulebase and src_rulebase['type'] == 'access-section' and 'uid' in src_rulebase: # and not src_rulebase['uid'] in section_header_uids: section_name = default_section_header_text @@ -445,7 +440,7 @@ def check_and_add_section_header(src_rulebase, target_rulebase, layer_name, impo return -def insert_section_header_rule(target_rulebase, section_name, layer_name, import_id, src_rulebase_uid, section_header_uids, parent_uid): +def insert_section_header_rule(target_rulebase: Rulebase, section_name: str, layer_name: str, import_id: str, src_rulebase_uid: str, section_header_uids: set[str], parent_uid: str): # TODO: re-implement return diff --git a/roles/importer/files/importer/checkpointR8x/cp_service.py b/roles/importer/files/importer/checkpointR8x/cp_service.py index 4bf29e802f..5e8c4c4a03 100644 --- a/roles/importer/files/importer/checkpointR8x/cp_service.py +++ b/roles/importer/files/importer/checkpointR8x/cp_service.py @@ -1,12 +1,12 @@ import re +from typing import Any import cp_const from fwo_const import list_delimiter -from fwo_log import getFwoLogger from fwo_exceptions import FwoImporterErrorInconsistencies # collect_svcobjects writes svc info into global users dict -def collect_svc_objects(object_table, svc_objects): +def collect_svc_objects(object_table: dict[str, Any], svc_objects: list[dict[str, Any]]): if object_table['type'] in cp_const.svc_obj_table_names: typ = 'undef' if object_table['type'] in cp_const.group_svc_obj_types: @@ -28,7 +28,7 @@ def collect_svc_objects(object_table, svc_objects): }) -def _set_default_values(obj): +def _set_default_values(obj: dict[str, Any]): """ Set default values for color, comments, and domain_uid. """ @@ -41,7 +41,7 @@ def _set_default_values(obj): obj['domain_uid'] = get_obj_domain_uid(obj) -def _get_rpc_number(obj): +def _get_rpc_number(obj: dict[str, Any]) -> str | None: """ Extract RPC number from interface-uuid or program-number. Returns RPC number or None. @@ -53,7 +53,7 @@ def _get_rpc_number(obj): return None -def _get_session_timeout(obj): +def _get_session_timeout(obj: dict[str, Any]) -> str | None: """ Extract and stringify session timeout. Returns session timeout as string or None. @@ -63,7 +63,7 @@ def _get_session_timeout(obj): return None -def _get_member_references(obj): +def _get_member_references(obj: dict[str, Any]) -> str | None: """ Process members list and return concatenated member references. Returns member reference string or None. @@ -80,7 +80,7 @@ def _get_member_references(obj): return member_refs[:-1] if member_refs else None -def _get_protocol_number(obj): +def _get_protocol_number(obj: dict[str, Any]) -> int | None: """ Extract and validate protocol number from object. Returns validated protocol number or None. @@ -100,7 +100,7 @@ def _get_protocol_number(obj): return proto if proto is None or proto >= 0 else None -def collect_single_svc_object(obj): +def collect_single_svc_object(obj: dict[str, Any]) -> None: """ Collects a single service object and appends its details to the svc_objects list. Handles different types of service objects and normalizes port information. @@ -117,7 +117,7 @@ def collect_single_svc_object(obj): _set_default_values(obj) -def normalize_port(obj) -> tuple[str|None, str|None]: +def normalize_port(obj: dict[str, Any]) -> tuple[str|None, str|None]: """ Normalizes the port information in the given object. If the 'port' key exists, it processes the port value to handle ranges and special cases. @@ -158,7 +158,7 @@ def normalize_port(obj) -> tuple[str|None, str|None]: return port, port_end -def get_obj_domain_uid(obj): +def get_obj_domain_uid(obj: dict[str, Any]) -> str: """ Returns the domain UID for the given object. If the object has a 'domain' key with a 'uid', it returns that UID. @@ -171,14 +171,14 @@ def get_obj_domain_uid(obj): # return name of nw_objects element where obj_uid = uid -def resolve_svc_uid_to_name(uid, svc_objects): +def resolve_svc_uid_to_name(uid: str, svc_objects: list[dict[str, Any]]) -> str: for obj in svc_objects: if obj['svc_uid'] == uid: return obj['svc_name'] raise FwoImporterErrorInconsistencies('Service object member uid ' + uid + ' not found') -def add_member_names_for_svc_group(idx, svc_objects): +def add_member_names_for_svc_group(idx: int, svc_objects: list[dict[str, Any]]) -> None: member_names = '' group = svc_objects.pop(idx) @@ -192,8 +192,8 @@ def add_member_names_for_svc_group(idx, svc_objects): svc_objects.insert(idx, group) -def normalize_service_objects(full_config, config2import, import_id, debug_level=0): - svc_objects = [] +def normalize_service_objects(full_config: dict[str, Any], config2import: dict[str, Any], import_id: int, debug_level: int = 0) -> None: + svc_objects: list[dict[str, Any]] = [] for obj_dict in full_config['objects']: collect_svc_objects(obj_dict, svc_objects) for obj in svc_objects: diff --git a/roles/importer/files/importer/checkpointR8x/cp_user.py b/roles/importer/files/importer/checkpointR8x/cp_user.py index 3ee5b1ab8d..dd5c901d12 100644 --- a/roles/importer/files/importer/checkpointR8x/cp_user.py +++ b/roles/importer/files/importer/checkpointR8x/cp_user.py @@ -1,7 +1,8 @@ +from typing import Any from fwo_log import getFwoLogger import json -def collect_users_from_rule(rule, users): #, objDict): +def collect_users_from_rule(rule: dict[str, Any], users: dict[str, Any]): #, objDict): if 'rule-number' in rule: # standard rule logger = getFwoLogger() if 'type' in rule and rule['type'] != 'place-holder': @@ -45,7 +46,7 @@ def collect_users_from_rule(rule, users): #, objDict): # collect_users writes user info into global users dict -def collect_users_from_rulebase(rulebase, users): +def collect_users_from_rulebase(rulebase: dict[str, Any], users: dict[str, Any]) -> None: if 'rulebase_chunks' in rulebase: for chunk in rulebase['rulebase_chunks']: if 'rulebase' in chunk: @@ -53,24 +54,24 @@ def collect_users_from_rulebase(rulebase, users): collect_users_from_rule(rule, users) else: for rule in rulebase: - collect_users_from_rule(rule, users) + collect_users_from_rule(rule, users) # type: ignore #TODO refactor this # the following is only used within new python-only importer: -def parse_user_objects_from_rulebase(rulebase, users, import_id): +def parse_user_objects_from_rulebase(rulebase: dict[str, Any], users: dict[str, Any], import_id: str) -> None: collect_users_from_rulebase(rulebase, users) for user_name in users.keys(): # TODO: get user info via API - userUid = getUserUidFromCpApi(user_name) + _ = getUserUidFromCpApi(user_name) # finally add the import id users[user_name]['control_id'] = import_id - -def getUserUidFromCpApi (userName): +def getUserUidFromCpApi (userName: str) -> str: # show-object with UID # dummy implementation returning the name as uid return userName -def normalizeUsersLegacy(): + +def normalizeUsersLegacy() -> None: raise NotImplementedError diff --git a/roles/importer/files/importer/checkpointR8x/fwcommon.py b/roles/importer/files/importer/checkpointR8x/fwcommon.py index 0f3e96481f..a7ed41b9f3 100644 --- a/roles/importer/files/importer/checkpointR8x/fwcommon.py +++ b/roles/importer/files/importer/checkpointR8x/fwcommon.py @@ -12,16 +12,16 @@ from model_controllers.fwconfigmanagerlist_controller import FwConfigManagerListController from models.fwconfig_normalized import FwConfigNormalized from model_controllers.import_state_controller import ImportStateController -from fwo_base import ConfigAction, ConfFormat +from fwo_base import ConfigAction import fwo_const import fwo_globals from model_controllers.fwconfig_normalized_controller import FwConfigNormalizedController from fwo_exceptions import ImportInterruption, FwoImporterError -from models.management import Management from models.import_state import ImportState +from model_controllers.management_controller import ManagementController -def has_config_changed (full_config, importState: ImportState, force=False): +def has_config_changed (full_config: dict[str, Any], importState: ImportState, force: bool = False): if full_config != {}: # a config was passed in (read from file), so we assume that an import has to be done (simulating changes here) return 1 @@ -56,6 +56,9 @@ def get_config(config_in: FwConfigManagerListController, importState: ImportStat start_time_temp = int(time.time()) logger.debug ( "checkpointR8x/get_config/getting objects ...") + if config_in.native_config is None: + raise FwoImporterError("native_config is None in get_config") + result_get_objects = get_objects(config_in.native_config, importState) if result_get_objects>0: raise FwLoginFailed( "checkpointR8x/get_config/error while gettings objects") @@ -87,9 +90,10 @@ def initialize_native_config(config_in: FwConfigManagerListController, importSta """ manager_details_list = create_ordered_manager_list(importState) + if config_in.native_config is None: + raise FwoImporterError("native_config is None in initialize_native_config") config_in.native_config.update({'domains': []}) for managerDetails in manager_details_list: - config_in.native_config['domains'].append({ 'domain_name': managerDetails.DomainName, 'domain_uid': managerDetails.DomainUid, @@ -102,9 +106,9 @@ def initialize_native_config(config_in: FwConfigManagerListController, importSta 'gateways': []}) -def normalize_config(import_state, config_in: FwConfigManagerListController, parsing_config_only: bool, sid: str) -> FwConfigManagerListController: +def normalize_config(import_state: ImportStateController, config_in: FwConfigManagerListController, parsing_config_only: bool, sid: str) -> FwConfigManagerListController: - native_and_normalized_config_dict_list = [] + native_and_normalized_config_dict_list: list[dict[str, Any]] = [] if config_in.native_config is None: raise FwoImporterError("Did not get a native config to normalize.") @@ -115,7 +119,7 @@ def normalize_config(import_state, config_in: FwConfigManagerListController, par # in case of mds, first nativ config domain is global is_global_loop_iteration = False - native_config_global = {} + native_config_global: dict[str, Any] = {} normalized_config_global = {} if config_in.native_config['domains'][0]['is-super-manager']: native_config_global = config_in.native_config['domains'][0] @@ -156,8 +160,8 @@ def normalize_config(import_state, config_in: FwConfigManagerListController, par return config_in -def normalize_single_manager_config(nativeConfig: dict, native_config_global: dict, normalized_config_dict: dict, - normalized_config_global: dict, importState: ImportStateController, +def normalize_single_manager_config(nativeConfig: dict[str, Any], native_config_global: dict[str, Any], normalized_config_dict: dict[str, Any], + normalized_config_global: dict[str, Any], importState: ImportStateController, parsing_config_only: bool, sid: str, is_global_loop_iteration: bool): logger = getFwoLogger() cp_network.normalize_network_objects(nativeConfig, normalized_config_dict, importState.ImportId, mgm_id=importState.MgmDetails.Id) @@ -171,11 +175,11 @@ def normalize_single_manager_config(nativeConfig: dict, native_config_global: di logger.info("completed normalizing rulebases") -def get_rules(nativeConfig: dict, importState: ImportStateController) -> int: +def get_rules(nativeConfig: dict[str, Any], importState: ImportStateController) -> int: """ Main function to get rules. Divided into smaller sub-tasks for better readability and maintainability. """ - show_params_policy_structure = { + show_params_policy_structure: dict[str, Any] = { 'limit': importState.FwoConfig.ApiFetchSize, 'details-level': 'full' } @@ -192,14 +196,14 @@ def get_rules(nativeConfig: dict, importState: ImportStateController) -> int: ) sid: str = cp_getter.login(managerDetails) - policy_structure = [] + policy_structure: list[dict[str, Any]] = [] cp_getter.get_policy_structure( cpManagerApiBaseUrl, sid, show_params_policy_structure, managerDetails, policy_structure=policy_structure ) process_devices( managerDetails, policy_structure, globalAssignments, global_policy_structure, - globalDomain, globalSid, cpManagerApiBaseUrl, sid, nativeConfig['domains'][manager_index], + globalDomain, globalSid, cpManagerApiBaseUrl, sid, nativeConfig['domains'][manager_index], # globalSid should not be None but is when the first manager is not supermanager nativeConfig['domains'][0], importState ) nativeConfig['domains'][manager_index].update({'policies': policy_structure}) @@ -208,18 +212,18 @@ def get_rules(nativeConfig: dict, importState: ImportStateController) -> int: return 0 -def create_ordered_manager_list(importState): +def create_ordered_manager_list(importState: ImportStateController) -> list[ManagementController]: """ creates list of manager details, supermanager is first """ - manager_details_list = [deepcopy(importState.MgmDetails)] + manager_details_list: list[ManagementController] = [deepcopy(importState.MgmDetails)] if importState.MgmDetails.IsSuperManager: for subManager in importState.MgmDetails.SubManagers: - manager_details_list.append(deepcopy(subManager)) + manager_details_list.append(deepcopy(subManager)) # type: ignore TODO: why we are adding submanagers as ManagementController? return manager_details_list -def handle_super_manager(managerDetails, cpManagerApiBaseUrl, show_params_policy_structure):# -> tuple[list[Any], list[Any] | None, Any | Literal[''] | No...: +def handle_super_manager(managerDetails: ManagementController, cpManagerApiBaseUrl: str, show_params_policy_structure: dict[str, Any]) -> tuple[list[Any], None, Any | None, str]: # global assignments are fetched from mds domain mdsSid: str = cp_getter.login(managerDetails) @@ -246,9 +250,9 @@ def handle_super_manager(managerDetails, cpManagerApiBaseUrl, show_params_policy return global_assignments, global_policy_structure, global_domain, global_sid def process_devices( - managerDetails, policy_structure, globalAssignments, global_policy_structure, - globalDomain, globalSid, cpManagerApiBaseUrl, sid, nativeConfigDomain, - nativeConfigGlobalDomain, importState + managerDetails: ManagementController, policy_structure: list[dict[str, Any]], globalAssignments: list[Any] | None, global_policy_structure: list[dict[str, Any]] | None, + globalDomain: str | None, globalSid: str | None, cpManagerApiBaseUrl: str, sid: str, nativeConfigDomain: dict[str, Any], + nativeConfigGlobalDomain: dict[str, Any], importState: ImportStateController ) -> None: logger = getFwoLogger() for device in managerDetails.Devices: @@ -280,7 +284,7 @@ def process_devices( nativeConfigDomain['gateways'].append(deviceConfig) -def initialize_device_config(device) -> dict[str, Any]: +def initialize_device_config(device: dict[str, Any]) -> dict[str, Any]: if 'name' in device and 'uid' in device: return {'name': device['name'], 'uid': device['uid'], 'rulebase_links': []} else: @@ -288,12 +292,18 @@ def initialize_device_config(device) -> dict[str, Any]: def handle_global_rulebase_links( - managerDetails, import_state, deviceConfig, globalAssignments, global_policy_structure, globalDomain, - globalSid, orderedLayerUids, nativeConfigGlobalDomain, cpManagerApiBaseUrl): + managerDetails: ManagementController, import_state: ImportStateController, deviceConfig: dict[str, Any], globalAssignments: list[Any] | None, global_policy_structure: list[dict[str, Any]] | None, globalDomain: str | None, + globalSid: str | None, orderedLayerUids: list[str], nativeConfigGlobalDomain: dict[str, Any], cpManagerApiBaseUrl: str) -> int: """Searches for global access policy for current device policy, adds global ordered layers and defines global rulebase link """ + if globalAssignments is None: + raise FwoImporterError("Global assignments is None in handle_global_rulebase_links") + + if global_policy_structure is None: + raise FwoImporterError("Global policy structure is None in handle_global_rulebase_links") + logger = getFwoLogger() for globalAssignment in globalAssignments: if globalAssignment['dependent-domain']['uid'] == managerDetails.getDomainString(): @@ -311,9 +321,11 @@ def handle_global_rulebase_links( define_global_rulebase_link(deviceConfig, global_ordered_layer_uids, orderedLayerUids, nativeConfigGlobalDomain, global_policy_rulebases_uid_list) return global_ordered_layer_count + + return 0 -def define_global_rulebase_link(deviceConfig, globalOrderedLayerUids, orderedLayerUids, nativeConfigGlobalDomain, global_policy_rulebases_uid_list): +def define_global_rulebase_link(deviceConfig: dict[str, Any], globalOrderedLayerUids: list[str], orderedLayerUids: list[str], nativeConfigGlobalDomain: dict[str, Any], global_policy_rulebases_uid_list: list[str]): """Links initial and placeholder rule for global rulebases """ @@ -346,7 +358,7 @@ def define_global_rulebase_link(deviceConfig, globalOrderedLayerUids, orderedLay placeholder_link_index += 1 -def define_initial_rulebase(deviceConfig, orderedLayerUids, is_global): +def define_initial_rulebase(deviceConfig: dict[str, Any], orderedLayerUids: list[str], is_global: bool): deviceConfig['rulebase_links'].append({ 'from_rulebase_uid': None, 'from_rule_uid': None, @@ -358,7 +370,7 @@ def define_initial_rulebase(deviceConfig, orderedLayerUids, is_global): }) -def get_rules_params(importState): +def get_rules_params(importState: ImportStateController) -> dict[str, Any]: return { 'limit': importState.FwoConfig.ApiFetchSize, 'use-object-dictionary': cp_const.use_object_dictionary, @@ -367,10 +379,10 @@ def get_rules_params(importState): } -def handle_nat_rules(device, nativeConfigDomain, sid, importState): +def handle_nat_rules(device: dict[str, Any], nativeConfigDomain: dict[str, Any], sid: str, importState: ImportStateController): logger = getFwoLogger() if 'package_name' in device and device['package_name']: - show_params_rules = { + show_params_rules: dict[str, Any] = { 'limit': importState.FwoConfig.ApiFetchSize, 'use-object-dictionary': cp_const.use_object_dictionary, 'details-level': 'standard', @@ -390,9 +402,9 @@ def handle_nat_rules(device, nativeConfigDomain, sid, importState): nativeConfigDomain['nat_rulebases'].append({"nat_rule_chunks": []}) -def add_ordered_layers_to_native_config(orderedLayerUids, show_params_rules, - cpManagerApiBaseUrl, sid, nativeConfigDomain, - deviceConfig, is_global, global_ordered_layer_count): +def add_ordered_layers_to_native_config(orderedLayerUids: list[str], show_params_rules: dict[str, Any], + cpManagerApiBaseUrl: str, sid: str | None, nativeConfigDomain: dict[str, Any], + deviceConfig: dict[str, Any], is_global: bool, global_ordered_layer_count: int) -> list[str]: """Fetches ordered layers and links them """ orderedLayerIndex = 0 @@ -426,11 +438,11 @@ def add_ordered_layers_to_native_config(orderedLayerUids, show_params_rules, return policy_rulebases_uid_list -def get_ordered_layer_uids(policy_structure, deviceConfig, domain) -> list[str]: +def get_ordered_layer_uids(policy_structure: list[dict[str, Any]], deviceConfig: dict[str, Any], domain: str | None) -> list[str]: """Get UIDs of ordered layers for policy of device """ - orderedLayerUids = [] + orderedLayerUids: list[str] = [] for policy in policy_structure: foundTargetInPolciy = False for target in policy['targets']: @@ -442,7 +454,7 @@ def get_ordered_layer_uids(policy_structure, deviceConfig, domain) -> list[str]: return orderedLayerUids -def append_access_layer_uid(policy, domain, orderedLayerUids): +def append_access_layer_uid(policy: dict[str, Any], domain: str | None, orderedLayerUids: list[str]) -> None: for accessLayer in policy['access-layers']: if accessLayer['domain'] == domain or domain == '': orderedLayerUids.append(accessLayer['uid']) @@ -483,7 +495,7 @@ def get_objects(native_config_dict: dict[str,Any], importState: ImportStateContr return 0 -def get_objects_per_domain(manager_details, native_domain, obj_type_array, show_params_objs, is_stand_alone_manager=True): +def get_objects_per_domain(manager_details: ManagementController, native_domain: dict[str, Any], obj_type_array: list[str], show_params_objs: dict[str, Any], is_stand_alone_manager: bool=True) -> None: sid = cp_getter.login(manager_details) cp_url = manager_details.buildFwApiString() for obj_type in obj_type_array: @@ -494,7 +506,7 @@ def get_objects_per_domain(manager_details, native_domain, obj_type_array, show_ native_domain['objects'].append(object_table) -def remove_predefined_objects_for_domains(object_table): +def remove_predefined_objects_for_domains(object_table: dict[str, Any]) -> None: if 'chunks' in object_table and 'type' in object_table and \ object_table['type'] in cp_const.types_to_remove_globals_from: return @@ -507,7 +519,7 @@ def remove_predefined_objects_for_domains(object_table): chunk['objects'].remove(obj) -def get_objects_per_type(obj_type, show_params_objs, sid, cpManagerApiBaseUrl): +def get_objects_per_type(obj_type: str, show_params_objs: dict[str, Any], sid: str, cpManagerApiBaseUrl: str) -> dict[str, Any]: logger = getFwoLogger() if fwo_globals.shutdown_requested: @@ -516,7 +528,7 @@ def get_objects_per_type(obj_type, show_params_objs, sid, cpManagerApiBaseUrl): show_params_objs.update({'details-level': cp_const.details_level_group_objects}) else: show_params_objs.update({'details-level': cp_const.details_level_objects}) - object_table = { "type": obj_type, "chunks": [] } + object_table: dict[str, Any] = { "type": obj_type, "chunks": [] } current=0 total=current+1 show_cmd = 'show-' + obj_type @@ -541,7 +553,7 @@ def get_objects_per_type(obj_type, show_params_objs, sid, cpManagerApiBaseUrl): return object_table -def add_special_objects_to_global_domain(object_table, obj_type, sid, cp_api_url): +def add_special_objects_to_global_domain(object_table: dict[str, Any], obj_type: str, sid: str, cp_api_url: str) -> None: """Appends special objects Original, Any, None and Internet to global domain """ # getting Original (NAT) object (both for networks and services) diff --git a/roles/importer/files/importer/ciscoasa9/asa_maps.py b/roles/importer/files/importer/ciscoasa9/asa_maps.py index 54444ecddb..1bae0c3434 100644 --- a/roles/importer/files/importer/ciscoasa9/asa_maps.py +++ b/roles/importer/files/importer/ciscoasa9/asa_maps.py @@ -1,4 +1,7 @@ -name_to_port = { +from typing import Any + + +name_to_port: dict[str, dict[str, Any]] = { "aol": { "port": 5190, "protocols": ["TCP"], diff --git a/roles/importer/files/importer/ciscoasa9/asa_models.py b/roles/importer/files/importer/ciscoasa9/asa_models.py index 3fbe9c526a..9fb5488bb8 100644 --- a/roles/importer/files/importer/ciscoasa9/asa_models.py +++ b/roles/importer/files/importer/ciscoasa9/asa_models.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import List, Union, Optional, Literal, Tuple +from typing import Literal from pydantic import BaseModel @@ -26,7 +26,7 @@ class Interface(BaseModel): security_level: int ip_address: str | None = None subnet_mask: str | None = None - additional_settings: List[str] + additional_settings: list[str] description: str | None = None class AsaNetworkObject(BaseModel): @@ -39,7 +39,7 @@ class AsaNetworkObject(BaseModel): class AsaNetworkObjectGroup(BaseModel): name: str - objects: List[AsaNetworkObjectGroupMember] + objects: list[AsaNetworkObjectGroupMember] description: str | None = None class AsaNetworkObjectGroupMember(BaseModel): @@ -51,21 +51,21 @@ class AsaServiceObject(BaseModel): name: str protocol: Literal["tcp", "udp", "ip", "tcp-udp", "icmp", "gre"] dst_port_eq: str | None = None - dst_port_range: Tuple[str, str] | None = None + dst_port_range: tuple[str, str] | None = None description: str | None = None class AsaServiceObjectGroup(BaseModel): name: str proto_mode: Literal["tcp", "udp", "tcp-udp"] | None ports_eq: dict[str, list[str]] # protocol -> list of ports - ports_range: dict[str, list[Tuple[str, str]]] # protocol -> list of (start_port, end_port) - nested_refs: List[str] - protocols: List[str] + ports_range: dict[str, list[tuple[str, str]]] # protocol -> list of (start_port, end_port) + nested_refs: list[str] + protocols: list[str] description: str | None class AsaProtocolGroup(BaseModel): name: str - protocols: List[str] + protocols: list[str] description: str | None = None class EndpointKind(BaseModel): @@ -85,7 +85,7 @@ class AccessListEntry(BaseModel): class AccessList(BaseModel): name: str - entries: List[AccessListEntry] + entries: list[AccessListEntry] class AccessGroupBinding(BaseModel): acl_name: str @@ -115,7 +115,7 @@ class MgmtAccessRule(BaseModel): class ClassMap(BaseModel): name: str - matches: List[str] = [] # e.g., ["default-inspection-traffic"] + matches: list[str] = [] # e.g., ["default-inspection-traffic"] class DnsInspectParameters(BaseModel): message_length_max_client: Literal["auto", "default"] | int | None = None @@ -128,13 +128,13 @@ class InspectionAction(BaseModel): class PolicyClass(BaseModel): class_name: str # e.g., "inspection_default" - inspections: List[InspectionAction] = [] + inspections: list[InspectionAction] = [] class PolicyMap(BaseModel): name: str # e.g., "global_policy" or "preset_dns_map" type_str: str | None = None # e.g., "inspect dns" for typed maps parameters_dns: DnsInspectParameters | None = None - classes: List[PolicyClass] = [] + classes: list[PolicyClass] = [] class ServicePolicyBinding(BaseModel): policy_map: str # e.g., "global_policy" @@ -145,20 +145,20 @@ class Config(BaseModel): asa_version: str hostname: str enable_password: AsaEnablePassword - service_modules: List[AsaServiceModule] - additional_settings: List[str] - interfaces: List[Interface] - objects: List[AsaNetworkObject] - object_groups: List[AsaNetworkObjectGroup] - service_objects: List[AsaServiceObject] = [] - service_object_groups: List[AsaServiceObjectGroup] = [] - access_lists: List[AccessList] = [] - access_group_bindings: List[AccessGroupBinding] = [] - nat_rules: List[NatRule] = [] - routes: List[Route] = [] - mgmt_access: List[MgmtAccessRule] = [] - names: List[Names] = [] - class_maps: List[ClassMap] = [] - policy_maps: List[PolicyMap] = [] - service_policies: List[ServicePolicyBinding] = [] - protocol_groups: List[AsaProtocolGroup] = [] \ No newline at end of file + service_modules: list[AsaServiceModule] + additional_settings: list[str] + interfaces: list[Interface] + objects: list[AsaNetworkObject] + object_groups: list[AsaNetworkObjectGroup] + service_objects: list[AsaServiceObject] = [] + service_object_groups: list[AsaServiceObjectGroup] = [] + access_lists: list[AccessList] = [] + access_group_bindings: list[AccessGroupBinding] = [] + nat_rules: list[NatRule] = [] + routes: list[Route] = [] + mgmt_access: list[MgmtAccessRule] = [] + names: list[Names] = [] + class_maps: list[ClassMap] = [] + policy_maps: list[PolicyMap] = [] + service_policies: list[ServicePolicyBinding] = [] + protocol_groups: list[AsaProtocolGroup] = [] \ No newline at end of file diff --git a/roles/importer/files/importer/ciscoasa9/asa_network.py b/roles/importer/files/importer/ciscoasa9/asa_network.py index 50a30ddc58..3ffc5303a7 100644 --- a/roles/importer/files/importer/ciscoasa9/asa_network.py +++ b/roles/importer/files/importer/ciscoasa9/asa_network.py @@ -6,7 +6,7 @@ inline ACL or group definitions. """ -from typing import Dict, List, Optional +from logging import Logger from netaddr import IPAddress, IPNetwork from ciscoasa9.asa_models import AsaNetworkObject, AsaNetworkObjectGroup, AsaNetworkObjectGroupMember, EndpointKind, Names from models.networkobject import NetworkObject @@ -14,7 +14,7 @@ import fwo_base -def create_network_host(name: str, ip_address: str, comment: Optional[str], ip_version: int) -> NetworkObject: +def create_network_host(name: str, ip_address: str, comment: str | None, ip_version: int) -> NetworkObject: """Create a normalized host network object. Args: @@ -41,7 +41,7 @@ def create_network_host(name: str, ip_address: str, comment: Optional[str], ip_v ) -def create_network_subnet(name: str, ip_address: str, subnet_mask: Optional[str], comment: Optional[str], ip_version: int) -> NetworkObject: +def create_network_subnet(name: str, ip_address: str, subnet_mask: str | None, comment: str | None, ip_version: int) -> NetworkObject: """Create a normalized network object. Args: @@ -77,7 +77,7 @@ def create_network_subnet(name: str, ip_address: str, subnet_mask: Optional[str] ) -def create_network_range(name: str, ip_start: str, ip_end: str, comment: Optional[str]) -> NetworkObject: +def create_network_range(name: str, ip_start: str, ip_end: str, comment: str | None) -> NetworkObject: """Create a normalized range network object. Args: @@ -100,7 +100,7 @@ def create_network_range(name: str, ip_start: str, ip_end: str, comment: Optiona ) -def create_network_group_object(name: str, member_refs: List[str], comment: Optional[str] = None) -> NetworkObject: +def create_network_group_object(name: str, member_refs: list[str], comment: str | None = None) -> NetworkObject: """Create a network group object. Args: @@ -141,7 +141,7 @@ def create_any_network_object() -> NetworkObject: ) -def normalize_names(names: List[Names]) -> Dict[str, NetworkObject]: +def normalize_names(names: list[Names]) -> dict[str, NetworkObject]: """Normalize 'names' entries (simple IP-to-name mappings). Args: @@ -150,7 +150,7 @@ def normalize_names(names: List[Names]) -> Dict[str, NetworkObject]: Returns: Dictionary of normalized network objects keyed by obj_uid """ - network_objects = {} + network_objects: dict[str, NetworkObject] = {} for name in names: obj = create_network_host(name.name, name.ip_address, name.description, ip_version=4) @@ -159,7 +159,7 @@ def normalize_names(names: List[Names]) -> Dict[str, NetworkObject]: return network_objects -def normalize_network_objects(network_objects_list: List[AsaNetworkObject]) -> Dict[str, NetworkObject]: +def normalize_network_objects(network_objects_list: list[AsaNetworkObject]) -> dict[str, NetworkObject]: """Normalize network objects from ASA configuration. Args: @@ -168,7 +168,7 @@ def normalize_network_objects(network_objects_list: List[AsaNetworkObject]) -> D Returns: Dictionary of normalized network objects keyed by obj_uid """ - network_objects = {} + network_objects: dict[str, NetworkObject] = {} for obj in network_objects_list: if obj.fqdn is not None: @@ -190,9 +190,9 @@ def normalize_network_objects(network_objects_list: List[AsaNetworkObject]) -> D return network_objects -def normalize_network_object_groups(object_groups: List[AsaNetworkObjectGroup], - network_objects: Dict[str, NetworkObject], - logger) -> Dict[str, NetworkObject]: +def normalize_network_object_groups(object_groups: list[AsaNetworkObjectGroup], + network_objects: dict[str, NetworkObject], + logger: Logger) -> dict[str, NetworkObject]: """Normalize network object groups from ASA configuration. Args: @@ -204,7 +204,7 @@ def normalize_network_object_groups(object_groups: List[AsaNetworkObjectGroup], Updated network objects dictionary including groups """ for group in object_groups: - member_refs = [] + member_refs: list[str] = [] for member in group.objects: try: @@ -223,7 +223,7 @@ def normalize_network_object_groups(object_groups: List[AsaNetworkObjectGroup], return network_objects -def _get_network_group_member_host(member: AsaNetworkObjectGroupMember) -> NetworkObject: +def get_network_group_member_host(member: AsaNetworkObjectGroupMember) -> NetworkObject: """Create a host network object for a network object group member. Args: @@ -274,7 +274,7 @@ def create_network_group_member(ref: str, member: AsaNetworkObjectGroupMember) - raise ValueError(f"Unsupported member kind '{member.kind}' in network object group.") -def get_network_group_member(member: AsaNetworkObjectGroupMember, network_objects: Dict[str, NetworkObject]) -> NetworkObject: +def get_network_group_member(member: AsaNetworkObjectGroupMember, network_objects: dict[str, NetworkObject]) -> NetworkObject: """Get network object for a network object group member reference. If it does not exist, create it. Args: @@ -295,7 +295,7 @@ def get_network_group_member(member: AsaNetworkObjectGroupMember, network_object return network_object -def get_network_rule_endpoint(endpoint: EndpointKind, network_objects: Dict[str, NetworkObject]) -> NetworkObject: +def get_network_rule_endpoint(endpoint: EndpointKind, network_objects: dict[str, NetworkObject]) -> NetworkObject: """Get network object for a rule endpoint. If it does not exist, create it. Args: diff --git a/roles/importer/files/importer/ciscoasa9/asa_normalize.py b/roles/importer/files/importer/ciscoasa9/asa_normalize.py index 23f24ab40a..bf45b76ac0 100644 --- a/roles/importer/files/importer/ciscoasa9/asa_normalize.py +++ b/roles/importer/files/importer/ciscoasa9/asa_normalize.py @@ -5,12 +5,18 @@ format used by the firewall orchestrator. """ +from logging import Logger from fwo_log import getFwoLogger from models.fwconfig_normalized import FwConfigNormalized from ciscoasa9.asa_models import Config from fwo_enums import ConfigAction from models.gateway import Gateway from models.rulebase_link import RulebaseLinkUidBased +from ciscoasa9.asa_rule import build_rulebases_from_access_lists +from models.networkobject import NetworkObject +from models.serviceobject import ServiceObject +from model_controllers.fwconfigmanagerlist_controller import FwConfigManagerListController +from model_controllers.import_state_controller import ImportStateController # Import the new modular functions from ciscoasa9.asa_network import ( @@ -23,10 +29,12 @@ create_protocol_any_service_objects, normalize_service_object_groups ) -from ciscoasa9.asa_rule import build_rulebases_from_access_lists -def normalize_all_network_objects(native_config: Config, logger) -> dict: + + + +def normalize_all_network_objects(native_config: Config, logger: Logger) -> dict[str, NetworkObject]: """ Normalize all network objects from the native ASA configuration. @@ -54,7 +62,7 @@ def normalize_all_network_objects(native_config: Config, logger) -> dict: return network_objects -def normalize_all_service_objects(native_config: Config) -> dict: +def normalize_all_service_objects(native_config: Config) -> dict[str, ServiceObject]: """ Normalize all service objects from the native ASA configuration. @@ -81,7 +89,7 @@ def normalize_all_service_objects(native_config: Config) -> dict: return service_objects -def normalize_config(config_in, import_state): +def normalize_config(config_in: FwConfigManagerListController, import_state: ImportStateController) -> FwConfigManagerListController: """ Normalize the ASA configuration into a structured format for the database. @@ -125,7 +133,7 @@ def normalize_config(config_in, import_state): ) # Step 4: Create rulebase links (ordered chain of rulebases) - rulebase_links = [] + rulebase_links: list[RulebaseLinkUidBased] = [] if len(rulebases) > 0: # First rulebase is the initial entry point rulebase_links.append(RulebaseLinkUidBased( diff --git a/roles/importer/files/importer/ciscoasa9/asa_parser.py b/roles/importer/files/importer/ciscoasa9/asa_parser.py index 41ad4972ec..57fcbcf655 100644 --- a/roles/importer/files/importer/ciscoasa9/asa_parser.py +++ b/roles/importer/files/importer/ciscoasa9/asa_parser.py @@ -1,26 +1,26 @@ import json import re from pathlib import Path -from typing import List, Optional +from typing import Callable from ciscoasa9.asa_models import AccessGroupBinding, AccessList, AccessListEntry, AsaEnablePassword,\ AsaNetworkObject, AsaNetworkObjectGroup, AsaProtocolGroup, AsaServiceModule, AsaServiceObject, AsaServiceObjectGroup,\ - ClassMap, Config, DnsInspectParameters, EndpointKind, InspectionAction, Interface, MgmtAccessRule,\ - Names, NatRule, PolicyClass, PolicyMap, Route, ServicePolicyBinding -from ciscoasa9.asa_parser_functions import _clean_lines, _consume_block, _parse_class_map_block, \ - _parse_dns_inspect_policy_map_block, _parse_icmp_object_group_block, _parse_interface_block, _parse_network_object_block, \ - _parse_network_object_group_block, _parse_policy_map_block, _parse_service_object_block, \ - _parse_service_object_group_block, _parse_endpoint, _parse_protocol_object_group_block, \ - _parse_access_list_entry + ClassMap, Config, Interface, MgmtAccessRule,\ + Names, NatRule, PolicyMap, Route, ServicePolicyBinding +from ciscoasa9.asa_parser_functions import clean_lines, consume_block, parse_class_map_block, \ + parse_dns_inspect_policy_map_block, parse_icmp_object_group_block, parse_interface_block, parse_network_object_block, \ + parse_network_object_group_block, parse_policy_map_block, parse_service_object_block, \ + parse_service_object_group_block, parse_protocol_object_group_block, \ + parse_access_list_entry def parse_asa_config(raw_config: str) -> Config: - lines = _clean_lines(raw_config) + lines = clean_lines(raw_config) # Initialize state state = _ParserState() # Handler registry: (pattern, handler_function) - handlers = [ + handlers: list[tuple[re.Pattern[str], Callable[[re.Match[str], str, list[str], int, _ParserState], int]]] = [ (re.compile(r"^ASA Version\s+(\S+)$", re.I), _handle_asa_version), (re.compile(r"^hostname\s+(\S+)$", re.I), _handle_hostname), (re.compile(r"^enable password\s+(\S+)\s+(\S+)$", re.I), _handle_enable_password), @@ -71,42 +71,42 @@ class _ParserState: def __init__(self): self.asa_version = "" self.hostname = "" - self.enable_password: Optional[AsaEnablePassword] = None - self.service_modules: List[AsaServiceModule] = [] - self.names: List[Names] = [] - self.interfaces: List[Interface] = [] - self.net_objects: List[AsaNetworkObject] = [] - self.net_obj_groups: List[AsaNetworkObjectGroup] = [] - self.svc_objects: List[AsaServiceObject] = [] - self.svc_obj_groups: List[AsaServiceObjectGroup] = [] - self.access_lists_map: dict[str, List[AccessListEntry]] = {} - self.access_groups: List[AccessGroupBinding] = [] - self.nat_rules: List[NatRule] = [] - self.routes: List[Route] = [] - self.mgmt_access: List[MgmtAccessRule] = [] - self.additional_settings: List[str] = [] - self.class_maps: List[ClassMap] = [] + self.enable_password: AsaEnablePassword | None = None + self.service_modules: list[AsaServiceModule] = [] + self.names: list[Names] = [] + self.interfaces: list[Interface] = [] + self.net_objects: list[AsaNetworkObject] = [] + self.net_obj_groups: list[AsaNetworkObjectGroup] = [] + self.svc_objects: list[AsaServiceObject] = [] + self.svc_obj_groups: list[AsaServiceObjectGroup] = [] + self.access_lists_map: dict[str, list[AccessListEntry]] = {} + self.access_groups: list[AccessGroupBinding] = [] + self.nat_rules: list[NatRule] = [] + self.routes: list[Route] = [] + self.mgmt_access: list[MgmtAccessRule] = [] + self.additional_settings: list[str] = [] + self.class_maps: list[ClassMap] = [] self.policy_maps: dict[str, PolicyMap] = {} - self.service_policies: List[ServicePolicyBinding] = [] - self.protocol_groups: List[AsaProtocolGroup] = [] + self.service_policies: list[ServicePolicyBinding] = [] + self.protocol_groups: list[AsaProtocolGroup] = [] -def _handle_asa_version(match, line, lines, i, state): +def _handle_asa_version(match: re.Match[str], line: str, lines: list[str], i: int, state: _ParserState) -> int: state.asa_version = match.group(1).strip() return i + 1 -def _handle_hostname(match, line, lines, i, state): +def _handle_hostname(match: re.Match[str], line: str, lines: list[str], i: int, state: _ParserState) -> int: state.hostname = match.group(1) return i + 1 -def _handle_enable_password(match, line, lines, i, state): +def _handle_enable_password(match: re.Match[str], line: str, lines: list[str], i: int, state: _ParserState) -> int: state.enable_password = AsaEnablePassword(password=match.group(1), encryption_function=match.group(2)) return i + 1 -def _handle_service_module_timeout(match, line, lines, i, state): +def _handle_service_module_timeout(match: re.Match[str], line: str, lines: list[str], i: int, state: _ParserState) -> int: name = match.group(1) timeout = int(match.group(2)) keepalive_counter = _find_keepalive_counter(lines, i, name) @@ -114,7 +114,7 @@ def _handle_service_module_timeout(match, line, lines, i, state): return i + 1 -def _find_keepalive_counter(lines, i, name): +def _find_keepalive_counter(lines: list[str], i: int, name: str) -> int: for j in range(i + 1, min(i + 5, len(lines))): m = re.match(rf"^service-module\s+{re.escape(name)}\s+keepalive-counter\s+(\d+)$", lines[j].strip(), re.I) if m: @@ -122,26 +122,26 @@ def _find_keepalive_counter(lines, i, name): return 0 -def _handle_service_module_counter(match, line, lines, i, state): +def _handle_service_module_counter(match: re.Match[str], line: str, lines: list[str], i: int, state: _ParserState) -> int: return i + 1 -def _handle_name(match, line, lines, i, state): +def _handle_name(match: re.Match[str], line: str, lines: list[str], i: int, state: _ParserState) -> int: ip, alias = match.group(1), match.group(2) desc = line[match.end():].strip() or None state.names.append(Names(name=alias, ip_address=ip, description=desc)) return i + 1 -def _handle_interface_block(match, line, lines, i, state): - block, new_i = _consume_block(lines, i) - state.interfaces.append(_parse_interface_block(block)) +def _handle_interface_block(match: re.Match[str], line: str, lines: list[str], i: int, state: _ParserState) -> int: + block, new_i = consume_block(lines, i) + state.interfaces.append(parse_interface_block(block)) return new_i -def _handle_network_object_block(match, line, lines, i, state): - block, new_i = _consume_block(lines, i) - net_obj, pending_nat = _parse_network_object_block(block) +def _handle_network_object_block(match: re.Match[str], line: str, lines: list[str], i: int, state: _ParserState) -> int: + block, new_i = consume_block(lines, i) + net_obj, pending_nat = parse_network_object_block(block) if net_obj: state.net_objects.append(net_obj) if pending_nat: @@ -149,48 +149,48 @@ def _handle_network_object_block(match, line, lines, i, state): return new_i -def _handle_network_object_group_block(match, line, lines, i, state): - block, new_i = _consume_block(lines, i) - state.net_obj_groups.append(_parse_network_object_group_block(block)) +def _handle_network_object_group_block(match: re.Match[str], line: str, lines: list[str], i: int, state: _ParserState) -> int: + block, new_i = consume_block(lines, i) + state.net_obj_groups.append(parse_network_object_group_block(block)) return new_i -def _handle_service_object_block(match, line, lines, i, state): - block, new_i = _consume_block(lines, i) - svc_obj = _parse_service_object_block(block) +def _handle_service_object_block(match: re.Match[str], line: str, lines: list[str], i: int, state: _ParserState) -> int: + block, new_i = consume_block(lines, i) + svc_obj = parse_service_object_block(block) if svc_obj: state.svc_objects.append(svc_obj) return new_i -def _handle_service_object_group(match, line, lines, i, state): - block, new_i = _consume_block(lines, i) - state.svc_obj_groups.append(_parse_service_object_group_block(block)) +def _handle_service_object_group(match: re.Match[str], line: str, lines: list[str], i: int, state: _ParserState) -> int: + block, new_i = consume_block(lines, i) + state.svc_obj_groups.append(parse_service_object_group_block(block)) return new_i -def _handle_icmp_object_group_block(match, line, lines, i, state): - block, new_i = _consume_block(lines, i) - state.svc_obj_groups.append(_parse_icmp_object_group_block(block)) +def _handle_icmp_object_group_block(match: re.Match[str], line: str, lines: list[str], i: int, state: _ParserState) -> int: + block, new_i = consume_block(lines, i) + state.svc_obj_groups.append(parse_icmp_object_group_block(block)) return new_i -def _handle_protocol_object_group_block(match, line, lines, i, state): - block, new_i = _consume_block(lines, i) - state.protocol_groups.append(_parse_protocol_object_group_block(block)) +def _handle_protocol_object_group_block(match: re.Match[str], line: str, lines: list[str], i: int, state: _ParserState) -> int: + block, new_i = consume_block(lines, i) + state.protocol_groups.append(parse_protocol_object_group_block(block)) return new_i -def _handle_access_list_entry(match, line, lines, i, state): +def _handle_access_list_entry(match: re.Match[str], line: str, lines: list[str], i: int, state: _ParserState) -> int: try: - entry = _parse_access_list_entry(line, state.protocol_groups, state.svc_objects, state.svc_obj_groups) + entry = parse_access_list_entry(line, state.protocol_groups, state.svc_objects, state.svc_obj_groups) state.access_lists_map.setdefault(entry.acl_name, []).append(entry) except Exception as e: print(f"Warning: Failed to parse access-list line: {line}. Error: {e}") return i + 1 -def _handle_access_group(match, line, lines, i, state): +def _handle_access_group(match: re.Match[str], line: str, lines: list[str], i: int, state: _ParserState) -> int: direction = match.group(2) if direction not in ("in", "out"): raise ValueError(f"Invalid direction value: {direction}") @@ -198,7 +198,7 @@ def _handle_access_group(match, line, lines, i, state): return i + 1 -def _handle_route(match, line, lines, i, state): +def _handle_route(match: re.Match[str], line: str, lines: list[str], i: int, state: _ParserState) -> int: state.routes.append(Route( interface=match.group(1), destination=match.group(2), @@ -209,7 +209,7 @@ def _handle_route(match, line, lines, i, state): return i + 1 -def _handle_mgmt_access(match, line, lines, i, state): +def _handle_mgmt_access(match: re.Match[str], line: str, lines: list[str], i: int, state: _ParserState) -> int: protocol_str = match.group(1).lower() if protocol_str not in ("http", "ssh", "telnet"): raise ValueError(f"Invalid protocol for MgmtAccessRule: {protocol_str}") @@ -217,29 +217,29 @@ def _handle_mgmt_access(match, line, lines, i, state): return i + 1 -def _handle_class_map_block(match, line, lines, i, state): - block, new_i = _consume_block(lines, i) - state.class_maps.append(_parse_class_map_block(block)) +def _handle_class_map_block(match: re.Match[str], line: str, lines: list[str], i: int, state: _ParserState) -> int: + block, new_i = consume_block(lines, i) + state.class_maps.append(parse_class_map_block(block)) return new_i -def _handle_dns_inspect_policy_map_block(match, line, lines, i, state): - block, new_i = _consume_block(lines, i) +def _handle_dns_inspect_policy_map_block(match: re.Match[str], line: str, lines: list[str], i: int, state: _ParserState) -> int: + block, new_i = consume_block(lines, i) pm_name = match.group(1) - pm = _parse_dns_inspect_policy_map_block(block, pm_name) + pm = parse_dns_inspect_policy_map_block(block, pm_name) state.policy_maps[pm_name] = pm return new_i -def _handle_policy_map_block(match, line, lines, i, state): - block, new_i = _consume_block(lines, i) +def _handle_policy_map_block(match: re.Match[str], line: str, lines: list[str], i: int, state: _ParserState) -> int: + block, new_i = consume_block(lines, i) pm_name = match.group(1) - pm = _parse_policy_map_block(block, pm_name) + pm = parse_policy_map_block(block, pm_name) state.policy_maps[pm_name] = pm return new_i -def _handle_service_policy(match, line, lines, i, state): +def _handle_service_policy(match: re.Match[str], line: str, lines: list[str], i: int, state: _ParserState) -> int: pm_name = match.group(1) scope_part = match.group(2).lower() if scope_part == "global": @@ -250,7 +250,7 @@ def _handle_service_policy(match, line, lines, i, state): return i + 1 -def _handle_additional_settings(line, state): +def _handle_additional_settings(line: str, state: _ParserState) -> None: interesting_prefixes = ( "ftp mode", "same-security-traffic", "dynamic-access-policy-record", "service-policy", "user-identity", "aaa ", "icmp ", "arp ", "ssh version", diff --git a/roles/importer/files/importer/ciscoasa9/asa_parser_functions.py b/roles/importer/files/importer/ciscoasa9/asa_parser_functions.py index b952fca45c..4131e5edad 100644 --- a/roles/importer/files/importer/ciscoasa9/asa_parser_functions.py +++ b/roles/importer/files/importer/ciscoasa9/asa_parser_functions.py @@ -1,16 +1,13 @@ -import json import re -from pathlib import Path -from typing import Dict, List, Optional, Union, Literal, Tuple -from ciscoasa9.asa_models import AccessGroupBinding, AccessList, AccessListEntry, AsaEnablePassword,\ - AsaNetworkObject, AsaNetworkObjectGroup, AsaNetworkObjectGroupMember, AsaServiceModule, AsaServiceObject, AsaServiceObjectGroup,\ - ClassMap, Config, DnsInspectParameters, EndpointKind, InspectionAction, Interface, MgmtAccessRule,\ - Names, NatRule, PolicyClass, PolicyMap, Route, ServicePolicyBinding, AsaProtocolGroup +from ciscoasa9.asa_models import AccessListEntry,\ + AsaNetworkObject, AsaNetworkObjectGroup, AsaNetworkObjectGroupMember, AsaServiceObject, AsaServiceObjectGroup,\ + ClassMap, DnsInspectParameters, EndpointKind, InspectionAction, Interface,\ + NatRule, PolicyClass, PolicyMap, AsaProtocolGroup from fwo_log import getFwoLogger -def _clean_lines(text: str) -> List[str]: - lines = [] +def clean_lines(text: str) -> list[str]: + lines: list[str] = [] for raw in text.splitlines(): line = raw.rstrip() # Skip leading metadata/comment lines starting with ':' (as in "show run") @@ -19,7 +16,7 @@ def _clean_lines(text: str) -> List[str]: lines.append(line) return lines -def _consume_block(lines: List[str], start_idx: int) -> Tuple[List[str], int]: +def consume_block(lines: list[str], start_idx: int) -> tuple[list[str], int]: """ Consume a block that starts at start_idx (matching start_re) and continues until next top-level directive (blank line or line not starting with space) @@ -40,7 +37,7 @@ def _consume_block(lines: List[str], start_idx: int) -> Tuple[List[str], int]: break return block, i -def _parse_endpoint(tokens: List[str]) -> Tuple[EndpointKind, int]: +def parse_endpoint(tokens: list[str]) -> tuple[EndpointKind, int]: """ Parse an ACL endpoint from tokens; returns (EndpointKind, tokens_consumed). Supported: @@ -70,12 +67,12 @@ def _parse_endpoint(tokens: List[str]) -> Tuple[EndpointKind, int]: return EndpointKind(kind="any", value="any"), 1 -def _find_description(blocks: List[str]) -> Optional[str]: +def _find_description(blocks: list[str]) -> str | None: """Helper to find description line in a block.""" return _find_line_with_prefix(list(blocks), "description ") -def _find_line_with_prefix(block: List[str], prefix: str, only_first: bool = False) -> Optional[str]: +def _find_line_with_prefix(block: list[str], prefix: str, only_first: bool = False) -> str | None: """Helper to find a single value in an interface block by prefix.""" v = None for b in list(block): @@ -89,7 +86,7 @@ def _find_line_with_prefix(block: List[str], prefix: str, only_first: bool = Fal return v -def _parse_interface_block_find_ip_address(block: List[str], prefix: str) -> Tuple[Optional[str], Optional[str]]: +def _parse_interface_block_find_ip_address(block: list[str], prefix: str) -> tuple[str | None, str | None]: """Helper to find IP address and mask in an interface block.""" ip = None mask = None @@ -105,7 +102,7 @@ def _parse_interface_block_find_ip_address(block: List[str], prefix: str) -> Tup return ip, mask -def _parse_interface_block(block: List[str]) -> Interface: +def parse_interface_block(block: list[str]) -> Interface: """Parse an interface block and return an Interface object.""" if_name = block[0].split()[1] blocks = list(block)[1:] @@ -130,12 +127,12 @@ def _parse_interface_block(block: List[str]) -> Interface: def _create_network_object_from_parts( name: str, - host: Optional[str], - subnet: Optional[str], - mask: Optional[str], - ip_range: Optional[Tuple[str, str]], - fqdn: Optional[str], - description: Optional[str] + host: str | None, + subnet: str | None, + mask: str | None, + ip_range: tuple[str, str] | None, + fqdn: str | None, + description: str | None ) -> AsaNetworkObject|None: """Helper to create AsaNetworkObject from parts.""" if host and not subnet: @@ -152,7 +149,7 @@ def _create_network_object_from_parts( return None -def _parse_network_object_block(block: List[str]) -> Tuple[Optional[AsaNetworkObject], Optional[NatRule]]: +def parse_network_object_block(block: list[str]) -> tuple[AsaNetworkObject | None, NatRule | None]: """Parse an object network block. Returns (network_object, nat_rule).""" obj_name = block[0].split()[2] host = None @@ -209,11 +206,11 @@ def _parse_network_object_block(block: List[str]) -> Tuple[Optional[AsaNetworkOb return net_obj, pending_nat -def _parse_network_object_group_block(block: List[str]) -> AsaNetworkObjectGroup: +def parse_network_object_group_block(block: list[str]) -> AsaNetworkObjectGroup: """Parse an object-group network block.""" grp_name = block[0].split()[2] desc = _find_description(block[1:]) - members: List[AsaNetworkObjectGroupMember] = [] + members: list[AsaNetworkObjectGroupMember] = [] for b in block[1:]: s = b.strip() @@ -247,7 +244,7 @@ def _parse_network_object_group_block(block: List[str]) -> AsaNetworkObjectGroup return AsaNetworkObjectGroup(name=grp_name, objects=members, description=desc) -def _parse_service_object_block(block: List[str]) -> AsaServiceObject | None: +def parse_service_object_block(block: list[str]) -> AsaServiceObject | None: """Parse an object service block.""" name = block[0].split()[2] protocol = None @@ -284,20 +281,20 @@ def _parse_service_object_block(block: List[str]) -> AsaServiceObject | None: def _convert_ports_to_dicts( - ports_eq: List[Tuple[str, str]], - ports_range: List[Tuple[str, Tuple[str, str]]] -) -> Tuple[Dict[str, List[str]], Dict[str, List[Tuple[str, str]]]]: + ports_eq: list[tuple[str, str]], + ports_range: list[tuple[str, tuple[str, str]]] +) -> tuple[dict[str, list[str]], dict[str, list[tuple[str, str]]]]: """ Convert port lists to dictionaries grouped by protocol. Returns (ports_eq_dict, ports_range_dict). """ - ports_eq_dict: Dict[str, List[str]] = {} + ports_eq_dict: dict[str, list[str]] = {} for proto, port in ports_eq: if proto not in ports_eq_dict: ports_eq_dict[proto] = [] ports_eq_dict[proto].append(port) - ports_range_dict: Dict[str, List[Tuple[str, str]]] = {} + ports_range_dict: dict[str, list[tuple[str, str]]] = {} for proto, prange in ports_range: if proto not in ports_range_dict: ports_range_dict[proto] = [] @@ -305,10 +302,10 @@ def _convert_ports_to_dicts( return ports_eq_dict, ports_range_dict -def _consume_port_objects(service_group_block: List[str], proto_mode: str) -> Tuple[List[Tuple[str, str]], List[Tuple[str, Tuple[str, str]]]]: +def _consume_port_objects(service_group_block: list[str], proto_mode: str) -> tuple[list[tuple[str, str]], list[tuple[str, tuple[str, str]]]]: """Helper to consume port-object lines from a service object group block.""" - ports_eq: List[Tuple[str, str]] = [] - ports_range: List[Tuple[str, Tuple[str, str]]] = [] + ports_eq: list[tuple[str, str]] = [] + ports_range: list[tuple[str, tuple[str, str]]] = [] for b in list(service_group_block): s = b.strip() @@ -325,11 +322,11 @@ def _consume_port_objects(service_group_block: List[str], proto_mode: str) -> Tu return ports_eq, ports_range -def _consume_service_definitions(service_group_block: List[str]) -> Tuple[List[Tuple[str, str]], List[Tuple[str, Tuple[str, str]]], List[str]]: +def _consume_service_definitions(service_group_block: list[str]) -> tuple[list[tuple[str, str]], list[tuple[str, tuple[str, str]]], list[str]]: """Helper to consume service-object definitions from a service object group block.""" - ports_eq: List[Tuple[str, str]] = [] - ports_range: List[Tuple[str, Tuple[str, str]]] = [] - protocols: List[str] = [] # list of fully enabled protocols + ports_eq: list[tuple[str, str]] = [] + ports_range: list[tuple[str, tuple[str, str]]] = [] + protocols: list[str] = [] # list of fully enabled protocols for b in list(service_group_block): s = b.strip() @@ -348,9 +345,9 @@ def _consume_service_definitions(service_group_block: List[str]) -> Tuple[List[T return ports_eq, ports_range, protocols -def _consume_service_references(service_group_block: List[str]) -> List[str]: +def _consume_service_references(service_group_block: list[str]) -> list[str]: """Helper to consume service-object and group-object lines from a service object group block.""" - nested_refs: List[str] = [] + nested_refs: list[str] = [] for b in service_group_block: s = b.strip() @@ -367,7 +364,7 @@ def _consume_service_references(service_group_block: List[str]) -> List[str]: return nested_refs -def _parse_service_object_group_block(block: List[str]) -> AsaServiceObjectGroup: +def parse_service_object_group_block(block: list[str]) -> AsaServiceObjectGroup: """Parse an object-group service block.""" hdr = block[0].split() name = hdr[2] @@ -384,10 +381,10 @@ def _parse_service_object_group_block(block: List[str]) -> AsaServiceObjectGroup desc = _find_description(block[1:]) - ports_eq: List[Tuple[str, str]] = [] - ports_range: List[Tuple[str, Tuple[str, str]]] = [] - nested_refs: List[str] = [] - protocols: List[str] = [] + ports_eq: list[tuple[str, str]] = [] + ports_range: list[tuple[str, tuple[str, str]]] = [] + nested_refs: list[str] = [] + protocols: list[str] = [] if proto_mode: ports_eq, ports_range = _consume_port_objects(block[1:], proto_mode) @@ -409,10 +406,10 @@ def _parse_service_object_group_block(block: List[str]) -> AsaServiceObjectGroup ) -def _parse_class_map_block(block: List[str]) -> ClassMap: +def parse_class_map_block(block: list[str]) -> ClassMap: """Parse a class-map block.""" name = block[0].split()[1] - matches: List[str] = [] + matches: list[str] = [] for b in block[1:]: s = b.strip() @@ -423,7 +420,7 @@ def _parse_class_map_block(block: List[str]) -> ClassMap: return ClassMap(name=name, matches=matches) -def _parse_dns_parameters_block(block: List[str], start_idx: int) -> Tuple[DnsInspectParameters, int]: +def _parse_dns_parameters_block(block: list[str], start_idx: int) -> tuple[DnsInspectParameters, int]: """ Parse a 'parameters' sub-block within a DNS inspect policy-map. Returns (DnsInspectParameters, next_index). @@ -447,7 +444,7 @@ def _parse_dns_parameters_block(block: List[str], start_idx: int) -> Tuple[DnsIn return params, k -def _parse_dns_inspect_policy_map_block(block: List[str], pm_name: str) -> PolicyMap: +def parse_dns_inspect_policy_map_block(block: list[str], pm_name: str) -> PolicyMap: """Parse a policy-map type inspect dns block.""" pm = PolicyMap(name=pm_name, type_str="inspect dns") params = DnsInspectParameters() @@ -464,7 +461,7 @@ def _parse_dns_inspect_policy_map_block(block: List[str], pm_name: str) -> Polic return pm -def _parse_policy_class_block(block: List[str], start_idx: int) -> Tuple[Optional[PolicyClass], int]: +def _parse_policy_class_block(block: list[str], start_idx: int) -> tuple[PolicyClass | None, int]: """ Parse a 'class ' sub-block starting at start_idx. Returns (PolicyClass or None, next_index). @@ -478,7 +475,7 @@ def _parse_policy_class_block(block: List[str], start_idx: int) -> Tuple[Optiona return None, start_idx + 1 class_name = mc.group(1) - inspections: List[InspectionAction] = [] + inspections: list[InspectionAction] = [] idx = start_idx + 1 # collect lines under this class (1 indent) while idx < len(block) and block[idx].startswith(" "): @@ -497,7 +494,7 @@ def _parse_policy_class_block(block: List[str], start_idx: int) -> Tuple[Optiona return PolicyClass(class_name=class_name, inspections=inspections), idx -def _parse_policy_map_block(block: List[str], pm_name: str) -> PolicyMap: +def parse_policy_map_block(block: list[str], pm_name: str) -> PolicyMap: """Parse a regular policy-map block.""" pm = PolicyMap(name=pm_name) @@ -511,10 +508,10 @@ def _parse_policy_map_block(block: List[str], pm_name: str) -> PolicyMap: return pm -def _parse_access_list_entry_protocol(parts: List[str], protocol_groups: List[AsaProtocolGroup], svc_objects: List[AsaServiceObject], svc_obj_groups: List[AsaServiceObjectGroup]) -> Tuple[EndpointKind, List[str]]: +def _parse_access_list_entry_protocol(parts: list[str], protocol_groups: list[AsaProtocolGroup], svc_objects: list[AsaServiceObject], svc_obj_groups: list[AsaServiceObjectGroup]) -> tuple[EndpointKind, list[str]]: """ Parse the protocol part of an access-list entry. - Returns (protocol EndpointKind, remaining tokens List[str]). + Returns (protocol EndpointKind, remaining tokens list[str]). """ # Determine protocol protocol = None @@ -542,10 +539,10 @@ def _parse_access_list_entry_protocol(parts: List[str], protocol_groups: List[As return protocol, tokens -def _parse_access_list_entry_dest_port(tokens: List[str], protocol: EndpointKind) -> Tuple[EndpointKind, List[str]]: +def _parse_access_list_entry_dest_port(tokens: list[str], protocol: EndpointKind) -> tuple[EndpointKind, list[str]]: """ Parse the destination port part of an access-list entry. - Returns (dst_port EndpointKind, remaining tokens List[str]). + Returns (dst_port EndpointKind, remaining tokens list[str]). """ dst_port = EndpointKind(kind="any", value="any") # Default value if len(tokens) >= 2 and tokens[0] == "eq": @@ -570,7 +567,7 @@ def _parse_access_list_entry_dest_port(tokens: List[str], protocol: EndpointKind return dst_port, tokens -def _parse_access_list_entry(line: str, protocol_groups: List[AsaProtocolGroup], svc_objects: List[AsaServiceObject], svc_obj_groups: List[AsaServiceObjectGroup]) -> AccessListEntry: +def parse_access_list_entry(line: str, protocol_groups: list[AsaProtocolGroup], svc_objects: list[AsaServiceObject], svc_obj_groups: list[AsaServiceObjectGroup]) -> AccessListEntry: """ Parse an access-list entry line and return an AccessListEntry object. Handles various formats as specified in the requirements. @@ -584,11 +581,11 @@ def _parse_access_list_entry(line: str, protocol_groups: List[AsaProtocolGroup], protocol, tokens = _parse_access_list_entry_protocol(parts, protocol_groups, svc_objects, svc_obj_groups) # Parse source endpoint - src, consumed = _parse_endpoint(tokens) + src, consumed = parse_endpoint(tokens) tokens = tokens[consumed:] # Parse destination endpoint - dst, consumed = _parse_endpoint(tokens) + dst, consumed = parse_endpoint(tokens) tokens = tokens[consumed:] # Parse destination port @@ -610,11 +607,11 @@ def _parse_access_list_entry(line: str, protocol_groups: List[AsaProtocolGroup], inactive=inactive ) -def _parse_protocol_object_group_block(block: List[str]) -> AsaProtocolGroup: +def parse_protocol_object_group_block(block: list[str]) -> AsaProtocolGroup: """Parse an object-group protocol block.""" name = block[0].split()[2] desc = _find_description(block[1:]) - protocols: List[str] = [] + protocols: list[str] = [] for b in block[1:]: s = b.strip() @@ -630,11 +627,11 @@ def _parse_protocol_object_group_block(block: List[str]) -> AsaProtocolGroup: ) -def _parse_icmp_object_group_block(block: List[str]) -> AsaServiceObjectGroup: +def parse_icmp_object_group_block(block: list[str]) -> AsaServiceObjectGroup: """Parse an object-group icmp-type block.""" grp_name = block[0].split()[2] desc = _find_description(block[1:]) - objects: List[str] = [] + objects: list[str] = [] for b in block[1:]: s = b.strip() mobj = re.match(r"^icmp-object\s+(\S+)$", s, re.I) diff --git a/roles/importer/files/importer/ciscoasa9/asa_rule.py b/roles/importer/files/importer/ciscoasa9/asa_rule.py index e8b922acca..debce4f7bb 100644 --- a/roles/importer/files/importer/ciscoasa9/asa_rule.py +++ b/roles/importer/files/importer/ciscoasa9/asa_rule.py @@ -5,18 +5,19 @@ service, source, and destination references. """ -from typing import List, Dict from netaddr import IPNetwork from models.rule import RuleNormalized, RuleAction, RuleTrack, RuleType from models.rulebase import Rulebase -from ciscoasa9.asa_models import AccessList, AccessListEntry, AsaProtocolGroup +from ciscoasa9.asa_models import AccessList, AccessListEntry, AsaProtocolGroup, EndpointKind from ciscoasa9.asa_service import create_service_for_acl_entry, create_any_protocol_service from ciscoasa9.asa_network import get_network_rule_endpoint from fwo_log import getFwoLogger import fwo_base +from models.networkobject import NetworkObject +from models.serviceobject import ServiceObject -def create_service_for_protocol_group_entry(protocol_group_name: str, protocol_groups: List[AsaProtocolGroup], service_objects: Dict) -> str: +def create_service_for_protocol_group_entry(protocol_group_name: str, protocol_groups: list[AsaProtocolGroup], service_objects: dict[str, ServiceObject]) -> str: """Resolve service reference for a protocol group. Args: @@ -34,7 +35,7 @@ def create_service_for_protocol_group_entry(protocol_group_name: str, protocol_g break if allowed_protocols: - svc_refs = [] + svc_refs: list[str] = [] for proto in allowed_protocols: svc_ref = create_any_protocol_service(proto, service_objects) svc_refs.append(svc_ref) @@ -48,7 +49,7 @@ def create_service_for_protocol_group_entry(protocol_group_name: str, protocol_g return fwo_base.sort_and_join(svc_refs) -def resolve_service_reference_for_rule(entry: AccessListEntry, protocol_groups: List[AsaProtocolGroup], service_objects: Dict) -> str: +def resolve_service_reference_for_rule(entry: AccessListEntry, protocol_groups: list[AsaProtocolGroup], service_objects: dict[str, ServiceObject]) -> str: """Resolve service reference for a rule entry. Args: @@ -67,7 +68,7 @@ def resolve_service_reference_for_rule(entry: AccessListEntry, protocol_groups: return create_service_for_acl_entry(entry, service_objects) -def resolve_network_reference_for_rule(endpoint, network_objects: Dict) -> str: +def resolve_network_reference_for_rule(endpoint: EndpointKind, network_objects: dict[str, NetworkObject]) -> str: """Resolve network reference for a rule endpoint. Args: @@ -88,8 +89,8 @@ def resolve_network_reference_for_rule(endpoint, network_objects: Dict) -> str: def create_rule_from_acl_entry(access_list_name: str, entry: AccessListEntry, - protocol_groups: List[AsaProtocolGroup], - network_objects: Dict, service_objects: Dict, + protocol_groups: list[AsaProtocolGroup], + network_objects: dict[str, NetworkObject], service_objects: dict[str, ServiceObject], gateway_uid: str) -> RuleNormalized: """Create a normalized rule from an ACL entry. @@ -150,10 +151,10 @@ def create_rule_from_acl_entry(access_list_name: str, entry: AccessListEntry, return rule -def build_rulebases_from_access_lists(access_lists: List[AccessList], mgm_uid: str, - protocol_groups: List[AsaProtocolGroup], - network_objects: Dict, service_objects: Dict, - gateway_uid: str) -> List[Rulebase]: +def build_rulebases_from_access_lists(access_lists: list[AccessList], mgm_uid: str, + protocol_groups: list[AsaProtocolGroup], + network_objects: dict[str, NetworkObject], service_objects: dict[str, ServiceObject], + gateway_uid: str) -> list[Rulebase]: """Build rulebases from ASA access lists. Each access list becomes a separate rulebase containing normalized rules. @@ -170,10 +171,10 @@ def build_rulebases_from_access_lists(access_lists: List[AccessList], mgm_uid: s Returns: List of normalized rulebases """ - rulebases = [] + rulebases: list[Rulebase] = [] for access_list in access_lists: - rules = {} + rules: dict[str, RuleNormalized] = {} for entry in access_list.entries: rule = create_rule_from_acl_entry( @@ -184,6 +185,10 @@ def build_rulebases_from_access_lists(access_lists: List[AccessList], mgm_uid: s service_objects, gateway_uid ) + if rule.rule_uid is None: + logger = getFwoLogger() + logger.error(f"Failed to create rule UID for ACL entry: {entry}") + raise ValueError("Rule UID generation failed.") rules[rule.rule_uid] = rule # Create rulebase for this access list diff --git a/roles/importer/files/importer/ciscoasa9/asa_service.py b/roles/importer/files/importer/ciscoasa9/asa_service.py index 137d804e8d..074b614a19 100644 --- a/roles/importer/files/importer/ciscoasa9/asa_service.py +++ b/roles/importer/files/importer/ciscoasa9/asa_service.py @@ -5,7 +5,6 @@ inline ACL definitions. """ -from typing import Dict, List, Optional, Tuple from models.serviceobject import ServiceObject from ciscoasa9.asa_models import AsaServiceObject, AsaServiceObjectGroup, AccessListEntry from ciscoasa9.asa_maps import name_to_port, protocol_map @@ -14,7 +13,7 @@ from fwo_log import getFwoLogger -def create_service_object(name: str, port: int, port_end: int, protocol: str, comment: Optional[str] = None) -> ServiceObject: +def create_service_object(name: str, port: int, port_end: int, protocol: str, comment: str | None = None) -> ServiceObject: """Create a normalized service object. Args: @@ -39,7 +38,7 @@ def create_service_object(name: str, port: int, port_end: int, protocol: str, co ) -def create_protocol_service_object(name: str, protocol: str, comment: Optional[str] = None) -> ServiceObject: +def create_protocol_service_object(name: str, protocol: str, comment: str | None = None) -> ServiceObject: """Create a service object for a protocol without specific ports. Args: @@ -60,7 +59,7 @@ def create_protocol_service_object(name: str, protocol: str, comment: Optional[s ) -def create_service_group_object(name: str, member_refs: List[str], comment: Optional[str] = None) -> ServiceObject: +def create_service_group_object(name: str, member_refs: list[str], comment: str | None = None) -> ServiceObject: """Create a service group object. Args: @@ -82,7 +81,7 @@ def create_service_group_object(name: str, member_refs: List[str], comment: Opti ) -def normalize_service_objects(service_objects: List[AsaServiceObject]) -> Dict[str, ServiceObject]: +def normalize_service_objects(service_objects: list[AsaServiceObject]) -> dict[str, ServiceObject]: """Normalize individual service objects from ASA configuration. Args: @@ -91,7 +90,7 @@ def normalize_service_objects(service_objects: List[AsaServiceObject]) -> Dict[s Returns: Dictionary of normalized service objects keyed by svc_uid """ - normalized = {} + normalized: dict[str, ServiceObject] = {} for svc in service_objects: if svc.dst_port_eq: @@ -121,13 +120,13 @@ def normalize_service_objects(service_objects: List[AsaServiceObject]) -> Dict[s return normalized -def create_protocol_any_service_objects() -> Dict[str, ServiceObject]: +def create_protocol_any_service_objects() -> dict[str, ServiceObject]: """Create default 'any' service objects for common protocols. Returns: Dictionary of protocol-any service objects """ - service_objects = {} + service_objects: dict[str, ServiceObject] = {} for proto in ("tcp", "udp", "icmp", "ip"): obj_name = f"any-{proto}" @@ -146,7 +145,7 @@ def create_protocol_any_service_objects() -> Dict[str, ServiceObject]: return service_objects -def create_service_for_port(port: str, proto: str, service_objects: Dict[str, ServiceObject]) -> str: +def create_service_for_port(port: str, proto: str, service_objects: dict[str, ServiceObject]) -> str: """Create a service object for a single port and protocol if it doesn't exist. Args: @@ -172,7 +171,7 @@ def create_service_for_port(port: str, proto: str, service_objects: Dict[str, Se return obj_name -def create_service_for_port_range(port_range: Tuple[str, str], proto: str, service_objects: Dict[str, ServiceObject]) -> str: +def create_service_for_port_range(port_range: tuple[str, str], proto: str, service_objects: dict[str, ServiceObject]) -> str: """Create a service object for a port range and protocol if it doesn't exist. Args: @@ -201,7 +200,7 @@ def create_service_for_port_range(port_range: Tuple[str, str], proto: str, servi return obj_name -def create_any_protocol_service(proto: str, service_objects: Dict[str, ServiceObject]) -> str: +def create_any_protocol_service(proto: str, service_objects: dict[str, ServiceObject]) -> str: """Create an 'any' service object for a protocol if it doesn't exist. Args: @@ -229,7 +228,7 @@ def create_any_protocol_service(proto: str, service_objects: Dict[str, ServiceOb -def create_service_for_protocol_entry_with_single_protocol(entry: AccessListEntry, service_objects: Dict[str, ServiceObject]) -> str: +def create_service_for_protocol_entry_with_single_protocol(entry: AccessListEntry, service_objects: dict[str, ServiceObject]) -> str: """Create service reference for a protocol entry with set protocol. Args: entry: Access list entry with protocol @@ -258,7 +257,7 @@ def create_service_for_protocol_entry_with_single_protocol(entry: AccessListEntr return create_any_protocol_service(entry.protocol.value, service_objects) -def create_service_for_protocol_entry(entry: AccessListEntry, service_objects: Dict[str, ServiceObject]) -> str: +def create_service_for_protocol_entry(entry: AccessListEntry, service_objects: dict[str, ServiceObject]) -> str: """Create service reference for a protocol group entry. Args: entry: Access list entry with protocol group @@ -271,7 +270,7 @@ def create_service_for_protocol_entry(entry: AccessListEntry, service_objects: D return create_service_for_protocol_entry_with_single_protocol(entry, service_objects) elif entry.protocol.value == "ip": - svc_refs = [] + svc_refs: list[str] = [] for proto in protocol_map.keys(): svc_refs.append(create_any_protocol_service(proto, service_objects)) @@ -292,7 +291,7 @@ def create_service_for_protocol_entry(entry: AccessListEntry, service_objects: D return create_any_protocol_service(entry.protocol.value, service_objects) -def create_service_for_acl_entry(entry: AccessListEntry, service_objects: Dict[str, ServiceObject]) -> str: +def create_service_for_acl_entry(entry: AccessListEntry, service_objects: dict[str, ServiceObject]) -> str: """Create service object(s) for an ACL entry and return the service reference. Args: @@ -315,7 +314,7 @@ def create_service_for_acl_entry(entry: AccessListEntry, service_objects: Dict[s else: # Default to all common protocols - svc_refs = [] + svc_refs: list[str] = [] for proto in ("tcp", "udp", "icmp"): svc_refs.append(create_any_protocol_service(proto, service_objects)) return fwo_base.sort_and_join(svc_refs) @@ -326,9 +325,9 @@ def create_service_for_acl_entry(entry: AccessListEntry, service_objects: Dict[s -def process_mixed_protocol_eq_ports(group: AsaServiceObjectGroup, service_objects: Dict[str, ServiceObject]) -> List[str]: +def process_mixed_protocol_eq_ports(group: AsaServiceObjectGroup, service_objects: dict[str, ServiceObject]) -> list[str]: """Process equal ports for mixed protocol groups.""" - obj_names = [] + obj_names: list[str] = [] for protos, eq_ports in group.ports_eq.items(): for proto in protos.split("-"): # handles "tcp-udp" for port in eq_ports: @@ -336,26 +335,26 @@ def process_mixed_protocol_eq_ports(group: AsaServiceObjectGroup, service_object obj_names.append(obj_name) return obj_names -def process_mixed_protocol_range_ports(group: AsaServiceObjectGroup, service_objects: Dict[str, ServiceObject]) -> List[str]: +def process_mixed_protocol_range_ports(group: AsaServiceObjectGroup, service_objects: dict[str, ServiceObject]) -> list[str]: """Process port ranges for mixed protocol groups.""" - obj_names = [] + obj_names: list[str] = [] for proto, ranges in group.ports_range.items(): for pr in ranges: obj_name = create_service_for_port_range(pr, proto, service_objects) obj_names.append(obj_name) return obj_names -def process_fully_enabled_protocols(group: AsaServiceObjectGroup, service_objects: Dict[str, ServiceObject]) -> List[str]: +def process_fully_enabled_protocols(group: AsaServiceObjectGroup, service_objects: dict[str, ServiceObject]) -> list[str]: """Process protocols that allow all ports.""" - obj_names = [] + obj_names: list[str] = [] for proto in group.protocols: obj_name = create_any_protocol_service(proto, service_objects) obj_names.append(obj_name) return obj_names -def process_mixed_protocol_group(group: AsaServiceObjectGroup, service_objects: Dict[str, ServiceObject]) -> List[str]: +def process_mixed_protocol_group(group: AsaServiceObjectGroup, service_objects: dict[str, ServiceObject]) -> list[str]: """Process a mixed protocol service group.""" - obj_names = [] + obj_names: list[str] = [] # Process ports_eq (single port values) obj_names.extend(process_mixed_protocol_eq_ports(group, service_objects)) @@ -373,7 +372,7 @@ def process_mixed_protocol_group(group: AsaServiceObjectGroup, service_objects: def process_single_protocol_eq_ports(protocol: str, ports: list[str], service_objects: dict[str, ServiceObject]) -> list[str]: """Process equal ports for single protocol groups.""" - obj_names = [] + obj_names: list[str] = [] for port in ports: obj_name = create_service_for_port(port, protocol, service_objects) obj_names.append(obj_name) @@ -381,15 +380,15 @@ def process_single_protocol_eq_ports(protocol: str, ports: list[str], service_ob def process_single_protocol_range_ports(protocol: str, ranges: list[tuple[str, str]], service_objects: dict[str, ServiceObject]) -> list[str]: """Process port ranges for single protocol groups.""" - obj_names = [] + obj_names: list[str] = [] for range in ranges: obj_name = create_service_for_port_range(range, protocol, service_objects) obj_names.append(obj_name) return obj_names -def process_single_protocol_group(group: AsaServiceObjectGroup, service_objects: Dict[str, ServiceObject]) -> List[str]: +def process_single_protocol_group(group: AsaServiceObjectGroup, service_objects: dict[str, ServiceObject]) -> list[str]: """Process a single-protocol service group.""" - obj_names = [] + obj_names: list[str] = [] if not group.proto_mode: raise ValueError(f"Service object group {group.name} missing proto_mode") @@ -411,7 +410,7 @@ def process_single_protocol_group(group: AsaServiceObjectGroup, service_objects: -def normalize_service_object_groups(service_groups: List[AsaServiceObjectGroup], service_objects: Dict[str, ServiceObject]) -> Dict[str, ServiceObject]: +def normalize_service_object_groups(service_groups: list[AsaServiceObjectGroup], service_objects: dict[str, ServiceObject]) -> dict[str, ServiceObject]: """Normalize service object groups from ASA configuration. Args: diff --git a/roles/importer/files/importer/ciscoasa9/fwcommon.py b/roles/importer/files/importer/ciscoasa9/fwcommon.py index 3f94009786..44d16efc5e 100644 --- a/roles/importer/files/importer/ciscoasa9/fwcommon.py +++ b/roles/importer/files/importer/ciscoasa9/fwcommon.py @@ -7,7 +7,7 @@ """ from pathlib import Path -from typing import Optional +from typing import Any, Optional from scrapli.driver import GenericDriver import time @@ -22,7 +22,7 @@ from fwo_exceptions import FwoImporterError -def has_config_changed(full_config, mgm_details, force=False): +def has_config_changed(full_config: FwConfigManagerListController, mgm_details: ManagementController, force: bool=False): # We don't get this info from ASA, so we always return True return True @@ -36,7 +36,7 @@ def _connect_to_device(mgm_details: ManagementController) -> GenericDriver: Returns: Connected GenericDriver instance. """ - device = { + device: dict[str, Any] = { "host": mgm_details.Hostname, "port": mgm_details.Port, "auth_username": mgm_details.ImportUser, @@ -285,7 +285,7 @@ def load_config_from_management(mgm_details: ManagementController, is_virtual_as try: return _attempt_connection(mgm_details, is_virtual_asa, attempt, max_retries) - except FwoImporterError as e: + except FwoImporterError as _: if attempt >= max_retries - 1: raise raise FwoImporterError(f"Failed to connect to device {mgm_details.Hostname} after {max_retries} attempts") @@ -306,11 +306,11 @@ def get_config(config_in: FwConfigManagerListController, import_state: ImportSta logger.debug ( "starting checkpointAsa9/get_config" ) - is_virtual_asa = import_state.MgmDetails.DeviceTypeName == "Cisco Asa on FirePower" + _ = import_state.MgmDetails.DeviceTypeName == "Cisco Asa on FirePower" - if config_in.native_config_is_empty: - raw_config = load_config_from_management(import_state.MgmDetails, is_virtual_asa) - # raw_config = load_config_from_file("test_asa.conf") + if config_in.native_config_is_empty: # type: ignore + # raw_config = load_config_from_management(import_state.MgmDetails, is_virtual_asa) + raw_config = load_config_from_file("test_asa.conf") config2import = parse_asa_config(raw_config) config_in.native_config = config2import.model_dump() diff --git a/roles/importer/files/importer/ciscofirepowerdomain7ff/cifp_getter.py b/roles/importer/files/importer/ciscofirepowerdomain7ff/cifp_getter.py index 8de8553887..e46f042d92 100644 --- a/roles/importer/files/importer/ciscofirepowerdomain7ff/cifp_getter.py +++ b/roles/importer/files/importer/ciscofirepowerdomain7ff/cifp_getter.py @@ -1,7 +1,7 @@ # library for API get functions import base64 +from typing import Any from fwo_log import getFwoLogger -import requests.packages import requests import json import fwo_globals @@ -9,7 +9,7 @@ auth_token = "" -def api_call(url, params = {}, headers = {}, json_payload = {}, auth_token = '', show_progress=False, method='get'): +def api_call(url: str, params: dict[str, Any] = {}, headers: dict[str, Any] = {}, json_payload: dict[str, Any] = {}, auth_token: str = '', show_progress: bool = False, method: str = 'get') -> tuple[dict[str, Any], dict[str, Any]]: logger = getFwoLogger() request_headers = {'Content-Type': 'application/json'} for header_key in headers: @@ -25,19 +25,11 @@ def api_call(url, params = {}, headers = {}, json_payload = {}, auth_token = '', verify=fwo_globals.verify_certs) else: raise Exception("unknown HTTP method found in cifp_getter") - - if response is None: - if 'pass' in json.dumps(json_payload): - exception_text = "error while sending api_call containing credential information to url '" + \ - str(url) - else: - exception_text = "error while sending api_call to url '" + str(url) + "' with payload '" + json.dumps( - json_payload, indent=2) + "' and headers: '" + json.dumps(request_headers, indent=2) - raise Exception(exception_text) - if (len(response.content) > 0): - body_json = response.json() + + if (len(response.content) > 0): + body_json: dict[str, Any] = response.json() else: - body_json = {} + body_json: dict[str, Any] = {} if fwo_globals.debug_level > 2: if 'pass' in json.dumps(json_payload): @@ -47,32 +39,33 @@ def api_call(url, params = {}, headers = {}, json_payload = {}, auth_token = '', logger.debug("api_call to url '" + str(url) + "' with payload '" + json.dumps( json_payload, indent=2) + "' and headers: '" + json.dumps(request_headers, indent=2)) - return response.headers, body_json + return dict(response.headers), body_json -def login(user, password, api_host, api_port): +def login(user: str, password: str, api_host: str, api_port: int) -> tuple[str, str]: base_url = 'https://' + api_host + ':' + str(api_port) + '/api/' try: headers, _ = api_call(base_url + "fmc_platform/v1/auth/generatetoken", method="post", headers={"Authorization" : "Basic " + str(base64.b64encode((user + ":" + password).encode('utf-8')), 'utf-8')}) except Exception as e: raise FwLoginFailed( "Cisco Firepower login ERROR: host=" + str(api_host) + ":" + str(api_port) + " Message: " + str(e)) from None - if headers.get("X-auth-access-token") == None: # leaving out payload as it contains pwd + access_token = headers.get("X-auth-access-token") + if access_token is None: # leaving out payload as it contains pwd raise FwLoginFailed( "Cisco Firepower login ERROR: host=" + str(api_host) + ":" + str(api_port)) from None if fwo_globals.debug_level > 2: logger = getFwoLogger() logger.debug("Login successful. Received auth token: " + headers["X-auth-access-token"]) - return headers.get("X-auth-access-token"), headers.get("DOMAINS") + return access_token, headers.get("DOMAINS") or "" # TODO Is there a logout? -def logout(v_url, sid, method='exec'): +def logout(v_url: str, sid: str, method: str = 'exec') -> None: return -def update_config_with_cisco_api_call(session_id, api_base_url, api_path, parameters={}, payload={}, show_progress=False, limit: int=1000, method="get"): +def update_config_with_cisco_api_call(session_id: str, api_base_url: str, api_path: str, parameters: dict[str, Any] = {}, payload: dict[str, Any] = {}, show_progress: bool = False, limit: int = 1000, method: str = "get") -> list[dict[str, Any]]: offset = 0 limit = 1000 returned_new_data = True - full_result = [] + full_result: list[dict[str, Any]] = [] while returned_new_data: parameters["offset"] = offset parameters["limit"] = limit diff --git a/roles/importer/files/importer/ciscofirepowerdomain7ff/cifp_network.py b/roles/importer/files/importer/ciscofirepowerdomain7ff/cifp_network.py index a78b47d87f..3d8a0244b4 100644 --- a/roles/importer/files/importer/ciscofirepowerdomain7ff/cifp_network.py +++ b/roles/importer/files/importer/ciscofirepowerdomain7ff/cifp_network.py @@ -1,13 +1,12 @@ -from asyncio.log import logger import random +from typing import Any from fwo_log import getFwoLogger from fwo_const import list_delimiter from netaddr import IPAddress -def normalize_nwobjects(full_config, config2import, import_id, jwt=None, mgm_id=None): - logger = getFwoLogger() - nw_objects = [] +def normalize_nwobjects(full_config: dict[str, Any], config2import: dict[str, Any], import_id: str, jwt: str | None = None, mgm_id: str | None = None): + nw_objects: list[dict[str, Any]] = [] for obj_orig in full_config["networkObjects"]: nw_objects.append(parse_object(obj_orig, import_id)) for obj_grp_orig in full_config["networkObjectGroups"]: @@ -17,12 +16,12 @@ def normalize_nwobjects(full_config, config2import, import_id, jwt=None, mgm_id= nw_objects.append(obj_grp) config2import['network_objects'] = nw_objects -def parse_obj_group(orig_grp, import_id, nw_objects, id = None): - refs = [] - names = [] +def parse_obj_group(orig_grp: dict[str, Any], import_id: str, nw_objects: list[dict[str, Any]], id: str | None = None): + refs: list[str] = [] + names: list[str] = [] if "literals" in orig_grp: - if id == None: - id = orig_grp["id"] if "id" in orig_grp else random.random() + if id is None: + id = orig_grp["id"] if "id" in orig_grp else str(random.random()) for orig_literal in orig_grp["literals"]: literal = parse_object(orig_literal, import_id) literal["obj_uid"] += "_" + str(id) @@ -36,16 +35,16 @@ def parse_obj_group(orig_grp, import_id, nw_objects, id = None): orig_obj["type"] != "Network" and orig_obj["type"] != "Range" and orig_obj["type"] != "FQDN"): logger = getFwoLogger() - logger.warn("Unknown network object type found: \"" + orig_obj["type"] + "\". Skipping.") + logger.warning("Unknown network object type found: \"" + orig_obj["type"] + "\". Skipping.") break names.append(orig_obj["name"]) refs.append(orig_obj["id"]) return list_delimiter.join(refs), list_delimiter.join(names) -def extract_base_object_infos(obj_orig, import_id): +def extract_base_object_infos(obj_orig: dict[str, Any], import_id: str) -> dict[str, Any]: logger = getFwoLogger() - obj = {} + obj: dict[str, Any] = {} if "id" in obj_orig: obj["obj_uid"] = obj_orig['id'] else: @@ -62,7 +61,9 @@ def extract_base_object_infos(obj_orig, import_id): obj['control_id'] = import_id return obj -def parse_object(obj_orig, import_id): + +def parse_object(obj_orig: dict[str, Any], import_id: str) -> dict[str, Any]: + logger = getFwoLogger() obj = extract_base_object_infos(obj_orig, import_id) if obj_orig["type"] == "Network": # network obj["obj_typ"] = "network" @@ -73,7 +74,7 @@ def parse_object(obj_orig, import_id): else: # not real cidr (netmask after /) obj['obj_ip'] = cidr[0] + "/" + str(IPAddress(cidr[1]).netmask_bits()) else: - logger.warn("missing value field in object - skipping: " + str(obj_orig)) + logger.warning("missing value field in object - skipping: " + str(obj_orig)) obj['obj_ip'] = "0.0.0.0" elif obj_orig["type"] == "Host": # host obj["obj_typ"] = "host" @@ -86,7 +87,7 @@ def parse_object(obj_orig, import_id): if obj_orig["value"].find("/") == -1: obj["obj_ip"] += "/32" else: - logger.warn("missing value field in object - skipping: " + str(obj_orig)) + logger.warning("missing value field in object - skipping: " + str(obj_orig)) obj['obj_ip'] = "0.0.0.0/0" elif obj_orig["type"] == "Range": # ip range obj['obj_typ'] = 'ip_range' diff --git a/roles/importer/files/importer/ciscofirepowerdomain7ff/cifp_rule.py b/roles/importer/files/importer/ciscofirepowerdomain7ff/cifp_rule.py index f2e64a4909..a25b18adb0 100644 --- a/roles/importer/files/importer/ciscofirepowerdomain7ff/cifp_rule.py +++ b/roles/importer/files/importer/ciscofirepowerdomain7ff/cifp_rule.py @@ -1,8 +1,8 @@ +from typing import Any from cifp_service import parse_svc_group from cifp_network import parse_obj_group import cifp_getter -from fwo_log import getFwoLogger rule_access_scope_v4 = ['rules_global_header_v4', 'rules_adom_v4', 'rules_global_footer_v4'] @@ -12,30 +12,29 @@ rule_nat_scope = ['rules_global_nat', 'rules_adom_nat'] rule_scope = rule_access_scope + rule_nat_scope -def getAccessPolicy(sessionId, api_url, config, device, limit): +def getAccessPolicy(sessionId: str, api_url: str, config: dict[str, Any], device: dict[str, Any], limit: int) -> None: access_policy = device["accessPolicy"]["id"] domain = device["domain"] - logger = getFwoLogger() device["rules"] = cifp_getter.update_config_with_cisco_api_call(sessionId, api_url, "fmc_config/v1/domain/" + domain + "/policy/accesspolicies/" + access_policy + "/accessrules", parameters={"expanded": True}, limit=limit) return -def normalize_access_rules(full_config, config2import, import_id, mgm_details={}, jwt=None): - any_nw_svc = {"svc_uid": "any_svc_placeholder", "svc_name": "Any", "svc_comment": "Placeholder service.", +def normalize_access_rules(full_config: dict[str, Any], config2import: dict[str, Any], import_id: str, mgm_details: dict[str, Any] = {}, jwt: str | None = None) -> None: + any_nw_svc: dict[str, Any] = {"svc_uid": "any_svc_placeholder", "svc_name": "Any", "svc_comment": "Placeholder service.", "svc_typ": "simple", "ip_proto": -1, "svc_port": 0, "svc_port_end": 65535, "control_id": import_id} - any_nw_object = {"obj_uid": "any_obj_placeholder", "obj_name": "Any", "obj_comment": "Placeholder object.", + any_nw_object: dict[str, Any] = {"obj_uid": "any_obj_placeholder", "obj_name": "Any", "obj_comment": "Placeholder object.", "obj_typ": "network", "obj_ip": "0.0.0.0/0", "control_id": import_id} config2import["service_objects"].append(any_nw_svc) config2import["network_objects"].append(any_nw_object) - rules = [] + rules: list[dict[str, Any]] = [] for device in full_config["devices"]: access_policy = device["accessPolicy"] rule_number = 0 for rule_orig in device["rules"]: - rule = {'rule_src': 'any', 'rule_dst': 'any', 'rule_svc': 'any', + rule: dict[str, Any] = {'rule_src': 'any', 'rule_dst': 'any', 'rule_svc': 'any', 'rule_src_refs': 'any_obj_placeholder', 'rule_dst_refs': 'any_obj_placeholder', 'rule_svc_refs': 'any_svc_placeholder'} rule['control_id'] = import_id diff --git a/roles/importer/files/importer/ciscofirepowerdomain7ff/cifp_service.py b/roles/importer/files/importer/ciscofirepowerdomain7ff/cifp_service.py index 6245ae2241..8b5e764a96 100644 --- a/roles/importer/files/importer/ciscofirepowerdomain7ff/cifp_service.py +++ b/roles/importer/files/importer/ciscofirepowerdomain7ff/cifp_service.py @@ -1,9 +1,10 @@ import random +from typing import Any from fwo_const import list_delimiter -def normalize_svcobjects(full_config, config2import, import_id): - svc_objects = [] +def normalize_svcobjects(full_config: dict[str, Any], config2import: dict[str, Any], import_id: str) -> None: + svc_objects: list[dict[str, Any]] = [] for svc_orig in full_config["serviceObjects"]: svc_objects.append(parse_svc(svc_orig, import_id)) for svc_grp_orig in full_config["serviceObjectGroups"]: @@ -13,8 +14,8 @@ def normalize_svcobjects(full_config, config2import, import_id): svc_objects.append(svc_grp) config2import['service_objects'] = svc_objects -def extract_base_svc_infos(svc_orig, import_id): - svc = {} +def extract_base_svc_infos(svc_orig: dict[str, Any], import_id: str) -> dict[str, Any]: + svc: dict[str, Any] = {} if "id" in svc_orig: svc["svc_uid"] = svc_orig["id"] else: @@ -34,7 +35,7 @@ def extract_base_svc_infos(svc_orig, import_id): svc["control_id"] = import_id return svc -def parse_svc(orig_svc, import_id): +def parse_svc(orig_svc: dict[str, Any], import_id: str) -> dict[str, Any]: svc = extract_base_svc_infos(orig_svc, import_id) svc["svc_typ"] = "simple" parse_port(orig_svc, svc) @@ -55,7 +56,7 @@ def parse_svc(orig_svc, import_id): svc["svc_name"] += " [Not supported]" return svc -def parse_port(orig_svc, svc): +def parse_port(orig_svc: dict[str, Any], svc: dict[str, Any]) -> None: if "port" in orig_svc: if orig_svc["port"].find("-") != -1: # port range port_range = orig_svc["port"].split("-") @@ -65,16 +66,16 @@ def parse_port(orig_svc, svc): svc["svc_port"] = orig_svc["port"] svc["svc_port_end"] = None -def parse_svc_group(orig_svc_grp, import_id, svc_objects, id = None): - refs = [] - names = [] +def parse_svc_group(orig_svc_grp: dict[str, Any], import_id: str, svc_objects: list[dict[str, Any]], id: str | None = None) -> tuple[str, str]: + refs: list[str] = [] + names: list[str] = [] if "literals" in orig_svc_grp: - if id == None: - id = orig_svc_grp["id"] if "id" in orig_svc_grp else random.random() + if id is None: + id = orig_svc_grp["id"] if "id" in orig_svc_grp else str(random.random()) for orig_literal in orig_svc_grp["literals"]: literal = parse_svc(orig_literal, import_id) - literal["svc_uid"] += "_" + id + literal["svc_uid"] += "_" + str(id) svc_objects.append(literal) names.append(literal["svc_name"]) refs.append(literal["svc_uid"]) diff --git a/roles/importer/files/importer/ciscofirepowerdomain7ff/fwcommon.py b/roles/importer/files/importer/ciscofirepowerdomain7ff/fwcommon.py index b5030cfca0..ac48a5969e 100644 --- a/roles/importer/files/importer/ciscofirepowerdomain7ff/fwcommon.py +++ b/roles/importer/files/importer/ciscofirepowerdomain7ff/fwcommon.py @@ -1,4 +1,5 @@ import json +from typing import Any # import sys # from common import importer_base_dir @@ -10,12 +11,12 @@ from fwo_log import getFwoLogger -def has_config_changed(full_config, mgm_details, force=False): +def has_config_changed(full_config: dict[str, Any], mgm_details: dict[str, Any], force: bool = False) -> bool: # dummy - may be filled with real check later on return True -def get_config(config2import, full_config, current_import_id, mgm_details, limit=1000, force=False, jwt=''): +def get_config(config2import: dict[str, Any], full_config: dict[str, Any], current_import_id: str, mgm_details: dict[str, Any], limit: int = 1000, force: bool = False, jwt: str = '') -> int: logger = getFwoLogger() if full_config == {}: # no native config was passed in, so getting it from Cisco Management parsing_config_only = False @@ -29,11 +30,11 @@ def get_config(config2import, full_config, current_import_id, mgm_details, limit sessionId, domains = cifp_getter.login(mgm_details["import_credential"]['user'], mgm_details["import_credential"]['secret'], mgm_details['hostname'], mgm_details['port']) domain = mgm_details["configPath"] - if sessionId == None or sessionId == "": + if sessionId == "": logger.error( 'Did not succeed in logging in to Cisco Firepower API, no sid returned.') return 1 - if domain == None or domain == "": + if domain is None or domain == "": logger.error( 'Configured domain is null or empty.') return 1 @@ -78,7 +79,7 @@ def get_config(config2import, full_config, current_import_id, mgm_details, limit # cifp_network.remove_nat_ip_entries(config2import) return 0 -def getAllAccessRules(sessionId, api_url, domains): +def getAllAccessRules(sessionId: str, api_url: str, domains: list[dict[str, Any]]) -> list[dict[str, Any]]: for domain in domains: domain["access_policies"] = cifp_getter.update_config_with_cisco_api_call(sessionId, api_url, "fmc_config/v1/domain/" + domain["uuid"] + "/policy/accesspolicies" , parameters={"expanded": True}, limit=1000) @@ -88,14 +89,14 @@ def getAllAccessRules(sessionId, api_url, domains): "fmc_config/v1/domain/" + domain["uuid"] + "/policy/accesspolicies/" + access_policy["id"] + "/accessrules", parameters={"expanded": True}, limit=1000) return domains -def getScopes(searchDomain, domains): - scopes = [] +def getScopes(searchDomain: str, domains: list[dict[str, Any]]) -> list[str]: + scopes: list[str] = [] for domain in domains: - if domain == domain["uuid"] or domain["name"].endswith(searchDomain): + if domain == domain["uuid"] or domain["name"].endswith(searchDomain): # TODO: is the check supposed to be searchDomain == domain["uuid"] ? scopes.append(domain["uuid"]) return scopes -def getDevices(sessionId, api_url, config, limit, scopes, devices): +def getDevices(sessionId: str, api_url: str, config: dict[str, Any], limit: int, scopes: list[str], devices: list[dict[str, Any]]) -> None: logger = getFwoLogger() # get all devices for scope in scopes: @@ -116,33 +117,45 @@ def getDevices(sessionId, api_url, config, limit, scopes, devices): config["devices"].remove(cisco_api_device) logger.info("Device \"" + cisco_api_device["name"] + "\" was found but it is not registered in FWO. Ignoring it.") -def getObjects(sessionId, api_url, config, limit, scopes): +def getObjects(sessionId: str, api_url: str, config: dict[str, Any], limit: int, scopes: list[str]) -> None: # network objects: - config["networkObjects"] = [] - config["networkObjectGroups"] = [] + networkObjects: list[dict[str, Any]] = [] + networkObjectGroups: list[dict[str, Any]] = [] # service objects: - config["serviceObjects"] = [] - config["serviceObjectGroups"] = [] + serviceObjects: list[dict[str, Any]] = [] + serviceObjectGroups: list[dict[str, Any]] = [] # user objects: - config["userObjects"] = [] - config["userObjectGroups"] = [] + userObjects: list[dict[str, Any]] = [] + userObjectGroups: list[dict[str, Any]] = [] + # get those objects that exist globally and on domain level for scope in scopes: # get network objects (groups): # for object_type in nw_obj_types: - config["networkObjects"].extend(cifp_getter.update_config_with_cisco_api_call(sessionId, api_url, + networkObjects.extend(cifp_getter.update_config_with_cisco_api_call(sessionId, api_url, "fmc_config/v1/domain/" + scope + "/object/networkaddresses", parameters={"expanded": True}, limit=limit)) - config["networkObjectGroups"].extend(cifp_getter.update_config_with_cisco_api_call(sessionId, api_url, + networkObjectGroups.extend(cifp_getter.update_config_with_cisco_api_call(sessionId, api_url, "fmc_config/v1/domain/" + scope + "/object/networkgroups", parameters={"expanded": True}, limit=limit)) # get service objects: # for object_type in svc_obj_types: - config["serviceObjects"].extend(cifp_getter.update_config_with_cisco_api_call(sessionId, api_url, + serviceObjects.extend(cifp_getter.update_config_with_cisco_api_call(sessionId, api_url, "fmc_config/v1/domain/" + scope + "/object/ports", parameters={"expanded": True}, limit=limit)) - config["serviceObjectGroups"].extend(cifp_getter.update_config_with_cisco_api_call(sessionId, api_url, + serviceObjectGroups.extend(cifp_getter.update_config_with_cisco_api_call(sessionId, api_url, "fmc_config/v1/domain/" + scope + "/object/portobjectgroups", parameters={"expanded": True}, limit=limit)) # get user objects: - config["userObjects"].extend(cifp_getter.update_config_with_cisco_api_call(sessionId, api_url, + userObjects.extend(cifp_getter.update_config_with_cisco_api_call(sessionId, api_url, "fmc_config/v1/domain/" + scope + "/object/realmusers", parameters={"expanded": True}, limit=limit)) - config["userObjectGroups"].extend(cifp_getter.update_config_with_cisco_api_call(sessionId, api_url, + userObjectGroups.extend(cifp_getter.update_config_with_cisco_api_call(sessionId, api_url, "fmc_config/v1/domain/" + scope + "/object/realmusergroups", parameters={"expanded": True}, limit=limit)) + + + # network objects: + config["networkObjects"] = networkObjects + config["networkObjectGroups"] = networkObjectGroups + # service objects: + config["serviceObjects"] = serviceObjects + config["serviceObjectGroups"] = serviceObjectGroups + # user objects: + config["userObjects"] = userObjects + config["userObjectGroups"] = userObjectGroups diff --git a/roles/importer/files/importer/common.py b/roles/importer/files/importer/common.py index 481540cdd8..e605f01f5c 100644 --- a/roles/importer/files/importer/common.py +++ b/roles/importer/files/importer/common.py @@ -42,9 +42,9 @@ expects service_provider to be initialized """ -def import_management(mgmId=None, ssl_verification=None, debug_level_in=0, - limit=150, force=False, clearManagementData=False, suppress_cert_warnings_in=None, - in_file=None, version=8) -> int: +def import_management(mgmId: int | None = None, ssl_verification: bool | None = None, debug_level_in: int = 0, + limit: int = 150, force: bool = False, clearManagementData: bool = False, suppress_cert_warnings_in: bool | None = None, + in_file: str | None = None, version: int = 8) -> int: fwo_signalling.registerSignallingHandlers() logger = getFwoLogger(debug_level=debug_level_in) @@ -113,8 +113,9 @@ def import_management(mgmId=None, ssl_verification=None, debug_level_in=0, return 1 -def _import_management(service_provider, importState: ImportStateController, config_importer=None, mgmId=None, ssl_verification=None, debug_level_in=0, - limit=150, clearManagementData=False, suppress_cert_warnings_in=None, in_file=None) -> None: +def _import_management(service_provider: ServiceProvider, importState: ImportStateController, config_importer: FwConfigImport | None = None, + mgmId: int | None = None, ssl_verification: bool | None = None, debug_level_in: int =0, + limit: int =150, clearManagementData: bool = False, suppress_cert_warnings_in: bool | None = None, in_file: str | None = None) -> None: config_normalized : FwConfigManagerListController @@ -166,14 +167,14 @@ def _import_management(service_provider, importState: ImportStateController, con -def handle_unexpected_exception(importState=None, config_importer=None, e=None): +def handle_unexpected_exception(importState: ImportStateController | None = None, config_importer: FwConfigImport | None = None, e: Exception | None = None): if 'importState' in locals() and importState is not None: importState.addError("Unexpected exception in import process - aborting " + traceback.format_exc()) if 'configImporter' in locals() and config_importer is not None: rollBackExceptionHandler(importState, configImporter=config_importer, exc=e) -def rollBackExceptionHandler(importState, configImporter=None, exc=None, errorText=""): +def rollBackExceptionHandler(importState: ImportStateController, configImporter: FwConfigImport | None = None, exc: BaseException | None = None, errorText: str = ""): logger = getFwoLogger() try: if fwo_globals.shutdown_requested: @@ -225,7 +226,7 @@ def import_from_file(importState: ImportStateController, fileName: str = "", gat return config_changed_since_last_import, configFromFile -def get_config_from_api(importState: ImportStateController, config_in) -> tuple[bool, FwConfigManagerListController]: +def get_config_from_api(importState: ImportStateController, config_in: FwConfigManagerListController) -> tuple[bool, FwConfigManagerListController]: logger = getFwoLogger(debug_level=importState.DebugLevel) try: # pick product-specific importer: @@ -238,8 +239,7 @@ def get_config_from_api(importState: ImportStateController, config_in) -> tuple[ raise # check for changes from product-specific FW API, if we are importing from file we assume config changes - config_changed_since_last_import = importState.ImportFileName is not None or \ - fw_module.has_config_changed(config_in, importState, force=importState.ForceImport) + config_changed_since_last_import = fw_module.has_config_changed(config_in, importState, force=importState.ForceImport) if config_changed_since_last_import: logger.info ( "has_config_changed: changes found or forced mode -> go ahead with getting config, Force = " + str(importState.ForceImport)) else: @@ -251,6 +251,9 @@ def get_config_from_api(importState: ImportStateController, config_in) -> tuple[ else: native_config = FwConfigManagerListController.generate_empty_config(importState.MgmDetails.IsSuperManager) + if config_in.native_config is None: + raise FwoImporterError("import_management: get_config returned no config") + write_native_config_to_file(importState, config_in.native_config) logger.debug("import_management: get_config completed (including normalization), duration: " @@ -277,7 +280,7 @@ def get_module_package_name(import_state: ImportStateController): def set_filename(import_state: ImportStateController, file_name: str = ''): # set file name in importState - if file_name == '' or file_name is None: + if file_name == '': # if the host name is an URI, do not connect to an API but simply read the config from this URI if stringIsUri(import_state.MgmDetails.Hostname): import_state.setImportFileName(import_state.MgmDetails.Hostname) diff --git a/roles/importer/files/importer/dummyroutermanagement1/fwcommon.py b/roles/importer/files/importer/dummyroutermanagement1/fwcommon.py index 98ae6cf6fe..f7252ce92d 100644 --- a/roles/importer/files/importer/dummyroutermanagement1/fwcommon.py +++ b/roles/importer/files/importer/dummyroutermanagement1/fwcommon.py @@ -1,18 +1,18 @@ -from common import complete_import -from curses import raw +from typing import Any from fwo_log import getFwoLogger import fwo_globals from model_controllers.interface_controller import Interface -from model_controllers.route_controller import Route -import json, requests, requests.packages +from model_controllers.route_controller import Route, getRouteDestination +import json, requests from datetime import datetime from fwo_exceptions import ConfigFileNotFound -def has_config_changed(_, __, force=False): +def has_config_changed(full_config: dict[str, Any], mgm_details: dict[str, Any], force: bool=False): + # dummy - may be filled with real check later on return True -def get_config(config2import, _, current_import_id, mgm_details, limit=100, force=False, jwt=''): +def get_config(config2import: dict[str, Any], full_config: dict[str, Any], current_import_id: str, mgm_details: dict[str, Any], limit: int=1000, force: bool=False, jwt: str=''): router_file_url = mgm_details['configPath'] error_count = 0 change_count = 0 @@ -33,27 +33,27 @@ def get_config(config2import, _, current_import_id, mgm_details, limit=100, forc r.raise_for_status() cfg = json.loads(r.content) - except requests.exceptions.RequestException: - error_string = "got HTTP status code" + str(r.status_code) + " while trying to read config file from URL " + router_file_url + except requests.exceptions.RequestException as e: + error_string = "got HTTP status code" + str(e.response.status_code if e.response else None) + " while trying to read config file from URL " + router_file_url error_count += 1 - error_count = complete_import(current_import_id, error_string, start_time, mgm_details, change_count, error_count, jwt) + error_count = complete_import(current_import_id, error_string, start_time, mgm_details, change_count, error_count, jwt) # type: ignore # TODO: function does not exist raise ConfigFileNotFound(error_string) from None except Exception: error_string = "Could not read config file " + router_file_url error_count += 1 - error_count = complete_import(current_import_id, error_string, start_time, mgm_details, change_count, error_count, jwt) + error_count = complete_import(current_import_id, error_string, start_time, mgm_details, change_count, error_count, jwt) # type: ignore # TODO: function does not exist raise ConfigFileNotFound(error_string) from None # deserialize network info from json into objects # device_id, name, ip, netmask_bits, state_up=True, ip_version=4 - ifaces = [] + ifaces: list[Interface] = [] for iface in cfg['interfaces']: ifaces.append(Interface(dev_id, iface['name'], iface['ip'], iface['netmask_bits'], state_up=iface['state_up'], ip_version=iface['ip_version'])) cfg['interfaces'] = ifaces # device_id, target_gateway, destination, static=True, source=None, interface=None, metric=None, distance=None, ip_version=4 - routes = [] + routes: list[Route] = [] for route in cfg['routing']: routes.append(Route(dev_id, route['target_gateway'], route['destination'], static=route['static'], interface=route['interface'], metric=route['metric'], distance=route['distance'], ip_version=route['ip_version'])) cfg['routing'] = routes diff --git a/roles/importer/files/importer/fortiadom5ff/fmgr_base.py b/roles/importer/files/importer/fortiadom5ff/fmgr_base.py index 11ffca6a64..80ec874100 100644 --- a/roles/importer/files/importer/fortiadom5ff/fmgr_base.py +++ b/roles/importer/files/importer/fortiadom5ff/fmgr_base.py @@ -1,8 +1,7 @@ +from typing import Any from services.service_provider import ServiceProvider from services.enums import Services from fwo_api_call import FwoApiCall, FwoApi -from fwo_config import readConfig -from fwo_const import fwo_config_filename from fwo_log import getFwoLogger # def resolve_objects (obj_name_string_list, delimiter, obj_dict, name_key, uid_key, rule_type=None, jwt=None, import_id=None, mgm_id=None): @@ -53,9 +52,8 @@ # return object_tables -def set_alerts_for_missing_objects(objects_not_found, jwt, import_id, rule_uid, object_type, mgm_id): +def set_alerts_for_missing_objects(objects_not_found: list[str], jwt: str, import_id: int, rule_uid: str | None, object_type: str | None, mgm_id: int): logger = getFwoLogger() - fwo_config = readConfig(fwo_config_filename) for obj in objects_not_found: if obj == 'all' or obj == 'Original': continue @@ -65,7 +63,7 @@ def set_alerts_for_missing_objects(objects_not_found, jwt, import_id, rule_uid, api_call = FwoApiCall(FwoApi(ApiUri=global_state.import_state.FwoConfig.FwoApiUri, Jwt=global_state.import_state.Jwt)) - if not api_call.create_data_issue(import_id=import_id, obj_name=obj, severity=1, + if not api_call.create_data_issue(importId=import_id, obj_name=obj, severity=1, rule_uid=rule_uid, mgm_id=mgm_id, object_type=object_type): logger.warning("resolve_raw_objects: encountered error while trying to log an import data issue using create_data_issue") @@ -77,7 +75,7 @@ def set_alerts_for_missing_objects(objects_not_found, jwt, import_id, rule_uid, description=desc, source='import', alertCode=16) -def lookup_obj_in_tables(el, object_tables, name_key, uid_key, ref_list): +def lookup_obj_in_tables(el: str, object_tables: list[list[dict[str, Any]]], name_key: str, uid_key: str, ref_list: list[str]) -> bool: logger = getFwoLogger() break_flag = False found = False diff --git a/roles/importer/files/importer/fortiadom5ff/fmgr_getter.py b/roles/importer/files/importer/fortiadom5ff/fmgr_getter.py index 5e4669008c..949ef87952 100644 --- a/roles/importer/files/importer/fortiadom5ff/fmgr_getter.py +++ b/roles/importer/files/importer/fortiadom5ff/fmgr_getter.py @@ -1,14 +1,14 @@ # library for API get functions from fwo_log import getFwoLogger -import requests.packages import requests import json from typing import Any import fwo_globals from fwo_exceptions import FwLoginFailed, FwoUnknownDeviceForManager, FwApiCallFailed, FwLogoutFailed +from models.management import Management -def api_call(url, command, json_payload, sid, show_progress=False, method=''): +def api_call(url: str, command: str, json_payload: dict[str, Any], sid: str, show_progress: bool=False, method: str='') -> dict[str, Any]: logger = getFwoLogger() request_headers = {'Content-Type': 'application/json'} if sid != '': @@ -22,8 +22,8 @@ def api_call(url, command, json_payload, sid, show_progress=False, method=''): r = requests.post(url, data=json.dumps(json_payload), headers=request_headers, verify=fwo_globals.verify_certs) try: - r - except Exception as e: + r # type: ignore #TYPING: This is always defined and does nothing + except Exception as _: if 'pass' in json.dumps(json_payload): exception_text = f'error while sending api_call containing credential information to url {str(url)}' else: @@ -56,8 +56,8 @@ def api_call(url, command, json_payload, sid, show_progress=False, method=''): return result_json -def login(user, password, base_url) -> str: - payload = { +def login(user: str, password: str, base_url: str) -> str | None: + payload: dict[str, Any] = { 'id': 1, 'params': [ { 'data': [ { 'user': user, 'passwd': password, } ] } ] } @@ -70,9 +70,9 @@ def login(user, password, base_url) -> str: return response['session'] -def logout(v_url, sid, method='exec'): +def logout(v_url: str, sid: str, method: str ='exec'): logger = getFwoLogger() - payload = {'params': [{}]} + payload: dict[str, Any] = {'params': [{}]} response = api_call(v_url, 'sys/logout', payload, sid, method=method) if 'result' in response and 'status' in response['result'][0] and 'code' in response['result'][0]['status'] and response['result'][0]['status']['code'] == 0: @@ -82,11 +82,11 @@ def logout(v_url, sid, method='exec'): 'api call: url: ' + str(v_url) + ', + payload: ' + str(payload)) -def update_config_with_fortinet_api_call(config_json, sid, api_base_url, api_path, result_name, payload={}, options=[], limit=150, method='get'): +def update_config_with_fortinet_api_call(config_json: list[dict[str, Any]], sid: str, api_base_url: str, api_path: str, result_name: str, payload: dict[str, Any] = {}, options: list[Any] = [], limit: int = 150, method: str = 'get'): offset = 0 limit = int(limit) returned_new_objects = True - full_result = [] + full_result: list[Any] = [] while returned_new_objects: range = [offset, limit] if payload == {}: @@ -111,7 +111,7 @@ def update_config_with_fortinet_api_call(config_json, sid, api_base_url, api_pat config_json.append({'type': result_name, 'data': full_result}) -def parse_special_fortinet_api_results(result_name, full_result): +def parse_special_fortinet_api_results(result_name: str, full_result: list[Any]) -> list[Any]: if result_name == 'nw_obj_global_firewall/internet-service-basic': if len(full_result)>0 and 'response' in full_result[0] and 'results' in full_result[0]['response']: full_result = full_result[0]['response']['results'] @@ -122,21 +122,21 @@ def parse_special_fortinet_api_results(result_name, full_result): return full_result -def fortinet_api_call(sid, api_base_url, api_path, payload={}, method='get'): +def fortinet_api_call(sid: str, api_base_url: str, api_path: str, payload: dict[str, Any] = {}, method: str = 'get') -> list[Any]: if payload == {}: payload = {'params': [{}]} - result = api_call(api_base_url, api_path, payload, sid, method=method) - plain_result = result['result'][0] + api_result = api_call(api_base_url, api_path, payload, sid, method=method) + plain_result: dict[str, Any] = api_result['result'][0] if 'data' in plain_result: result = plain_result['data'] if isinstance(result, dict): # code implicitly expects result to be a list, but some fmgr data results are dicts - result = [result] + result: list[Any] = [result] else: result = [] return result -def get_devices_from_manager(adom_mgm_details, sid, fm_api_url) -> dict[str, Any]: - device_vdom_dict = {} +def get_devices_from_manager(adom_mgm_details: Management, sid: str, fm_api_url: str) -> dict[str, Any]: + device_vdom_dict: dict[str, dict[str, str]] = {} device_results = fortinet_api_call(sid, fm_api_url, '/dvmdb/adom/' + adom_mgm_details.DomainName + '/device') for mgm_details_device in adom_mgm_details.Devices: @@ -149,7 +149,7 @@ def get_devices_from_manager(adom_mgm_details, sid, fm_api_url) -> dict[str, Any return device_vdom_dict -def parse_device_and_vdom(fmgr_device, mgm_details_device, device_vdom_dict, found_fmgr_device): +def parse_device_and_vdom(fmgr_device: dict[str, Any], mgm_details_device: dict[str, Any], device_vdom_dict: dict[str, dict[str, str]], found_fmgr_device: bool) -> bool: if 'vdom' in fmgr_device: for fmgr_vdom in fmgr_device['vdom']: if mgm_details_device['name'] == fmgr_device['name'] + '_' + fmgr_vdom['name']: @@ -161,7 +161,7 @@ def parse_device_and_vdom(fmgr_device, mgm_details_device, device_vdom_dict, fou return found_fmgr_device -def get_policy_packages_from_manager(sid, fm_api_url, adom=''): +def get_policy_packages_from_manager(sid: str, fm_api_url: str, adom: str = '') -> list[Any]: if adom == '': url = '/pm/pkg/global' else: diff --git a/roles/importer/files/importer/fortiadom5ff/fmgr_gw_networking.py b/roles/importer/files/importer/fortiadom5ff/fmgr_gw_networking.py index 74bd299835..89988cf581 100644 --- a/roles/importer/files/importer/fortiadom5ff/fmgr_gw_networking.py +++ b/roles/importer/files/importer/fortiadom5ff/fmgr_gw_networking.py @@ -1,4 +1,4 @@ -from asyncio.log import logger +from typing import Any from fwo_log import getFwoLogger from netaddr import IPAddress, IPNetwork from functools import cmp_to_key @@ -8,7 +8,7 @@ from model_controllers.interface_controller import Interface from model_controllers.route_controller import Route, getRouteDestination -def normalize_network_data(native_config, normalized_config, mgm_details): +def normalize_network_data(native_config: dict[str, Any], normalized_config: dict[str, Any], mgm_details: dict[str, Any]) -> None: # currently only a single IP (v4+v6) per interface ;-) # @@ -34,7 +34,7 @@ def normalize_network_data(native_config, normalized_config, mgm_details): normalized_config.update({'routing': {}, 'interfaces': {} }) - for dev_id, plain_dev_name, plain_vdom_name, full_vdom_name in get_all_dev_names(mgm_details['devices']): + for dev_id, _, _, full_vdom_name in get_all_dev_names(mgm_details['devices']): normalized_config.update({'routing': [], 'interfaces': []}) if 'routing-table-ipv4/' + full_vdom_name not in native_config: @@ -78,11 +78,11 @@ def normalize_network_data(native_config, normalized_config, mgm_details): # logger.warning('found devices without default route') -def get_matching_route(destination_ip, routing_table): +def get_matching_route(destination_ip: IPAddress, routing_table: list[dict[str, Any]]) -> dict[str, Any] | None: logger = getFwoLogger() - def route_matches(ip, destination): + def route_matches(ip: IPAddress, destination: str) -> bool: ip_n = IPNetwork(ip).cidr dest_n = IPNetwork(destination).cidr return ip_n in dest_n or dest_n in ip_n @@ -100,7 +100,7 @@ def route_matches(ip, destination): return None -def get_ip_of_interface(interface, interface_list=[]): +def get_ip_of_interface(interface: str, interface_list: list[dict[str, Any]] = []) -> str | None: interface_details = next((sub for sub in interface_list if sub['name'] == interface), None) @@ -110,9 +110,9 @@ def get_ip_of_interface(interface, interface_list=[]): return None -def sort_reverse(ar_in, key): +def sort_reverse(ar_in: list[dict[str, Any]], key: str) -> list[dict[str, Any]]: - def comp(left, right): + def comp(left: dict[str, Any], right: dict[str, Any]) -> int: l_submask = int(left[key].split("/")[1]) r_submask = int(right[key].split("/")[1]) return l_submask - r_submask @@ -121,7 +121,7 @@ def comp(left, right): # strip off last part of a string separated by separator -def strip_off_last_part(string_in, separator='_'): +def strip_off_last_part(string_in: str, separator: str = '_') -> str: string_out = string_in if separator in string_in: # strip off final _xxx part str_ar = string_in.split(separator) @@ -130,7 +130,7 @@ def strip_off_last_part(string_in, separator='_'): return string_out -def get_last_part(string_in, separator='_'): +def get_last_part(string_in: str, separator: str = '_') -> str: string_out = '' if separator in string_in: # strip off _vdom_name str_ar = string_in.split(separator) @@ -138,8 +138,8 @@ def get_last_part(string_in, separator='_'): return string_out -def get_plain_device_names_without_vdoms(devices): - device_array = [] +def get_plain_device_names_without_vdoms(devices: list[dict[str, Any]]) -> list[str]: + device_array: list[str] = [] for dev in devices: dev_name = strip_off_last_part(dev["name"]) if dev_name not in device_array: @@ -149,9 +149,9 @@ def get_plain_device_names_without_vdoms(devices): # only getting one vdom as currently assuming routing to be # the same for all vdoms on a device -def get_device_names_plus_one_vdom(devices): - device_array = [] - device_array_with_vdom = [] +def get_device_names_plus_one_vdom(devices: list[dict[str, Any]]) -> list[list[str]]: + device_array: list[str] = [] + device_array_with_vdom: list[list[str]] = [] for dev in devices: dev_name = strip_off_last_part(dev["name"]) vdom_name = get_last_part(dev["name"]) @@ -162,8 +162,8 @@ def get_device_names_plus_one_vdom(devices): # getting devices and their vdom names -def get_device_plus_full_vdom_names(devices): - device_array_with_vdom = [] +def get_device_plus_full_vdom_names(devices: list[dict[str, Any]]) -> list[list[str]]: + device_array_with_vdom: list[list[str]] = [] for dev in devices: dev_name = strip_off_last_part(dev["name"]) vdom_name = dev["name"] @@ -172,8 +172,8 @@ def get_device_plus_full_vdom_names(devices): # getting devices and their vdom names -def get_all_dev_names(devices): - device_array_with_vdom = [] +def get_all_dev_names(devices: list[dict[str, Any]]) -> list[list[Any]]: + device_array_with_vdom: list[list[Any]] = [] for dev in devices: dev_id = dev["id"] dev_name = strip_off_last_part(dev["name"]) @@ -184,17 +184,17 @@ def get_all_dev_names(devices): # get network information (currently only used for source nat) -def getInterfacesAndRouting(sid, fm_api_url, nativeConfig, adom_name, devices, limit): - +def getInterfacesAndRouting(sid: str, fm_api_url: str, nativeConfig: list[dict[str, Any]], adom_name: str, devices: list[dict[str, Any]], limit: int) -> None: + #TYPING: DICT OR LIST??? logger = getFwoLogger() # strip off vdom names, just deal with the plain device device_array = get_all_dev_names(devices) - for dev_id, plain_dev_name, plain_vdom_name, full_vdom_name in device_array: + for _, plain_dev_name, plain_vdom_name, full_vdom_name in device_array: logger.info("dev_name: " + plain_dev_name + ", full vdom_name: " + full_vdom_name) # getting interfaces of device - all_interfaces_payload = { + all_interfaces_payload: dict[str, Any] = { "id": 1, "params": [ { @@ -270,13 +270,13 @@ def getInterfacesAndRouting(sid, fm_api_url, nativeConfig, adom_name, devices, l # now getting routing information for ip_version in ["ipv4", "ipv6"]: - payload = { "params": [ { "data": { + payload: dict[str, Any] = { "params": [ { "data": { "target": ["adom/" + adom_name + "/device/" + plain_dev_name], "action": "get", "resource": "/api/v2/monitor/router/" + ip_version + "/select?&vdom="+ plain_vdom_name } } ] } try: # get routing table per vdom - routing_helper = {} - routing_table = [] + routing_helper: list[Any] = [] + routing_table: list[Any] = [] fmgr_getter.update_config_with_fortinet_api_call( routing_helper, sid, fm_api_url, "/sys/proxy/json", "routing-table-" + ip_version + '/' + full_vdom_name, @@ -294,10 +294,10 @@ def getInterfacesAndRouting(sid, fm_api_url, nativeConfig, adom_name, devices, l routing_table = [] # now storing the routing table: - nativeConfig.update({"routing-table-" + ip_version + '/' + full_vdom_name: routing_table}) + nativeConfig.update({"routing-table-" + ip_version + '/' + full_vdom_name: routing_table}) #type: ignore #TYPING: dict or list??? broo -def get_device_from_package(package_name, mgm_details): +def get_device_from_package(package_name: str, mgm_details: dict[str, Any]) -> str | None: logger = getFwoLogger() for dev in mgm_details['devices']: if dev['local_rulebase_name'] == package_name: diff --git a/roles/importer/files/importer/fortiadom5ff/fmgr_network.py b/roles/importer/files/importer/fortiadom5ff/fmgr_network.py index cdf8d5feaf..0154637923 100644 --- a/roles/importer/files/importer/fortiadom5ff/fmgr_network.py +++ b/roles/importer/files/importer/fortiadom5ff/fmgr_network.py @@ -1,16 +1,14 @@ from asyncio.log import logger import ipaddress +from typing import Any from fwo_log import getFwoLogger from fwo_const import list_delimiter, nat_postfix from fmgr_zone import find_zones_in_normalized_config -from fwo_config import readConfig -from model_controllers.import_state_controller import ImportStateController -from copy import deepcopy from fwo_exceptions import FwoImporterErrorInconsistencies -def normalize_network_objects(native_config, normalized_config_adom, normalized_config_global, nw_obj_types): - nw_objects = [] +def normalize_network_objects(native_config: dict[str, Any], normalized_config_adom: dict[str, Any], normalized_config_global: dict[str, Any], nw_obj_types: list[str]) -> None: + nw_objects: list[dict[str, Any]] = [] if 'objects' not in native_config: return # no objects to normalize @@ -30,8 +28,8 @@ def normalize_network_objects(native_config, normalized_config_adom, normalized_ normalized_config_adom.update({'network_objects': nw_objects}) -def get_obj_member_refs_list(obj_orig, native_config_objects, current_obj_type): - obj_member_refs_list = [] +def get_obj_member_refs_list(obj_orig: dict[str, Any], native_config_objects: dict[str, Any], current_obj_type: str) -> list[str]: + obj_member_refs_list: list[str] = [] for member_name in obj_orig['member']: for obj_type in native_config_objects: if exclude_object_types_in_member_ref_search(obj_type, current_obj_type): @@ -44,7 +42,7 @@ def get_obj_member_refs_list(obj_orig, native_config_objects, current_obj_type): f"Member inconsistent for object {obj_orig['name']}, found members={str(obj_orig['member'])} and member_refs={str(obj_member_refs_list)}") return obj_member_refs_list -def exclude_object_types_in_member_ref_search(obj_type, current_obj_type): +def exclude_object_types_in_member_ref_search(obj_type: str, current_obj_type: str) -> bool: #TODO expand for all kinds of missmatches in group and member skip_member_ref_loop = False if current_obj_type.endswith('firewall/addrgrp'): @@ -52,8 +50,8 @@ def exclude_object_types_in_member_ref_search(obj_type, current_obj_type): skip_member_ref_loop = True return skip_member_ref_loop -def normalize_network_object(obj_orig, nw_objects, normalized_config_adom, normalized_config_global, native_config_objects, current_obj_type): - obj = {} +def normalize_network_object(obj_orig: dict[str, Any], nw_objects: list[dict[str, Any]], normalized_config_adom: dict[str, Any], normalized_config_global: dict[str, Any], native_config_objects: dict[str, Any], current_obj_type: str) -> None: + obj: dict[str, Any] = {} obj.update({'obj_name': obj_orig['name']}) if 'subnet' in obj_orig: # ipv4 object _parse_subnet(obj, obj_orig) @@ -103,7 +101,7 @@ def normalize_network_object(obj_orig, nw_objects, normalized_config_adom, norma nw_objects.append(obj) -def _parse_subnet (obj, obj_orig): +def _parse_subnet (obj: dict[str, Any], obj_orig: dict[str, Any]) -> None: ipa = ipaddress.ip_network(str(obj_orig['subnet'][0]) + '/' + str(obj_orig['subnet'][1])) if ipa.num_addresses > 1: obj.update({ 'obj_typ': 'network' }) @@ -113,7 +111,7 @@ def _parse_subnet (obj, obj_orig): obj.update({ 'obj_ip_end': str(ipa.broadcast_address) }) -def normalize_network_object_ipv6(obj_orig, obj): +def normalize_network_object_ipv6(obj_orig: dict[str, Any], obj: dict[str, Any]) -> None: ipa = ipaddress.ip_network(obj_orig['ip6']) if ipa.num_addresses > 1: obj.update({ 'obj_typ': 'network' }) @@ -123,7 +121,7 @@ def normalize_network_object_ipv6(obj_orig, obj): obj.update({ 'obj_ip_end': str(ipa.broadcast_address) }) -def normalize_vip_object(obj_orig, obj, nw_objects): +def normalize_vip_object(obj_orig: dict[str, Any], obj: dict[str, Any], nw_objects: list[dict[str, Any]]) -> None: obj_zone = 'global' obj.update({ 'obj_typ': 'host' }) if 'extip' not in obj_orig or len(obj_orig['extip'])==0: @@ -132,7 +130,7 @@ def normalize_vip_object(obj_orig, obj, nw_objects): if len(obj_orig['extip'])>1: logger.warning("vip (extip): found more than one extip, just using the first one for " + obj_orig['name']) set_ip_in_obj(obj, obj_orig['extip'][0]) # resolving nat range if there is one - nat_obj = {} + nat_obj: dict[str, Any] = {} nat_obj.update({'obj_typ': 'host' }) nat_obj.update({'obj_color': 'black'}) nat_obj.update({'obj_comment': 'FWO-auto-generated nat object for VIP'}) @@ -152,7 +150,7 @@ def normalize_vip_object(obj_orig, obj, nw_objects): nw_objects.append(nat_obj) -def normalize_vip_object_nat_ip(obj_orig, obj, nat_obj): +def normalize_vip_object_nat_ip(obj_orig: dict[str, Any], obj: dict[str, Any], nat_obj: dict[str, Any]) -> None: # now dealing with the nat ip obj (mappedip) if 'mappedip' not in obj_orig or len(obj_orig['mappedip'])==0: logger.warning("vip (extip): found empty mappedip field for " + obj_orig['name']) @@ -173,7 +171,7 @@ def normalize_vip_object_nat_ip(obj_orig, obj, nat_obj): ###### range handling -def set_ip_in_obj(nw_obj, ip): # add start and end ip in nw_obj if it is a range, otherwise do nothing +def set_ip_in_obj(nw_obj: dict[str, Any], ip: str) -> None: # add start and end ip in nw_obj if it is a range, otherwise do nothing if '-' in ip: # dealing with range ip_start, ip_end = ip.split('-') nw_obj.update({'obj_ip': str(ip_start) }) @@ -184,7 +182,7 @@ def set_ip_in_obj(nw_obj, ip): # add start and end ip in nw_obj if it is a range # for members of groups, the name of the member obj needs to be fetched separately (starting from API v1.?) -def resolve_nw_uid_to_name(uid, nw_objects): +def resolve_nw_uid_to_name(uid: str, nw_objects: list[dict[str, Any]]) -> str: # return name of nw_objects element where obj_uid = uid for obj in nw_objects: if obj['obj_uid'] == uid: @@ -192,7 +190,7 @@ def resolve_nw_uid_to_name(uid, nw_objects): return 'ERROR: uid "' + uid + '" not found' -def add_member_names_for_nw_group(idx, nw_objects): +def add_member_names_for_nw_group(idx: int, nw_objects: list[dict[str, Any]]) -> None: group = nw_objects.pop(idx) if group['obj_member_refs'] == '' or group['obj_member_refs'] == None: #member_names = None @@ -209,7 +207,7 @@ def add_member_names_for_nw_group(idx, nw_objects): nw_objects.insert(idx, group) -def create_network_object(name, type, ip, ip_end, uid, color, comment, zone): +def create_network_object(name: str, type: str, ip: str, ip_end: str | None, uid: str, color: str, comment: str | None, zone: str | None) -> dict[str, Any]: # if zone is None or zone == '': # zone = 'global' return { @@ -224,7 +222,7 @@ def create_network_object(name, type, ip, ip_end, uid, color, comment, zone): } -def get_nw_obj(nat_obj_name, nwobjects): +def get_nw_obj(nat_obj_name: str, nwobjects: list[dict[str, Any]]) -> dict[str, Any] | None: for obj in nwobjects: if 'obj_name' in obj and obj['obj_name']==nat_obj_name: return obj @@ -233,13 +231,13 @@ def get_nw_obj(nat_obj_name, nwobjects): # this removes all obj_nat_ip entries from all network objects # these were used during import but might cause issues if imported into db -def remove_nat_ip_entries(config2import): +def remove_nat_ip_entries(config2import: dict[str, Any]) -> None: for obj in config2import['network_objects']: if 'obj_nat_ip' in obj: obj.pop('obj_nat_ip') -def get_first_ip_of_destination(obj_ref, config2import): +def get_first_ip_of_destination(obj_ref: str, config2import: dict[str, Any]) -> str | None: logger = getFwoLogger() if list_delimiter in obj_ref: obj_ref = obj_ref.split(list_delimiter)[0] diff --git a/roles/importer/files/importer/fortiadom5ff/fmgr_rule.py b/roles/importer/files/importer/fortiadom5ff/fmgr_rule.py index 6643485fcc..19ec7d40ed 100644 --- a/roles/importer/files/importer/fortiadom5ff/fmgr_rule.py +++ b/roles/importer/files/importer/fortiadom5ff/fmgr_rule.py @@ -1,9 +1,8 @@ import copy import ipaddress from time import strftime, localtime +from typing import Any from fwo_const import list_delimiter, nat_postfix, dummy_ip -from fwo_base import extend_string_list -from fmgr_service import create_svc_object from fmgr_network import create_network_object, get_first_ip_of_destination from fmgr_zone import find_zones_in_normalized_config from fmgr_consts import nat_types @@ -27,15 +26,15 @@ def normalize_rulebases( mgm_uid: str, - native_config: dict, - native_config_global: dict, - normalized_config_adom: dict, - normalized_config_global: dict, + native_config: dict[str, Any], + native_config_global: dict[str, Any], + normalized_config_adom: dict[str, Any], + normalized_config_global: dict[str, Any], is_global_loop_iteration: bool ) -> None: normalized_config_adom['policies'] = [] - fetched_rulebase_uids: list = [] + fetched_rulebase_uids: list[str] = [] if normalized_config_global != {}: for normalized_rulebase_global in normalized_config_global.get('policies', []): fetched_rulebase_uids.append(normalized_rulebase_global.uid) @@ -45,8 +44,8 @@ def normalize_rulebases( is_global_loop_iteration, normalized_config_adom, normalized_config_global) -def normalize_rulebases_for_each_link_destination(gateway, mgm_uid, fetched_rulebase_uids, native_config, - native_config_global, is_global_loop_iteration, normalized_config_adom, normalized_config_global): +def normalize_rulebases_for_each_link_destination(gateway: dict[str, Any], mgm_uid: str, fetched_rulebase_uids: list[str], native_config: dict[str, Any], + native_config_global: dict[str, Any], is_global_loop_iteration: bool, normalized_config_adom: dict[str, Any], normalized_config_global: dict[str, Any]): logger = getFwoLogger() for rulebase_link in gateway['rulebase_links']: if rulebase_link['to_rulebase_uid'] not in fetched_rulebase_uids and rulebase_link['to_rulebase_uid'] != '': @@ -74,14 +73,14 @@ def normalize_rulebases_for_each_link_destination(gateway, mgm_uid, fetched_rule # normalizing nat rulebases is work in progress #normalize_nat_rulebase(rulebase_link, native_config, normalized_config_adom, normalized_config_global) -def normalize_nat_rulebase(rulebase_link, native_config, normalized_config_adom, normalized_config_global): +def normalize_nat_rulebase(rulebase_link: dict[str, Any], native_config: dict[str, Any], normalized_config_adom: dict[str, Any], normalized_config_global: dict[str, Any]): if not rulebase_link['is_section']: for nat_type in nat_types: nat_type_string = nat_type + '_' + rulebase_link['to_rulebase_uid'] nat_rulebase = get_native_nat_rulebase(native_config, nat_type_string) parse_nat_rulebase(nat_rulebase, nat_type_string, normalized_config_adom, normalized_config_global) -def get_native_nat_rulebase(native_config, nat_type_string): +def get_native_nat_rulebase(native_config: dict[str, Any], nat_type_string: str) -> list[dict[str, Any]]: logger = getFwoLogger() for nat_rulebase in native_config['nat_rulebases']: if nat_type_string == nat_rulebase['type']: @@ -89,13 +88,13 @@ def get_native_nat_rulebase(native_config, nat_type_string): logger.warning('no nat data for '+ nat_type_string) return [] -def find_rulebase_to_parse(rulebase_list, rulebase_uid): +def find_rulebase_to_parse(rulebase_list: list[dict[str, Any]], rulebase_uid: str) -> dict[str, Any]: for rulebase in rulebase_list: if rulebase['uid'] == rulebase_uid: return rulebase return {} -def initialize_normalized_rulebase(rulebase_to_parse, mgm_uid): +def initialize_normalized_rulebase(rulebase_to_parse: dict[str, Any], mgm_uid: str) -> Rulebase: """ we use 'type' as uid/name since a rulebase may have a v4 and a v6 part """ @@ -104,14 +103,14 @@ def initialize_normalized_rulebase(rulebase_to_parse, mgm_uid): normalized_rulebase = Rulebase(uid=rulebaseUid, name=rulebaseName, mgm_uid=mgm_uid, rules={}) return normalized_rulebase -def parse_rulebase(normalized_config_adom, normalized_config_global, rulebase_to_parse, normalized_rulebase, found_rulebase_in_global): +def parse_rulebase(normalized_config_adom: dict[str, Any], normalized_config_global: dict[str, Any], rulebase_to_parse: dict[str, Any], normalized_rulebase: Rulebase, found_rulebase_in_global: bool): """Parses a native Fortinet rulebase into a normalized rulebase.""" for native_rule in rulebase_to_parse['data']: parse_single_rule(normalized_config_adom, normalized_config_global, native_rule, normalized_rulebase) if not found_rulebase_in_global: add_implicit_deny_rule(normalized_config_adom, normalized_config_global, normalized_rulebase) -def add_implicit_deny_rule(normalized_config_adom, normalized_config_global, rulebase: Rulebase): +def add_implicit_deny_rule(normalized_config_adom: dict[str, Any], normalized_config_global: dict[str, Any], rulebase: Rulebase): deny_rule = {'srcaddr': ['all'], 'srcaddr6': ['all'], 'dstaddr': ['all'], 'dstaddr6': ['all'], @@ -156,9 +155,12 @@ def add_implicit_deny_rule(normalized_config_adom, normalized_config_global, rul rule_dst_zone=list_delimiter.join(rule_dst_zones), rule_head_text=None ) + + if rule_normalized.rule_uid is None: + raise FwoImporterErrorInconsistencies("rule_normalized.rule_uid is None when adding implicit deny rule") rulebase.rules[rule_normalized.rule_uid] = rule_normalized -def parse_single_rule(normalized_config_adom, normalized_config_global, native_rule, rulebase: Rulebase): +def parse_single_rule(normalized_config_adom: dict[str, Any], normalized_config_global: dict[str, Any], native_rule: dict[str, Any], rulebase: Rulebase): """Parses a single native Fortinet rule into a normalized rule and adds it to the given rulebase.""" # Extract basic rule information rule_disabled = True # Default to disabled @@ -215,20 +217,22 @@ def parse_single_rule(normalized_config_adom, normalized_config_global, native_r rule_dst_zone=list_delimiter.join(rule_dst_zones), rule_head_text=None ) + if rule_normalized.rule_uid is None: + raise FwoImporterErrorInconsistencies("rule_normalized.rule_uid is None when parsing single rule") # Add the rule to the rulebase rulebase.rules[rule_normalized.rule_uid] = rule_normalized # TODO: handle combined NAT, see handle_combined_nat_rule -def rule_parse_action(native_rule): +def rule_parse_action(native_rule: dict[str, Any]) -> RuleAction: # Extract action - Fortinet uses 0 for deny/drop, 1 for accept if native_rule.get('action', 0) == 0: return RuleAction.DROP else: return RuleAction.ACCEPT -def rule_parse_tracking_info(native_rule): +def rule_parse_tracking_info(native_rule: dict[str, Any]) -> RuleTrack: # TODO: Implement more detailed logging level extraction (difference between 1/2/3?) logtraffic = native_rule.get('logtraffic', 0) if isinstance(logtraffic, int) and logtraffic > 0 or isinstance(logtraffic, str) and logtraffic != 'disable': @@ -236,9 +240,9 @@ def rule_parse_tracking_info(native_rule): else: return RuleTrack.NONE -def rule_parse_service(native_rule): - rule_svc_list = [] - rule_svc_refs_list = [] +def rule_parse_service(native_rule: dict[str, Any]) -> tuple[list[str], list[str]]: + rule_svc_list: list[str] = [] + rule_svc_refs_list: list[str] = [] for svc in native_rule.get('service', []): rule_svc_list.append(svc) rule_svc_refs_list.append(svc) @@ -251,11 +255,11 @@ def rule_parse_service(native_rule): return rule_svc_list, rule_svc_refs_list -def rule_parse_addresses(native_rule, target, normalized_config_adom, normalized_config_global, is_nat): +def rule_parse_addresses(native_rule: dict[str, Any], target: str, normalized_config_adom: dict[str, Any], normalized_config_global: dict[str, Any], is_nat: bool) -> tuple[list[str], list[str]]: if target not in ['src', 'dst']: raise FwoImporterErrorInconsistencies(f"target '{target}' must either be src or dst.") - addr_list = [] - addr_ref_list = [] + addr_list: list[str] = [] + addr_ref_list: list[str] = [] if not is_nat: build_addr_list(native_rule, True, target, normalized_config_adom, normalized_config_global, addr_list, addr_ref_list) build_addr_list(native_rule, False, target, normalized_config_adom, normalized_config_global, addr_list, addr_ref_list) @@ -263,7 +267,7 @@ def rule_parse_addresses(native_rule, target, normalized_config_adom, normalized build_nat_addr_list(native_rule, target, normalized_config_adom, normalized_config_global, addr_list, addr_ref_list) return addr_list, addr_ref_list -def build_addr_list(native_rule, is_v4, target, normalized_config_adom, normalized_config_global, addr_list, addr_ref_list): +def build_addr_list(native_rule: dict[str, Any], is_v4: bool, target: str, normalized_config_adom: dict[str, Any], normalized_config_global: dict[str, Any], addr_list: list[str], addr_ref_list: list[str]) -> None: if is_v4 and target == 'src': for addr in native_rule.get('srcaddr', []) + native_rule.get('internet-service-src-name', []): addr_list.append(addr) @@ -281,7 +285,7 @@ def build_addr_list(native_rule, is_v4, target, normalized_config_adom, normaliz addr_list.append(addr) addr_ref_list.append(find_addr_ref(addr, is_v4, normalized_config_adom, normalized_config_global)) -def build_nat_addr_list(native_rule, target, normalized_config_adom, normalized_config_global, addr_list, addr_ref_list): +def build_nat_addr_list(native_rule: dict[str, Any], target: str, normalized_config_adom: dict[str, Any], normalized_config_global: dict[str, Any], addr_list: list[str], addr_ref_list: list[str]) -> None: # so far only ip v4 expected if target == 'src': for addr in native_rule.get('orig-addr', []): @@ -292,14 +296,14 @@ def build_nat_addr_list(native_rule, target, normalized_config_adom, normalized_ addr_list.append(addr) addr_ref_list.append(find_addr_ref(addr, True, normalized_config_adom, normalized_config_global)) -def find_addr_ref(addr, is_v4, normalized_config_adom, normalized_config_global): +def find_addr_ref(addr: str, is_v4: bool, normalized_config_adom: dict[str, Any], normalized_config_global: dict[str, Any]) -> str: for nw_obj in normalized_config_adom['network_objects'] + normalized_config_global.get('network_objects', []): if addr == nw_obj['obj_name']: if (is_v4 and ip_type(nw_obj) == 4) or (not is_v4 and ip_type(nw_obj) == 6): return nw_obj['obj_uid'] raise FwoImporterErrorInconsistencies(f"No ref found for '{addr}'.") -def ip_type(nw_obj): +def ip_type(nw_obj: dict[str, Any]) -> int: # default to v4 first_ip = nw_obj.get('obj_ip', '0.0.0.0/32') if first_ip == '': @@ -307,7 +311,7 @@ def ip_type(nw_obj): net=ipaddress.ip_network(str(first_ip)) return net.version -def rule_parse_negation_flags(native_rule): +def rule_parse_negation_flags(native_rule: dict[str, Any]) -> tuple[bool, bool, bool]: # if customer decides to mix internet-service and "normal" addr obj in src/dst and mix negates this will prob. not work correctly if 'srcaddr-negate' in native_rule: rule_src_neg = native_rule['srcaddr-negate'] == 1 or native_rule['srcaddr-negate'] == 'disable' @@ -319,22 +323,22 @@ def rule_parse_negation_flags(native_rule): rule_svc_neg = 'service-negate' in native_rule and (native_rule['service-negate'] == 1 or native_rule['service-negate'] == 'disable') return rule_src_neg, rule_dst_neg, rule_svc_neg -def rule_parse_installon(native_rule) -> str|None: +def rule_parse_installon(native_rule: dict[str, Any]) -> str|None: rule_installon = None if 'scope_member' in native_rule and native_rule['scope_member']: rule_installon = list_delimiter.join(sorted({vdom['name'] + '_' + vdom['vdom'] for vdom in native_rule['scope_member']})) return rule_installon -def rule_parse_last_hit(native_rule): +def rule_parse_last_hit(native_rule: dict[str, Any]) -> str|None: last_hit = native_rule.get('_last_hit', None) if last_hit != None: last_hit = strftime("%Y-%m-%d %H:%M:%S", localtime(last_hit)) return last_hit -def get_access_policy(sid, fm_api_url, native_config_adom, native_config_global, adom_device_vdom_policy_package_structure, adom_name, mgm_details_device, device_config, limit): +def get_access_policy(sid: str, fm_api_url: str, native_config_adom: dict[str, Any], native_config_global: dict[str, Any], adom_device_vdom_policy_package_structure: dict[str, Any], adom_name: str, mgm_details_device: dict[str, Any], device_config: dict[str, Any], limit: int): previous_rulebase = None - link_list = [] + link_list: list[Any] = [] local_pkg_name, global_pkg_name = find_packages(adom_device_vdom_policy_package_structure, adom_name, mgm_details_device) options = ['extra info', 'scope member', 'get meta'] @@ -349,7 +353,7 @@ def get_access_policy(sid, fm_api_url, native_config_adom, native_config_global, device_config['rulebase_links'].extend(link_list) -def get_and_link_global_rulebase(header_or_footer, previous_rulebase, global_pkg_name, native_config_global, sid, fm_api_url, options, limit, link_list): +def get_and_link_global_rulebase(header_or_footer: str, previous_rulebase: str | None, global_pkg_name: str, native_config_global: dict[str, Any], sid: str, fm_api_url: str, options: list[str], limit: int, link_list: list[Any]) -> Any: rulebase_type_prefix = 'rules_global_' + header_or_footer if global_pkg_name != '': if not is_rulebase_already_fetched(native_config_global['rulebases'], rulebase_type_prefix + '_v4_' + global_pkg_name): @@ -370,7 +374,7 @@ def get_and_link_global_rulebase(header_or_footer, previous_rulebase, global_pkg previous_rulebase = link_rulebase(link_list, native_config_global['rulebases'], global_pkg_name, rulebase_type_prefix, previous_rulebase, True) return previous_rulebase -def get_and_link_local_rulebase(rulebase_type_prefix, previous_rulebase, adom_name, local_pkg_name, native_config_adom, sid, fm_api_url, options, limit, link_list): +def get_and_link_local_rulebase(rulebase_type_prefix: str, previous_rulebase: str | None, adom_name: str, local_pkg_name: str, native_config_adom: dict[str, Any], sid: str, fm_api_url: str, options: list[str], limit: int, link_list: list[Any]) -> Any: if not is_rulebase_already_fetched(native_config_adom['rulebases'], rulebase_type_prefix + '_v4_' + local_pkg_name): fmgr_getter.update_config_with_fortinet_api_call( native_config_adom['rulebases'], @@ -388,7 +392,7 @@ def get_and_link_local_rulebase(rulebase_type_prefix, previous_rulebase, adom_na previous_rulebase = link_rulebase(link_list, native_config_adom['rulebases'], local_pkg_name, rulebase_type_prefix, previous_rulebase, False) return previous_rulebase -def find_packages(adom_device_vdom_policy_package_structure, adom_name, mgm_details_device): +def find_packages(adom_device_vdom_policy_package_structure: dict[str, Any], adom_name: str, mgm_details_device: dict[str, Any]) -> tuple[str, str]: for device in adom_device_vdom_policy_package_structure[adom_name]: for vdom in adom_device_vdom_policy_package_structure[adom_name][device]: if mgm_details_device['name'] == device + '_' + vdom: @@ -399,13 +403,13 @@ def find_packages(adom_device_vdom_policy_package_structure, adom_name, mgm_deta return '', '' raise FwoDeviceWithoutLocalPackage('Could not find local package for ' + mgm_details_device['name'] + ' in Fortimanager Config') from None -def is_rulebase_already_fetched(rulebases, type): +def is_rulebase_already_fetched(rulebases: list[dict[str, Any]], type: str) -> bool: for rulebase in rulebases: if rulebase['type'] == type: return True return False -def link_rulebase(link_list, rulebases, pkg_name, rulebase_type_prefix, previous_rulebase, is_global): +def link_rulebase(link_list: list[Any], rulebases: list[dict[str, Any]], pkg_name: str, rulebase_type_prefix: str, previous_rulebase: str | None, is_global: bool) -> str|None: for version in ['v4', 'v6']: full_pkg_name = rulebase_type_prefix + '_' + version + '_' + pkg_name has_data = has_rulebase_data(rulebases, full_pkg_name, is_global, version, pkg_name) @@ -415,7 +419,7 @@ def link_rulebase(link_list, rulebases, pkg_name, rulebase_type_prefix, previous return previous_rulebase -def build_link(previous_rulebase, full_pkg_name, is_global): +def build_link(previous_rulebase: str | None, full_pkg_name: str, is_global: bool) -> dict[str, Any]: if previous_rulebase is None: is_initial = True previous_rulebase = None @@ -432,7 +436,7 @@ def build_link(previous_rulebase, full_pkg_name, is_global): 'is_section': False } -def has_rulebase_data(rulebases, full_pkg_name, is_global, version, pkg_name): +def has_rulebase_data(rulebases: list[dict[str, Any]], full_pkg_name: str, is_global: bool, version: str, pkg_name: str) -> bool: """adds name and uid to rulebase and removes empty global rulebases""" has_data = False if version == 'v4': @@ -452,7 +456,7 @@ def has_rulebase_data(rulebases, full_pkg_name, is_global, version, pkg_name): rulebases.remove(rulebase) return has_data -def get_nat_policy(sid, fm_api_url, native_config, adom_device_vdom_policy_package_structure, adom_name, mgm_details_device, limit): +def get_nat_policy(sid: str, fm_api_url: str, native_config: dict[str, Any], adom_device_vdom_policy_package_structure: dict[str, Any], adom_name: str, mgm_details_device: dict[str, Any], limit: int): local_pkg_name, global_pkg_name = find_packages(adom_device_vdom_policy_package_structure, adom_name, mgm_details_device) if adom_name == '': for nat_type in nat_types: @@ -470,7 +474,7 @@ def get_nat_policy(sid, fm_api_url, native_config, adom_device_vdom_policy_packa # delete_v: ab hier kann sehr viel weg, ich lasses vorerst zB für die nat # pure nat rules -def parse_nat_rulebase(nat_rulebase, nat_type_string, normalized_config_adom, normalized_config_global): +def parse_nat_rulebase(nat_rulebase: list[dict[str, Any]], nat_type_string: str, normalized_config_adom: dict[str, Any], normalized_config_global: dict[str, Any]) -> None: # this function is not called until it is ready return # the following is a first draft and is not yet functional @@ -599,7 +603,7 @@ def parse_nat_rulebase(nat_rulebase, nat_type_string, normalized_config_adom, no # nat_rules.append(xlate_rule) # normalized_config_adom['rules'].extend(nat_rules) -def create_xlate_rule(rule): +def create_xlate_rule(rule: dict[str, Any]) -> dict[str, Any]: xlate_rule = copy.deepcopy(rule) rule['rule_type'] = 'combined' xlate_rule['rule_type'] = 'xlate' @@ -614,7 +618,7 @@ def create_xlate_rule(rule): return xlate_rule -def handle_combined_nat_rule(rule, rule_orig, config2import, nat_rule_number, import_id, localPkgName, dev_id): +def handle_combined_nat_rule(rule: dict[str, Any], rule_orig: dict[str, Any], config2import: dict[str, Any], nat_rule_number: int, import_id: str, localPkgName: str, dev_id: int) -> dict[str, Any] | None: # now dealing with VIPs (dst NAT part) of combined rules logger = getFwoLogger() xlate_rule = None @@ -626,10 +630,9 @@ def handle_combined_nat_rule(rule, rule_orig, config2import, nat_rule_number, im xlate_rule = create_xlate_rule(rule) if 'ippool' in rule_orig: if rule_orig['ippool']==0: # hiding behind outbound interface - interface_name = 'unknownIF' - destination_interface_ip = '0.0.0.0' + destination_interface_ip: str | None = '0.0.0.0' destination_ip = get_first_ip_of_destination(rule['rule_dst_refs'], config2import) # get an ip of destination - hideInterface = 'undefined_interface' + hideInterface: str | None = 'undefined_interface' if destination_ip is None: logger.warning('src nat behind interface: found no valid destination ip in rule with UID ' + rule['rule_uid']) else: @@ -640,7 +643,7 @@ def handle_combined_nat_rule(rule, rule_orig, config2import, nat_rule_number, im + rule['rule_uid'] + ', dest_ip: ' + destination_ip) else: destination_interface_ip = get_ip_of_interface_obj(matching_route.interface, dev_id, config2import['interfaces']) - interface_name = matching_route.interface + interface_name: str | None = matching_route.interface hideInterface=interface_name if hideInterface is None: logger.warning('src nat behind interface: found route with undefined interface ') #+ str(jsonpickle.dumps(matching_route, unpicklable=True))) @@ -694,8 +697,8 @@ def handle_combined_nat_rule(rule, rule_orig, config2import, nat_rule_number, im if len(nat_object_list)>0: if xlate_rule is None: # no source nat, so we create the necessary nat rule here xlate_rule = create_xlate_rule(rule) - xlate_dst = [] - xlate_dst_refs = [] + xlate_dst: list[str] = [] + xlate_dst_refs: list[str] = [] for nat_obj in nat_object_list: if 'obj_ip_end' in nat_obj: # this nat obj is a range - include the end ip in name and uid as well to avoid akey conflicts xlate_dst.append(nat_obj['obj_nat_ip'] + '-' + nat_obj['obj_ip_end'] + nat_postfix) @@ -713,8 +716,8 @@ def handle_combined_nat_rule(rule, rule_orig, config2import, nat_rule_number, im return xlate_rule -def extract_nat_objects(nwobj_list, all_nwobjects): - nat_obj_list = [] +def extract_nat_objects(nwobj_list: list[str], all_nwobjects: list[dict[str, str]]) -> list[dict[str, str]]: + nat_obj_list: list[dict[str, str]] = [] for obj in nwobj_list: for obj2 in all_nwobjects: if obj2['obj_name']==obj: @@ -726,22 +729,22 @@ def extract_nat_objects(nwobj_list, all_nwobjects): return nat_obj_list -def add_users_to_rule(rule_orig, rule): +def add_users_to_rule(rule_orig: dict[str, Any], rule: dict[str, Any]) -> None: if 'groups' in rule_orig: add_users(rule_orig['groups'], rule) if 'users' in rule_orig: add_users(rule_orig['users'], rule) -def add_users(users, rule): +def add_users(users: list[str], rule: dict[str, Any]) -> None: for user in users: - rule_src_with_users = [] + rule_src_with_users: list[str] = [] for src in rule['rule_src'].split(list_delimiter): rule_src_with_users.append(user + '@' + src) rule['rule_src'] = list_delimiter.join(rule_src_with_users) # here user ref is the user name itself - rule_src_refs_with_users = [] + rule_src_refs_with_users: list[str] = [] for src in rule['rule_src_refs'].split(list_delimiter): rule_src_refs_with_users.append(user + '@' + src) rule['rule_src_refs'] = list_delimiter.join(rule_src_refs_with_users) diff --git a/roles/importer/files/importer/fortiadom5ff/fmgr_service.py b/roles/importer/files/importer/fortiadom5ff/fmgr_service.py index c3215ab69e..fb49bca9af 100644 --- a/roles/importer/files/importer/fortiadom5ff/fmgr_service.py +++ b/roles/importer/files/importer/fortiadom5ff/fmgr_service.py @@ -1,11 +1,9 @@ import re from fwo_const import list_delimiter -from model_controllers.import_state_controller import ImportStateController -from fwo_log import getFwoLogger from typing import Any -def normalize_service_objects(native_config, normalized_config_adom, svc_obj_types): - svc_objects = [] +def normalize_service_objects(native_config: dict[str, Any], normalized_config_adom: dict[str, Any], svc_obj_types: list[str]) -> None: + svc_objects: list[dict[str, Any]] = [] if 'objects' not in native_config: return # no objects to normalize @@ -23,7 +21,7 @@ def normalize_service_objects(native_config, normalized_config_adom, svc_obj_typ normalized_config_adom.update({'service_objects': svc_objects}) -def normalize_service_object(obj_orig, svc_objects): +def normalize_service_object(obj_orig: dict[str, Any], svc_objects: list[dict[str, Any]]) -> None: member_names = '' if 'member' in obj_orig: svc_type = 'group' @@ -37,6 +35,9 @@ def normalize_service_object(obj_orig, svc_objects): if 'name' in obj_orig: name = str(obj_orig['name']) + if name is None: + raise ValueError("Service object without name encountered") + color = 'foreground' #TODO: color mapping. what is color: 0? (nativeconfig entwickler_fortimanager_stand_2025-07-27, service object 'gALL') session_timeout = None # todo: find the right timer @@ -50,7 +51,7 @@ def normalize_service_object(obj_orig, svc_objects): add_object(svc_objects, svc_type, name, color, 0, None, None, session_timeout) -def handle_svc_protocol(obj_orig, svc_objects, svc_type, name, color, session_timeout): +def handle_svc_protocol(obj_orig: dict[str, Any], svc_objects: list[dict[str, Any]], svc_type: str, name: str, color: str, session_timeout: Any) -> None: proto = 0 range_names = '' added_svc_obj = 0 @@ -74,7 +75,7 @@ def handle_svc_protocol(obj_orig, svc_objects, svc_type, name, color, session_ti pass # not doing anything for other protocols, e.g. GRE, ESP, ... -def parse_standard_protocols_with_ports(obj_orig, svc_objects, svc_type, name, color, session_timeout, range_names, added_svc_obj): +def parse_standard_protocols_with_ports(obj_orig: dict[str, Any], svc_objects: list[dict[str, Any]], svc_type: str, name: str, color: str, session_timeout: Any, range_names: str, added_svc_obj: int) -> None: split = check_split(obj_orig) if "tcp-portrange" in obj_orig and len(obj_orig['tcp-portrange']) > 0: tcpname = name @@ -106,7 +107,7 @@ def parse_standard_protocols_with_ports(obj_orig, svc_objects, svc_type, name, c added_svc_obj += 1 -def check_split(obj_orig) -> bool: +def check_split(obj_orig: dict[str, Any]) -> bool: count = 0 if "tcp-portrange" in obj_orig and len(obj_orig['tcp-portrange']) > 0: count += 1 @@ -117,9 +118,9 @@ def check_split(obj_orig) -> bool: return (count > 1) -def extract_ports(port_ranges) -> 'tuple[list[Any], list[Any]]': - ports = [] - port_ends = [] +def extract_ports(port_ranges: list[str] | None) -> 'tuple[list[Any], list[Any]]': + ports: list[Any] = [] + port_ends: list[Any] = [] if port_ranges is not None and len(port_ranges) > 0: for port_range in port_ranges: # remove src-ports @@ -148,7 +149,7 @@ def extract_ports(port_ranges) -> 'tuple[list[Any], list[Any]]': return ports, port_ends -def create_svc_object(name, proto, color, port, comment) -> 'dict[str, Any]': +def create_svc_object(name: str, proto: int, color: str, port: Any, comment: str) -> 'dict[str, Any]': return { 'svc_name': name, 'svc_typ': 'simple', @@ -160,10 +161,10 @@ def create_svc_object(name, proto, color, port, comment) -> 'dict[str, Any]': } -def add_object(svc_objects, type, name, color, proto, port_ranges, member_names, session_timeout): +def add_object(svc_objects: list[dict[str, Any]], type: str, name: str, color: str, proto: int, port_ranges: list[str] | None, member_names: str | None, session_timeout: Any) -> None: if port_ranges is None: svc_objects.extend([{'svc_typ': type, - 'svc_name': name, + 'svc_name': name, 'svc_color': color, 'svc_uid': name, # ? 'svc_comment': None, # ? diff --git a/roles/importer/files/importer/fortiadom5ff/fmgr_user.py b/roles/importer/files/importer/fortiadom5ff/fmgr_user.py index 33ad3be8e3..a90f341476 100644 --- a/roles/importer/files/importer/fortiadom5ff/fmgr_user.py +++ b/roles/importer/files/importer/fortiadom5ff/fmgr_user.py @@ -1,7 +1,8 @@ +from typing import Any from fwo_const import list_delimiter -def normalize_users(full_config, config2import, import_id, user_scope): - users = [] +def normalize_users(full_config: dict[str, list[dict[str, Any]]], config2import: dict[str, list[dict[str, Any]]], import_id: int, user_scope: list[str]) -> None: + users: list[dict[str, Any]] = [] for scope in user_scope: for user_orig in full_config[scope]: user_normalized = _parse_user(user_orig) @@ -10,13 +11,13 @@ def normalize_users(full_config, config2import, import_id, user_scope): config2import.update({'user_objects': users}) -def _parse_user(user_orig) -> dict: +def _parse_user(user_orig: dict[str, Any]) -> dict[str, Any]: name = None svc_type = 'simple' color = None member_names = None comment = None - user = {} + user: dict[str, Any] = {} if 'member' in user_orig: svc_type = 'group' member_names = '' @@ -30,7 +31,7 @@ def _parse_user(user_orig) -> dict: if 'color' in user_orig and str(user_orig['color']) != "0": color = str(user_orig['color']) - user.update ({ 'user_typ': svc_type, + user.update({ 'user_typ': svc_type, 'user_name': name, 'user_color': color, 'user_uid': name, diff --git a/roles/importer/files/importer/fortiadom5ff/fmgr_zone.py b/roles/importer/files/importer/fortiadom5ff/fmgr_zone.py index 1677cbfdf3..cf8c20c3ab 100644 --- a/roles/importer/files/importer/fortiadom5ff/fmgr_zone.py +++ b/roles/importer/files/importer/fortiadom5ff/fmgr_zone.py @@ -3,7 +3,7 @@ from typing import Any -def get_zones(sid, fm_api_url, native_config, adom_name, limit): +def get_zones(sid: str, fm_api_url: str, native_config: dict[str, Any], adom_name: str, limit: int): if adom_name == '': fmgr_getter.update_config_with_fortinet_api_call( @@ -13,8 +13,8 @@ def get_zones(sid, fm_api_url, native_config, adom_name, limit): native_config['zones'], sid, fm_api_url, '/pm/config/adom/' + adom_name + '/obj/dynamic/interface', 'interface_' + adom_name, limit=limit) -def normalize_zones(native_config, normalized_config_adom, is_global_loop_iteration): - zones: list[dict] = [] +def normalize_zones(native_config: dict[str, Any], normalized_config_adom: dict[str, Any], is_global_loop_iteration: bool): + zones: list[dict[str, Any]] = [] fetched_zones: list[str] = [] if is_global_loop_iteration: # can not find the following zones in api return statically_add_missing_global_zones(fetched_zones) @@ -39,7 +39,7 @@ def statically_add_missing_global_zones(fetched_zones: list[str]) -> None: # double check, if these zones cannot be parsed from api results -def fetch_dynamic_mapping(mapping, fetched_zones): +def fetch_dynamic_mapping(mapping: dict[str, Any], fetched_zones: list[str]) -> None: for dyn_mapping in mapping['dynamic_mapping']: if 'name' in dyn_mapping and not dyn_mapping['name'] in fetched_zones: fetched_zones.append(dyn_mapping['name']) @@ -48,14 +48,14 @@ def fetch_dynamic_mapping(mapping, fetched_zones): if local_interface not in fetched_zones: fetched_zones.append(local_interface) -def fetch_platform_mapping(mapping, fetched_zones): +def fetch_platform_mapping(mapping: dict[str, Any], fetched_zones: list[str]) -> None: for dyn_mapping in mapping['platform_mapping']: if 'intf-zone' in dyn_mapping and not dyn_mapping['intf-zone'] in fetched_zones: fetched_zones.append(dyn_mapping['intf-zone']) -def find_zones_in_normalized_config(native_zone_list : list, normalized_config_adom, normalized_config_global): +def find_zones_in_normalized_config(native_zone_list: list[str], normalized_config_adom: dict[str, Any], normalized_config_global: dict[str, Any]) -> list[str]: """Verifies that input zones exist in normalized config""" - zone_out_list = [] + zone_out_list: list[str] = [] for nativ_zone in native_zone_list: was_zone_found = False for normalized_zone in normalized_config_adom['zone_objects'] + normalized_config_global['zone_objects']: diff --git a/roles/importer/files/importer/fortiadom5ff/fwcommon.py b/roles/importer/files/importer/fortiadom5ff/fwcommon.py index 835556b01d..2469028fd9 100644 --- a/roles/importer/files/importer/fortiadom5ff/fwcommon.py +++ b/roles/importer/files/importer/fortiadom5ff/fwcommon.py @@ -6,11 +6,9 @@ from fwo_base import write_native_config_to_file import fmgr_getter from fwo_log import getFwoLogger -from fmgr_gw_networking import getInterfacesAndRouting, normalize_network_data from model_controllers.fwconfigmanagerlist_controller import FwConfigManagerListController from model_controllers.fwconfig_normalized_controller import FwConfigNormalizedController from models.fwconfigmanager import FwConfigManager -from model_controllers.management_controller import ManagementController from fmgr_network import normalize_network_objects from fmgr_service import normalize_service_objects from fmgr_rule import normalize_rulebases, get_access_policy, get_nat_policy @@ -18,17 +16,18 @@ from fwo_base import ConfigAction from fmgr_zone import get_zones, normalize_zones from models.fwconfig_normalized import FwConfigNormalized +from models.management import Management -def has_config_changed(full_config, mgm_details, force=False): +def has_config_changed(full_config: dict[str, Any], mgm_details: Management, force: bool = False): # dummy - may be filled with real check later on return True def get_config(config_in: FwConfigManagerListController, importState: ImportStateController): logger = getFwoLogger() - if config_in.has_empty_config(): # no native config was passed in, so getting it from FW-Manager - config_in.native_config.update({'domains': []}) + if config_in.has_empty_config(): # no native config was passed in, so getting it from FW-Manager + config_in.native_config.update({'domains': []}) # type: ignore #TYPING: What is this? None or not None this is the question parsing_config_only = False else: parsing_config_only = True @@ -38,7 +37,7 @@ def get_config(config_in: FwConfigManagerListController, importState: ImportStat limit = importState.FwoConfig.ApiFetchSize fm_api_url = importState.MgmDetails.buildFwApiString() native_config_global = initialize_native_config_domain(importState.MgmDetails) - config_in.native_config['domains'].append(native_config_global) + config_in.native_config['domains'].append(native_config_global) # type: ignore #TYPING: None or not None this is the question adom_list = build_adom_list(importState) adom_device_vdom_structure = build_adom_device_vdom_structure(adom_list, sid, fm_api_url) # delete_v: das geht schief für unschöne adoms @@ -52,7 +51,7 @@ def get_config(config_in: FwConfigManagerListController, importState: ImportStat for adom in adom_list: adom_name = adom.DomainName native_config_adom = initialize_native_config_domain(adom) - config_in.native_config['domains'].append(native_config_adom) + config_in.native_config['domains'].append(native_config_adom) # type: ignore #TYPING: None or not None this is the question adom_scope = 'adom/'+adom_name get_objects(sid, fm_api_url, native_config_adom, native_config_global, adom_name, limit, nw_obj_types, svc_obj_types, adom_scope, arbitrary_vdom_for_updateable_objects) @@ -79,12 +78,12 @@ def get_config(config_in: FwConfigManagerListController, importState: ImportStat write_native_config_to_file(importState, config_in.native_config) - normalized_managers = normalize_config(importState, config_in.native_config) + normalized_managers = normalize_config(importState, config_in.native_config) # type: ignore #TYPING: None or not None this is the question logger.info("completed getting config") return 0, normalized_managers -def initialize_native_config_domain(mgm_details : ManagementController): +def initialize_native_config_domain(mgm_details: Management) -> dict[str, Any]: return { 'domain_name': mgm_details.DomainName, 'domain_uid': mgm_details.DomainUid, @@ -97,14 +96,14 @@ def initialize_native_config_domain(mgm_details : ManagementController): 'zones': [], 'gateways': []} -def get_arbitrary_vdom(adom_device_vdom_structure): +def get_arbitrary_vdom(adom_device_vdom_structure: dict[str, dict[str, dict[str, Any]]]) -> dict[str, str] | None: for adom in adom_device_vdom_structure: for device in adom_device_vdom_structure[adom]: for vdom in adom_device_vdom_structure[adom][device]: return {'adom': adom, 'device': device, 'vdom': vdom} -def normalize_config(import_state, native_config: 'dict[str,Any]') -> FwConfigManagerListController: +def normalize_config(import_state: ImportStateController, native_config: dict[str,Any]) -> FwConfigManagerListController: manager_list = FwConfigManagerListController() @@ -113,7 +112,7 @@ def normalize_config(import_state, native_config: 'dict[str,Any]') -> FwConfigMa rewrite_native_config_obj_type_as_key(native_config) # for easier accessability of objects in normalization process - native_config_global = {} + native_config_global: dict[str, Any] = {} normalized_config_global = {} for native_conf in native_config['domains']: @@ -152,14 +151,14 @@ def normalize_config(import_state, native_config: 'dict[str,Any]') -> FwConfigMa return manager_list -def rewrite_native_config_obj_type_as_key(native_config): +def rewrite_native_config_obj_type_as_key(native_config: dict[str, Any]): # rewrite native config objects to have the object type as key # this is needed for the normalization process for domain in native_config['domains']: if 'objects' not in domain: continue - obj_dict = {} + obj_dict: dict[str, Any] = {} for obj_chunk in domain['objects']: if 'type' not in obj_chunk: continue @@ -168,8 +167,8 @@ def rewrite_native_config_obj_type_as_key(native_config): domain['objects'] = obj_dict -def normalize_single_manager_config(native_config: 'dict[str, Any]', native_config_global: 'dict[str, Any]', normalized_config_adom: dict, - normalized_config_global: dict, import_state: ImportStateController, +def normalize_single_manager_config(native_config: 'dict[str, Any]', native_config_global: 'dict[str, Any]', normalized_config_adom: dict[str, Any], + normalized_config_global: dict[str, Any], import_state: ImportStateController, is_global_loop_iteration: bool): current_nw_obj_types = deepcopy(nw_obj_types) @@ -198,15 +197,15 @@ def normalize_single_manager_config(native_config: 'dict[str, Any]', native_conf normalize_gateways(native_config, normalized_config_adom) -def build_adom_list(importState : ImportStateController): - adom_list = [] +def build_adom_list(importState : ImportStateController) -> list[Management]: + adom_list: list[Management] = [] if importState.MgmDetails.IsSuperManager: for subManager in importState.MgmDetails.SubManagers: adom_list.append(deepcopy(subManager)) return adom_list -def build_adom_device_vdom_structure(adom_list, sid, fm_api_url) -> dict: - adom_device_vdom_structure = {} +def build_adom_device_vdom_structure(adom_list: list[Management], sid: str, fm_api_url: str) -> dict[str, dict[str, dict[str, Any]]]: + adom_device_vdom_structure: dict[str, dict[str, dict[str, Any]]] = {} for adom in adom_list: adom_device_vdom_structure.update({adom.DomainName: {}}) if len(adom.Devices) > 0: @@ -214,7 +213,7 @@ def build_adom_device_vdom_structure(adom_list, sid, fm_api_url) -> dict: adom_device_vdom_structure[adom.DomainName].update(device_vdom_dict) return adom_device_vdom_structure -def add_policy_package_to_vdoms(adom_device_vdom_structure, sid, fm_api_url): +def add_policy_package_to_vdoms(adom_device_vdom_structure: dict[str, dict[str, dict[str, str]]], sid: str, fm_api_url: str) -> dict[str, dict[str, dict[str, Any]]]: adom_device_vdom_policy_package_structure = deepcopy(adom_device_vdom_structure) for adom in adom_device_vdom_policy_package_structure: policy_packages_result = fmgr_getter.fortinet_api_call(sid, fm_api_url, '/pm/pkg/adom/' + adom) @@ -224,7 +223,7 @@ def add_policy_package_to_vdoms(adom_device_vdom_structure, sid, fm_api_url): add_global_policy_package_to_vdom(adom_device_vdom_policy_package_structure, sid, fm_api_url, adom) return adom_device_vdom_policy_package_structure -def parse_policy_package(policy_package, adom_device_vdom_policy_package_structure, adom): +def parse_policy_package(policy_package: dict[str, Any], adom_device_vdom_policy_package_structure: dict[str, dict[str, dict[str, Any]]], adom: str): for scope_member in policy_package['scope member']: for device in adom_device_vdom_policy_package_structure[adom]: if device == scope_member['name']: @@ -232,7 +231,7 @@ def parse_policy_package(policy_package, adom_device_vdom_policy_package_structu if vdom == scope_member['vdom']: adom_device_vdom_policy_package_structure[adom][device].update({vdom: {'local': policy_package['name'], 'global': ''}}) -def add_global_policy_package_to_vdom(adom_device_vdom_policy_package_structure, sid, fm_api_url, adom): +def add_global_policy_package_to_vdom(adom_device_vdom_policy_package_structure: dict[str, dict[str, dict[str, Any]]], sid: str, fm_api_url: str, adom: str): global_assignment_result = fmgr_getter.fortinet_api_call(sid, fm_api_url, '/pm/config/adom/' + adom + '/_adom/options') for global_assignment in global_assignment_result: if global_assignment['assign_excluded'] == 0 and global_assignment['specify_assign_pkg_list'] == 0: @@ -244,22 +243,22 @@ def add_global_policy_package_to_vdom(adom_device_vdom_policy_package_structure, else: raise ImportInterruption('Broken global assign format.') -def assign_case_all(adom_device_vdom_policy_package_structure, adom, global_assignment): +def assign_case_all(adom_device_vdom_policy_package_structure: dict[str, dict[str, dict[str, Any]]], adom: str, global_assignment: dict[str, Any]): for device in adom_device_vdom_policy_package_structure[adom]: for vdom in adom_device_vdom_policy_package_structure[adom][device]: adom_device_vdom_policy_package_structure[adom][device][vdom]['global'] = global_assignment['assign_name'] -def assign_case_include(adom_device_vdom_policy_package_structure, adom, global_assignment): +def assign_case_include(adom_device_vdom_policy_package_structure: dict[str, dict[str, dict[str, Any]]], adom: str, global_assignment: dict[str, Any]): for device in adom_device_vdom_policy_package_structure[adom]: for vdom in adom_device_vdom_policy_package_structure[adom][device]: match_assign_and_vdom_policy_package(global_assignment, adom_device_vdom_policy_package_structure[adom][device][vdom], True) -def assign_case_exclude(adom_device_vdom_policy_package_structure, adom, global_assignment): +def assign_case_exclude(adom_device_vdom_policy_package_structure: dict[str, dict[str, dict[str, Any]]], adom: str, global_assignment: dict[str, Any]): for device in adom_device_vdom_policy_package_structure[adom]: for vdom in adom_device_vdom_policy_package_structure[adom][device]: match_assign_and_vdom_policy_package(global_assignment, adom_device_vdom_policy_package_structure[adom][device][vdom], False) -def match_assign_and_vdom_policy_package(global_assignment, vdom_structure, is_include): +def match_assign_and_vdom_policy_package(global_assignment: dict[str, Any], vdom_structure: dict[str, Any], is_include: bool): for package in global_assignment['pkg list']: if is_include: if package['name'] == vdom_structure['local']: @@ -268,8 +267,8 @@ def match_assign_and_vdom_policy_package(global_assignment, vdom_structure, is_i if package['name'] != vdom_structure['local']: vdom_structure['global'] = global_assignment['assign_name'] -def initialize_device_config(mgm_details_device): - device_config = {'name': mgm_details_device['name'], +def initialize_device_config(mgm_details_device: dict[str, Any]) -> dict[str, Any]: + device_config: dict[str, Any] = {'name': mgm_details_device['name'], 'uid': mgm_details_device['uid'], 'rulebase_links': []} return device_config @@ -284,7 +283,7 @@ def get_sid(importState: ImportStateController): return sid -def get_objects(sid, fm_api_url, native_config_domain, native_config_global, adom_name, limit, nw_obj_types, svc_obj_types, adom_scope, arbitrary_vdom_for_updateable_objects): +def get_objects(sid: str, fm_api_url: str, native_config_domain: dict[str, Any], native_config_global: dict[str, Any], adom_name: str, limit: int, nw_obj_types: list[str], svc_obj_types: list[str], adom_scope: str, arbitrary_vdom_for_updateable_objects: dict[str, Any] | None): # get those objects that exist globally and on adom level # get network objects: @@ -312,7 +311,7 @@ def get_objects(sid, fm_api_url, native_config_domain, native_config_global, ado return if arbitrary_vdom_for_updateable_objects['adom'] == adom_name: # get dynamic objects - payload = { + payload: dict[str, Any] = { 'params': [ { 'data': { @@ -329,7 +328,7 @@ def get_objects(sid, fm_api_url, native_config_domain, native_config_global, ado native_config_global['objects'], sid, fm_api_url, "sys/proxy/json", "nw_obj_global_firewall/internet-service-basic", limit=limit, payload=payload, method='exec') -def normalize_gateways(native_config, normalized_config_adom): +def normalize_gateways(native_config: dict[str, Any], normalized_config_adom: dict[str, Any]): for gateway in native_config['gateways']: normalized_gateway = {} normalized_gateway['Uid'] = gateway['uid'] @@ -339,15 +338,15 @@ def normalize_gateways(native_config, normalized_config_adom): normalized_gateway['RulebaseLinks'] = normalize_links(gateway['rulebase_links']) normalized_config_adom['gateways'].append(normalized_gateway) -def normalize_interfaces(): +def normalize_interfaces() -> list[Any]: # TODO return [] -def normalize_routing(): +def normalize_routing() -> list[Any]: # TODO return [] -def normalize_links(rulebase_links : list): +def normalize_links(rulebase_links : list[dict[str, Any]]) -> list[dict[str, Any]]: for link in rulebase_links: link['link_type'] = link.pop('type') diff --git a/roles/importer/files/importer/fwconfig_base.py b/roles/importer/files/importer/fwconfig_base.py index 70d586a5bc..35431864d1 100644 --- a/roles/importer/files/importer/fwconfig_base.py +++ b/roles/importer/files/importer/fwconfig_base.py @@ -4,7 +4,7 @@ class FwoEncoder(json.JSONEncoder): - def default(self, obj): + def default(self, obj: object) -> object: # type: ignore if isinstance(obj, ConfigAction) or isinstance(obj, ConfFormat): return obj.name @@ -12,7 +12,7 @@ def default(self, obj): return json.JSONEncoder.default(self, obj) -def replaceNoneWithEmpty(s): +def replaceNoneWithEmpty(s: str | None) -> str: if s is None or s == '': return '' else: diff --git a/roles/importer/files/importer/fwo_alert.py b/roles/importer/files/importer/fwo_alert.py index eccb1b29e2..e8402bfeb5 100644 --- a/roles/importer/files/importer/fwo_alert.py +++ b/roles/importer/files/importer/fwo_alert.py @@ -2,8 +2,8 @@ import json import fwo_const import fwo_log - -def getFwoAlerter(): +# TODO delete this file +def getFwoAlerter() -> dict[str, str]: logger = fwo_log.getFwoLogger() try: with open(fwo_const.fwo_config_filename, "r") as fwo_config: @@ -21,7 +21,7 @@ def getFwoAlerter(): logger.error("getFwoAlerter - error while reading importer pwd file") raise - jwt = fwo_api_call.login(fwo_const.importer_user_name, importer_pwd, user_management_api_base_url) + jwt = fwo_api_call.login(fwo_const.importer_user_name, importer_pwd, user_management_api_base_url) # type: ignore return { "fwo_api_base_url": fwo_api_base_url, "jwt": jwt } diff --git a/roles/importer/files/importer/fwo_api.py b/roles/importer/files/importer/fwo_api.py index f04c8324fc..c999204704 100644 --- a/roles/importer/files/importer/fwo_api.py +++ b/roles/importer/files/importer/fwo_api.py @@ -1,11 +1,10 @@ -import requests.packages import requests import json import traceback import time from pprint import pformat import string -from typing import Any +from typing import Any, MutableMapping import fwo_globals from fwo_log import getFwoLogger @@ -23,18 +22,18 @@ class FwoApi(): FwoApiUrl: str FwoJwt: str - query_info: dict + query_info: dict[str, Any] query_analyzer: QueryAnalyzer - def __init__(self, ApiUri, Jwt): + def __init__(self, ApiUri: str, Jwt: str): self.FwoApiUrl = ApiUri self.FwoJwt = Jwt self.query_info = {} self.query_analyzer = QueryAnalyzer() - def call(self, query, query_variables={}, debug_level=0, analyze_payload=False) -> dict: + def call(self, query: str, query_variables: dict[str, list[Any] | Any] = {}, debug_level: int = 0, analyze_payload: bool = False) -> dict[str, Any]: """ The standard FWO API call. """ @@ -45,7 +44,7 @@ def call(self, query, query_variables={}, debug_level=0, analyze_payload=False) 'Authorization': f'Bearer {self.FwoJwt}', 'x-hasura-role': role } - full_query = {"query": query, "variables": query_variables} + full_query: dict[str, Any] = {"query": query, "variables": query_variables} logger = getFwoLogger(debug_level=debug_level) return_object = {} @@ -62,7 +61,7 @@ def call(self, query, query_variables={}, debug_level=0, analyze_payload=False) if analyze_payload and self.query_info["chunking_info"]["needs_chunking"]: started = time.time() - return_object = self._call_chunked(session, query, query_variables, fwo_globals.debug_level) + return_object: dict[str, Any] = self._call_chunked(session, query, query_variables, fwo_globals.debug_level) elapsed_time = time.time() - started affected_rows = 0 if 'data' in return_object.keys() and 'affected_rows' in return_object['data'].keys(): @@ -71,7 +70,7 @@ def call(self, query, query_variables={}, debug_level=0, analyze_payload=False) logger.debug(f"Chunked API call ({self.query_info['query_name']}) processed in {elapsed_time:.4f} s. Affected rows: {affected_rows}.") self.query_info = {} else: - return_object = self._post_query(session, full_query) + return_object: dict[str, Any] = self._post_query(session, full_query) self._try_show_api_call_info(full_query, request_headers, fwo_globals.debug_level) @@ -93,10 +92,14 @@ def call(self, query, query_variables={}, debug_level=0, analyze_payload=False) logger.error(f"Unexpected error during API call: {str(e)}") raise FwoImporterError(f"return_object not defined. Error during API call: {str(e)}") raise FwoImporterError(f"Unexpected error during API call: {str(e)}") + return return_object @staticmethod - def login(user, password, user_management_api_base_url, method='api/AuthenticationToken/Get'): - payload = {"Username": user, "Password": password} + def login(user: str, password: str | None, user_management_api_base_url: str | None, method: str = 'api/AuthenticationToken/Get'): + payload: dict[str, str | None] = {"Username": user, "Password": password} + + if user_management_api_base_url is None: + raise FwoApiLoginFailed("fwo_api: user_management_api_base_url is None during login") with requests.Session() as session: if fwo_globals.verify_certs is None: # only for first FWO API call (getting info on cert verification) @@ -110,7 +113,7 @@ def login(user, password, user_management_api_base_url, method='api/Authenticati except requests.exceptions.RequestException: raise FwoApiLoginFailed ("fwo_api: error during login to url: " + str(user_management_api_base_url) + " with user " + user) from None - if response.text is not None and response.status_code==200: + if response.status_code==200: return response.text else: error_txt = "fwo_api: ERROR: did not receive a JWT during login" + \ @@ -188,7 +191,7 @@ def call_endpoint(self, method: str, endpoint: str, params: Any = None) -> Any: logger.error(f"Middleware API request failed: {str(e)}") raise FwoImporterError(f"Middleware API request failed: {str(e)}") - def _handle_request_exception(self, exception, query_payload, headers): + def _handle_request_exception(self, exception: requests.exceptions.RequestException, query_payload: dict[str, Any], headers: dict[str, Any]) -> None: """ Error handling for the standard API call. """ @@ -205,7 +208,7 @@ def _handle_request_exception(self, exception, query_payload, headers): raise exception - def _call_chunked(self, session, query, query_variables: dict = {}, debug_level=0): + def _call_chunked(self, session: requests.Session, query: str, query_variables: dict[str, list[Any]] = {}, debug_level: int = 0) -> dict[str, Any]: """ Splits a defined query variable into chunks and posts the queries chunk by chunk. """ @@ -260,8 +263,8 @@ def _call_chunked(self, session, query, query_variables: dict = {}, debug_level= return return_object - def _update_query_variables_by_chunk(self, query_variables, chunkable_variables): - chunks = {} + def _update_query_variables_by_chunk(self, query_variables: dict[str, list[Any]], chunkable_variables: dict[str, list[Any]]) -> int: + chunks: dict[str, Any] = {} total_chunk_elements = 0 for variable, list_object in chunkable_variables.items(): @@ -275,7 +278,7 @@ def _update_query_variables_by_chunk(self, query_variables, chunkable_variables) return total_chunk_elements - def _handle_chunked_calls_response(self, return_object, response): + def _handle_chunked_calls_response(self, return_object: dict[str, Any], response: dict[str, Any]) -> dict[str, Any]: logger = getFwoLogger(debug_level=int(fwo_globals.debug_level)) if return_object == {}: @@ -304,10 +307,10 @@ def _handle_chunked_calls_response(self, return_object, response): return return_object - def _handle_chunked_calls_response_with_return_data(self, return_object, new_return_object_type, new_return_object): + def _handle_chunked_calls_response_with_return_data(self, return_object: dict[str, Any], new_return_object_type: str, new_return_object: dict[str, Any] | list[Any]) -> None: total_affected_rows = 0 - returning_data = [] + returning_data: list[dict[str, Any]] = [] self._try_write_extended_log(debug_level=9, message=f"Handling chunked calls response for type '{new_return_object_type}' with data: {pformat(new_return_object)}") @@ -336,7 +339,7 @@ def _handle_chunked_calls_response_with_return_data(self, return_object, new_ret return_object["data"][new_return_object_type]["returning"].extend(returning_data) - def _post_query(self, session, query_payload): + def _post_query(self, session: requests.Session, query_payload: dict[str, Any]) -> dict[str, Any]: """ Posts the given payload to the api endpoint. Returns the response as json or None if the response object is None. """ @@ -356,7 +359,7 @@ def _post_query(self, session, query_payload): return r.json() - def show_api_call_info(self, url, query, headers, type='debug'): + def show_api_call_info(self, url: str, query: dict[str, Any], headers: dict[str, Any], type: str='debug'): max_query_size_to_display = 1000 query_string = json.dumps(query, indent=2) header_string = json.dumps(headers, indent=2) @@ -375,7 +378,7 @@ def show_api_call_info(self, url, query, headers, type='debug'): result += "\n and headers: \n" + header_string return result - def _try_show_api_call_info(self, full_query, request_headers, debug_level): + def _try_show_api_call_info(self, full_query: dict[str, Any], request_headers: dict[str, Any], debug_level: int) -> None: """ Tries to show the API call info if the debug level is high enough. """ @@ -384,7 +387,7 @@ def _try_show_api_call_info(self, full_query, request_headers, debug_level): logger.debug(self.showImportApiCallInfo(self.FwoApiUrl, full_query, request_headers, typ='debug', show_query_info=True)) - def _try_write_extended_log(self, debug_level, message): + def _try_write_extended_log(self, debug_level: int, message: str) -> None: """ Writes an extended log message if the debug level is high enough. """ @@ -393,7 +396,7 @@ def _try_write_extended_log(self, debug_level, message): logger.debug(message) - def showImportApiCallInfo(self, api_url, query, headers, typ='debug', show_query_info=False): + def showImportApiCallInfo(self, api_url: str, query: dict[str, Any], headers: dict[str, Any] | MutableMapping[str, str | bytes], typ: str ='debug', show_query_info: bool = False): max_query_size_to_display = 1000 query_string = json.dumps(query, indent=2) header_string = json.dumps(dict(headers), indent=2) diff --git a/roles/importer/files/importer/fwo_api_call.py b/roles/importer/files/importer/fwo_api_call.py index 293ba4a834..4a55a84e09 100644 --- a/roles/importer/files/importer/fwo_api_call.py +++ b/roles/importer/files/importer/fwo_api_call.py @@ -3,6 +3,7 @@ import json import datetime import time +from typing import TYPE_CHECKING, Any import fwo_const import fwo_globals @@ -10,8 +11,12 @@ from fwo_api import FwoApi from fwo_exceptions import FwoApiFailedLockImport from query_analyzer import QueryAnalyzer -from model_controllers.import_statistics_controller import ImportStatisticsController -from models.management import Management +from model_controllers.management_controller import ManagementController +from models.fwconfig_normalized import FwConfigNormalized + +if TYPE_CHECKING: + from model_controllers.import_state_controller import ImportStateController + from model_controllers.import_statistics_controller import ImportStatisticsController # NOTE: we cannot import ImportState(Controller) here due to circular refs @@ -24,14 +29,17 @@ def __init__(self, api: FwoApi): self.query_analyzer = QueryAnalyzer() - def get_mgm_ids(self, query_variables): + def get_mgm_ids(self, query_variables: dict[str, list[Any]]) -> list[dict[str, Any]]: # TODO: confirm return type # from 9.0 do not import sub-managers separately mgm_query = FwoApi.get_graphql_code([fwo_const.graphql_query_path + "device/getManagementWithSubs.graphql"]) - return self.call(mgm_query, query_variables=query_variables)['data']['management'] + result = self.call(mgm_query, query_variables=query_variables) + if 'data' in result and 'management' in result['data']: + return result['data']['management'] + return [] - def get_config_value(self, key='limit') -> str|None: - query_variables = {'key': key} + def get_config_value(self, key: str='limit') -> str|None: + query_variables: dict[str, str] = {'key': key} cfg_query = FwoApi.get_graphql_code([fwo_const.graphql_query_path + "config/getConfigValue.graphql"]) try: @@ -51,8 +59,8 @@ def get_config_value(self, key='limit') -> str|None: return None - def get_config_values(self, keyFilter='limit'): - query_variables = {'keyFilter': keyFilter+"%"} + def get_config_values(self, keyFilter:str='limit') -> dict[str, str]|None: + query_variables: dict[str, str] = {'keyFilter': keyFilter+"%"} config_query = FwoApi.get_graphql_code([fwo_const.graphql_query_path + "config/getConfigValuesByKeyFilter.graphql"]) try: @@ -64,21 +72,21 @@ def get_config_values(self, keyFilter='limit'): if 'data' in result and 'config' in result['data']: resultArray = result['data']['config'] - dict1 = {v['config_key']: v['config_value'] for k,v in enumerate(resultArray)} + dict1 = {v['config_key']: v['config_value'] for _,v in enumerate(resultArray)} return dict1 else: return None # this mgm field is used by mw dailycheck scheduler - def log_import_attempt(self, mgm_id, successful=False): + def log_import_attempt(self, mgm_id: int, successful: bool = False): now = datetime.datetime.now().isoformat() - query_variables = { "mgmId": mgm_id, "timeStamp": now, "success": successful } + query_variables: dict[str, Any] = { "mgmId": mgm_id, "timeStamp": now, "success": successful } mgm_mutation = FwoApi.get_graphql_code([fwo_const.graphql_query_path + "import/updateManagementLastImportAttempt.graphql"]) return self.call(mgm_mutation, query_variables=query_variables) - def setImportLock(self, mgm_details: Management, is_full_import: int = False, is_initial_import: int = False, debug_level: int = 0) -> int: + def setImportLock(self, mgm_details: ManagementController, is_full_import: int = False, is_initial_import: int = False, debug_level: int = 0) -> int: logger = getFwoLogger(debug_level=debug_level) import_id = -1 mgm_id = mgm_details.Id @@ -94,14 +102,14 @@ def setImportLock(self, mgm_details: Management, is_full_import: int = False, is if import_id == -1: self.create_data_issue(mgm_id=int(mgm_id), severity=1, description="failed to get import lock for management id " + str(mgm_id)) - self.set_alert(import_id=import_id, title="import error", mgm_id=str(mgm_id), severity=1, \ + self.set_alert(import_id=import_id, title="import error", mgm_id=mgm_id, severity=1, \ description="fwo_api: failed to get import lock", source='import', alertCode=15, mgm_details=mgm_details) raise FwoApiFailedLockImport("fwo_api: failed to get import lock for management id " + str(mgm_id)) from None else: return import_id - def count_rule_changes_per_import(self, import_id): + def count_rule_changes_per_import(self, import_id: int): logger = getFwoLogger() change_count_query = FwoApi.get_graphql_code([fwo_const.graphql_query_path + "import/getRuleChangesPerImport.graphql"]) try: @@ -113,7 +121,7 @@ def count_rule_changes_per_import(self, import_id): return rule_changes_in_import - def count_any_changes_per_import(self, import_id): + def count_any_changes_per_import(self, import_id: int): logger = getFwoLogger() change_count_query = FwoApi.get_graphql_code([fwo_const.graphql_query_path + "import/getChangesPerImport.graphql"]) try: @@ -128,10 +136,10 @@ def count_any_changes_per_import(self, import_id): return changes_in_import - def unlock_import(self, import_id: int, mgm_id: int, import_stats: ImportStatisticsController) -> int: + def unlock_import(self, import_id: int, mgm_id: int, import_stats: 'ImportStatisticsController') -> int: logger = getFwoLogger() error_during_import_unlock = 0 - query_variables = {"stopTime": datetime.datetime.now().isoformat(), "importId": import_id, + query_variables: dict[str, Any] = {"stopTime": datetime.datetime.now().isoformat(), "importId": import_id, "success": import_stats.ErrorCount == 0, "anyChangesFound": import_stats.getTotalChangeNumber() > 0, "ruleChangesFound": import_stats.getRuleChangeNumber() > 0, "changeNumber": import_stats.getRuleChangeNumber()} @@ -141,21 +149,21 @@ def unlock_import(self, import_id: int, mgm_id: int, import_stats: ImportStatist unlock_result = self.call(unlock_mutation, query_variables=query_variables) if 'errors' in unlock_result: raise FwoApiFailedLockImport(unlock_result['errors']) - changes_in_import_control = unlock_result['data']['update_import_control']['affected_rows'] - except Exception as e: + _ = unlock_result['data']['update_import_control']['affected_rows'] + except Exception as _: logger.exception("failed to unlock import for management id " + str(mgm_id)) error_during_import_unlock = 1 return error_during_import_unlock # currently temporarily only working with single chunk - def import_json_config(self, importState, config, startImport=True): + def import_json_config(self, importState: 'ImportStateController', config: FwConfigNormalized, startImport: bool = True): logger = getFwoLogger(debug_level=importState.DebugLevel) import_mutation = FwoApi.get_graphql_code([fwo_const.graphql_query_path + "import/addImportConfig.graphql"]) try: debug_mode = (fwo_globals.debug_level>0) - query_vars = { + query_vars: dict[str, Any] = { 'debug_mode': debug_mode, 'mgmId': importState.MgmDetails.Id, 'importId': importState.ImportId, @@ -180,7 +188,7 @@ def import_json_config(self, importState, config, startImport=True): return 1 - def delete_json_config_in_import_table(self, importState, query_variables): + def delete_json_config_in_import_table(self, importState: 'ImportStateController', query_variables: dict[str, Any]) -> int: logger = getFwoLogger(debug_level=importState.DebugLevel) delete_mutation = FwoApi.get_graphql_code([fwo_const.graphql_query_path + "import/deleteImportConfig.graphql"]) try: @@ -192,20 +200,20 @@ def delete_json_config_in_import_table(self, importState, query_variables): return changes_in_delete_config - def get_error_string_from_imp_control(self, importState, query_variables): + def get_error_string_from_imp_control(self, _: 'ImportStateController', query_variables: dict[str, Any]) -> list[dict[str, Any]]: # TODO: confirm return type error_query = "query getErrors($importId:bigint) { import_control(where:{control_id:{_eq:$importId}}) { import_errors } }" return self.call(error_query, query_variables=query_variables)['data']['import_control'] - def create_data_issue(self, import_id=None, obj_name=None, mgm_id=None, dev_id=None, severity=1, - rule_uid=None, object_type=None, description=None, source='import'): + def create_data_issue(self, importId: int | None = None, obj_name: str | None = None, mgm_id: int | None = None, dev_id: int | None = None, severity: int = 1, + rule_uid: str | None = None, object_type: str | None = None, description: str | None = None, source: str = 'import') -> bool: logger = getFwoLogger() if obj_name=='all' or obj_name=='Original': return True # ignore resolve errors for enriched objects that are not in the native config create_data_issue_mutation = FwoApi.get_graphql_code([fwo_const.graphql_query_path + "monitor/addLogEntry.graphql"]) - query_variables = {"source": source, "severity": severity } + query_variables: dict[str, Any] = {"source": source, "severity": severity } if dev_id is not None: query_variables.update({"devId": dev_id}) @@ -225,13 +233,12 @@ def create_data_issue(self, import_id=None, obj_name=None, mgm_id=None, dev_id=N changes = result['data']['insert_log_data_issue']['returning'] except Exception as e: logger.error(f"failed to create log_data_issue: {json.dumps(query_variables)}: {str(e)}") - raise - return False + raise # TODO: or return False? return len(changes)==1 - def set_alert(self, import_id=None, title=None, mgm_id=None, dev_id=None, severity=1, - jsonData=None, description=None, source='import', user_id=None, refAlert=None, alertCode=None, mgm_details = None): + def set_alert(self, import_id: int | None = None, title: str | None = None, mgm_id: int | None = None, dev_id: int | None = None, severity: int | None = 1, + jsonData: dict[str, Any] | None = None, description: str | None = None, source: str = 'import', user_id: int | None = None, refAlert: str | None = None, alertCode: int | None = None, mgm_details: ManagementController | None = None): logger = getFwoLogger() @@ -259,7 +266,7 @@ def set_alert(self, import_id=None, title=None, mgm_id=None, dev_id=None, severi if alertCode is None or mgm_id is not None: return True # Acknowledge older alert for same problem on same management - query_variables = { "mgmId": mgm_id, "alertCode": alertCode, "currentAlertId": newAlertId } + query_variables: dict[str, Any] = { "mgmId": mgm_id, "alertCode": alertCode, "currentAlertId": newAlertId } existingUnacknowledgedAlerts = self.call(getAlert_query, query_variables=query_variables) if 'data' not in existingUnacknowledgedAlerts or 'alert' not in existingUnacknowledgedAlerts['data']: return False @@ -267,13 +274,13 @@ def set_alert(self, import_id=None, title=None, mgm_id=None, dev_id=None, severi if 'alert_id' in alert: now = datetime.datetime.now().isoformat() query_variables = { "userId": 0, "alertId": alert['alert_id'], "ackTimeStamp": now } - updateResult = self.call(ackAlert_mutation, query_variables=query_variables) + _ = self.call(ackAlert_mutation, query_variables=query_variables) except Exception as e: logger.error(f"failed to create alert entry: {json.dumps(query_variables)}; exception: {str(e)}") raise return True - def _set_alert_build_query_vars(self, query_variables, dev_id, user_id, mgm_id, refAlert, title, description, alertCode): + def _set_alert_build_query_vars(self, query_variables: dict[str, Any], dev_id: int | None, user_id: int | None, mgm_id: int | None, refAlert: str | None, title: str | None, description: str | None, alertCode: int | None): if dev_id is not None: query_variables.update({"devId": dev_id}) if user_id is not None: @@ -290,7 +297,7 @@ def _set_alert_build_query_vars(self, query_variables, dev_id, user_id, mgm_id, query_variables.update({"alertCode": alertCode}) - def complete_import(self, importState: "ImportStateController"): + def complete_import(self, importState: 'ImportStateController'): logger = getFwoLogger(debug_level=importState.DebugLevel) if fwo_globals.shutdown_requested: @@ -322,7 +329,7 @@ def complete_import(self, importState: "ImportStateController"): if importState.Stats.getChangeDetails() != {} and importState.DebugLevel>3 and len(importState.getErrors()) == 0: import_result += ", change details: " + str(importState.Stats.getChangeDetails()) if importState.Stats.ErrorCount>0: - self.create_data_issue(import_id=importState.ImportId, severity=1, description=importState.getErrorString()) + self.create_data_issue(importId=importState.ImportId, severity=1, description=importState.getErrorString()) self.set_alert(import_id=importState.ImportId, title="import error", mgm_id=importState.MgmDetails.Id, severity=2, \ description=str(importState.getErrorString()), source='import', alertCode=14, mgm_details=importState.MgmDetails) if not importState.Stats.ErrorAlreadyLogged: @@ -330,7 +337,7 @@ def complete_import(self, importState: "ImportStateController"): importState.Stats.ErrorAlreadyLogged = True - def get_last_complete_import(self, query_vars, debug_level=0) -> tuple[int, str]: + def get_last_complete_import(self, query_vars: dict[str, Any], debug_level: int = 0) -> tuple[int, str]: mgm_query = FwoApi.get_graphql_code([fwo_const.graphql_query_path + "import/getLastCompleteImport.graphql"]) lastFullImportDate: str = "" lastFullImportId: int = 0 @@ -339,7 +346,7 @@ def get_last_complete_import(self, query_vars, debug_level=0) -> tuple[int, str] if len(pastDetails['data']['import_control'])>0: lastFullImportDate = pastDetails['data']['import_control'][0]['start_time'] lastFullImportId = pastDetails['data']['import_control'][0]['control_id'] - except Exception as e: + except Exception as _: logger = getFwoLogger() logger.error(f"error while getting past import details for mgm {str(query_vars)}: {str(traceback.format_exc())}") raise diff --git a/roles/importer/files/importer/fwo_base.py b/roles/importer/files/importer/fwo_base.py index d042eefa24..d3958f770b 100644 --- a/roles/importer/files/importer/fwo_base.py +++ b/roles/importer/files/importer/fwo_base.py @@ -3,7 +3,7 @@ from copy import deepcopy import re from enum import Enum -from typing import Any, List, get_type_hints +from typing import TYPE_CHECKING, Any, get_type_hints import ipaddress import traceback import time @@ -16,6 +16,8 @@ from fwo_enums import ConfFormat, ConfigAction from fwo_log import getFwoLogger, getFwoAlertLogger from model_controllers.fwconfig_import_ruleorder import RuleOrderService +if TYPE_CHECKING: + from model_controllers.import_state_controller import ImportStateController from services.service_provider import ServiceProvider from services.global_state import GlobalState from services.enums import Services, Lifetime @@ -23,11 +25,11 @@ from services.group_flats_mapper import GroupFlatsMapper -def split_list(list_in, max_list_length): +def split_list(list_in: list[Any], max_list_length: int) -> list[list[Any]]: if len(list_in) str: if (content == None or content == '') and not no_csv_delimiter: # do not add apostrophes for empty fields field_result = csv_delimiter else: @@ -51,7 +53,7 @@ def csv_add_field(content, no_csv_delimiter=False): return field_result -def sanitize(content, lower: bool = False) -> None | str: +def sanitize(content: Any, lower: bool = False) -> None | str: if content is None: return None result = str(content) @@ -63,7 +65,7 @@ def sanitize(content, lower: bool = False) -> None | str: return result -def extend_string_list(list_string, src_dict, key, delimiter, jwt=None, import_id=None): +def extend_string_list(list_string: str | None, src_dict: dict[str, list[str]], key: str, delimiter: str, jwt: Any = None, import_id: Any = None) -> str: if list_string is None: list_string = '' if list_string == '': @@ -83,7 +85,7 @@ def extend_string_list(list_string, src_dict, key, delimiter, jwt=None, import_i return result -def jsonToLogFormat(jsonData): +def jsonToLogFormat(jsonData: dict[str, Any] | str) -> str: if type(jsonData) is dict: jsonString = json.dumps(jsonData) elif isinstance(jsonData, str): @@ -96,7 +98,7 @@ def jsonToLogFormat(jsonData): return jsonString -def writeAlertToLogFile(jsonData): +def writeAlertToLogFile(jsonData: dict[str, Any]) -> None: logger = getFwoAlertLogger() jsonDataCopy = deepcopy(jsonData) # make sure the original alert is not changed if type(jsonDataCopy) is dict and 'jsonData' in jsonDataCopy: @@ -106,7 +108,7 @@ def writeAlertToLogFile(jsonData): logger.info(alertText) -def set_ssl_verification(ssl_verification_mode): +def set_ssl_verification(ssl_verification_mode: str) -> bool | str: logger = getFwoLogger() if ssl_verification_mode == '' or ssl_verification_mode == 'off': ssl_verification = False @@ -119,16 +121,16 @@ def set_ssl_verification(ssl_verification_mode): return ssl_verification -def stringIsUri(s): - return re.match('http://.+', s) or re.match('https://.+', s) or re.match('file://.+', s) +def stringIsUri(s: str) -> re.Match[str] | None: # TODO: should return bool? + return re.match('http://.+', s) or re.match('https://.+', s) or re.match('file://.+', s) -def serializeDictToClass(data: dict, cls): +def serializeDictToClass(data: dict[str, Any], cls: Any) -> Any: # Unpack the dictionary into keyword arguments return cls(**data) -def serializeDictToClassRecursively(data: dict, cls: Any) -> Any: +def serializeDictToClassRecursively(data: dict[str, list[Any] | Any | Enum], cls: Any) -> Any: try: init_args = {} type_hints = get_type_hints(cls) @@ -146,19 +148,18 @@ def serializeDictToClassRecursively(data: dict, cls: Any) -> Any: inner_type = field_type.__args__[0] if isinstance(value, list): init_args[field] = [ - serializeDictToClassRecursively(item, inner_type) if isinstance(item, dict) else item - for item in value + serializeDictToClassRecursively(item, inner_type) if isinstance(item, dict) else item for item in value # type: ignore ] else: raise ValueError(f"Expected a list for field '{field}', but got {type(value).__name__}") # Handle dictionary (nested objects) elif isinstance(value, dict): - init_args[field] = serializeDictToClassRecursively(value, field_type) + init_args[field] = serializeDictToClassRecursively(value, field_type) # type: ignore # Handle Enum types elif isinstance(field_type, type) and issubclass(field_type, Enum): - init_args[field] = field_type[value] + init_args[field] = field_type[value] # type: ignore # Direct assignment for basic types else: @@ -167,12 +168,12 @@ def serializeDictToClassRecursively(data: dict, cls: Any) -> Any: # Create an instance of the class with the collected arguments return cls(**init_args) - except (TypeError, ValueError, KeyError) as e: + except (TypeError, ValueError, KeyError) as _: # If an error occurs, return the original dictionary as is return data -def oldSerializeDictToClassRecursively(data: dict, cls: Any) -> Any: +def oldSerializeDictToClassRecursively(data: dict[str, Any], cls: Any) -> Any: # Create an empty dictionary to store keyword arguments init_args = {} @@ -186,7 +187,7 @@ def oldSerializeDictToClassRecursively(data: dict, cls: Any) -> Any: # Handle list types inner_type = field_type.__args__[0] init_args[field] = [ - serializeDictToClassRecursively(item, inner_type) if isinstance(item, dict) else item + serializeDictToClassRecursively(item, inner_type) if isinstance(item, dict) else item # type: ignore for item in data[field] ] elif isinstance(data[field], dict): @@ -200,7 +201,7 @@ def oldSerializeDictToClassRecursively(data: dict, cls: Any) -> Any: return cls(**init_args) -def deserializeClassToDictRecursively(obj: Any, seen=None) -> Any: +def deserializeClassToDictRecursively(obj: Any, seen: set[int] | None = None) -> dict[str, Any] | list[Any] | Any | str | int | float | bool | None: if seen is None: seen = set() @@ -216,10 +217,10 @@ def deserializeClassToDictRecursively(obj: Any, seen=None) -> Any: if isinstance(obj, list): # If the object is a list, deserialize each item - return [deserializeClassToDictRecursively(item, seen) for item in obj] + return [deserializeClassToDictRecursively(item, seen) for item in obj] # type: ignore elif isinstance(obj, dict): # If the object is a dictionary, deserialize each key-value pair - return {key: deserializeClassToDictRecursively(value, seen) for key, value in obj.items()} + return {key: deserializeClassToDictRecursively(value, seen) for key, value in obj.items()} # type: ignore elif isinstance(obj, Enum): # If the object is an Enum, convert it to its value return obj.value @@ -235,10 +236,10 @@ def deserializeClassToDictRecursively(obj: Any, seen=None) -> Any: return obj -def cidrToRange(ip): +def cidrToRange(ip: str | None) -> list[str] | list[None]: # TODO: I have no idea what other than string it could be logger = getFwoLogger() - if isinstance(ip, str): + if isinstance(ip, str): # type: ignore # dealing with ranges: if '-' in ip: return '-'.split(ip) @@ -251,7 +252,7 @@ def cidrToRange(ip): net = ipaddress.IPv4Network(ip) elif ipVersion=='IPv6': net = ipaddress.IPv6Network(ip) - return [str(net.network_address), str(net.broadcast_address)] + return [str(net.network_address), str(net.broadcast_address)] # type: ignore return [ip] @@ -278,7 +279,7 @@ def validIPAddress(IP: str) -> str: return "Invalid" -def validate_ip_address(address): +def validate_ip_address(address: str) -> bool: try: # ipaddress.ip_address(address) ipaddress.ip_network(address) @@ -308,7 +309,7 @@ def lcs_dp(seq1: list[Any], seq2: list[Any]) -> tuple[list[list[int]], int]: return dp, dp[m][n] -def backtrack_lcs(seq1, seq2, dp) -> list[tuple[int, int]]: +def backtrack_lcs(seq1: list[Any], seq2: list[Any], dp: list[list[int]]) -> list[tuple[int, int]]: """ Backtracks the dynamic programming (DP) table to recover one longest common subsequence (LCS) (as a list of (i, j) index pairs). These index pairs indicate positions in seq1 and seq2 that match in the LCS. @@ -397,7 +398,7 @@ def compute_min_moves(source: list[Any], target: list[Any]) -> dict[str, Any]: } -def write_native_config_to_file(importState, configNative): +def write_native_config_to_file(importState: 'ImportStateController', configNative: dict[str, Any] | None) -> None: from fwo_const import import_tmp_path if importState.DebugLevel>6: logger = getFwoLogger(debug_level=importState.DebugLevel) @@ -425,10 +426,10 @@ def init_service_provider(): return service_provider -def find_all_diffs(a, b, strict=False, path="root"): - diffs = [] +def find_all_diffs(a: Any, b: Any, strict: bool = False, path: str = "root") -> list[str]: + diffs: list[str] = [] if isinstance(a, dict): - for k in a: + for k in a: # type: ignore if k not in b: diffs.append(f"Key '{k}' missing in second object at {path}") else: @@ -439,27 +440,25 @@ def find_all_diffs(a, b, strict=False, path="root"): if k not in a: diffs.append(f"Key '{k}' missing in first object at {path}") elif isinstance(a, list): - for i, (x, y) in enumerate(zip(a, b)): + for i, (x, y) in enumerate(zip(a, b)): # type: ignore res = find_all_diffs(x, y, strict, f"{path}[{i}]") if res: diffs.extend(res) - if len(a) != len(b): - diffs.append( - f"list length mismatch at {path}: {len(a)} != {len(b)}") + if len(a) != len(b): # type: ignore + diffs.append(f"list length mismatch at {path}: {len(a)} != {len(b)}") # type: ignore else: if a != b: - if not strict and (a is None or a == '') and (b is None - or b == ''): + if not strict and (a is None or a == '') and (b is None or b == ''): return diffs diffs.append(f"Value mismatch at {path}: {a} != {b}") return diffs -def sort_and_join(input_list: List[str]) -> str: +def sort_and_join(input_list: list[str]) -> str: """ Sorts the input list of strings and joins them using the standard list delimiter. """ return fwo_const.list_delimiter.join(sorted(input_list)) -def generate_hash_from_dict(input_dict: dict) -> str: +def generate_hash_from_dict(input_dict: dict[Any, Any]) -> str: """ Generates a consistent hash from a dictionary by serializing it with sorted keys. """ dict_string = json.dumps(input_dict, sort_keys=True) return hashlib.sha256(dict_string.encode('utf-8')).hexdigest() diff --git a/roles/importer/files/importer/fwo_config.py b/roles/importer/files/importer/fwo_config.py index 1a7306eb66..de0cd4b23e 100644 --- a/roles/importer/files/importer/fwo_config.py +++ b/roles/importer/files/importer/fwo_config.py @@ -3,7 +3,7 @@ import sys, json from fwo_const import importer_pwd_file -def readConfig(fwo_config_filename='/etc/fworch/fworch.json'): +def readConfig(fwo_config_filename: str = '/etc/fworch/fworch.json') -> dict[str, str | int | None]: logger = getFwoLogger() try: # read fwo config (API URLs) @@ -27,7 +27,7 @@ def readConfig(fwo_config_filename='/etc/fworch/fworch.json'): except Exception: logger.error("unspecified error occurred while trying to read config file: "+ fwo_config_filename) sys.exit(1) - config = { + config: dict[str, str | int | None] = { "fwo_major_version": fwo_major_version, "user_management_api_base_url": user_management_api_base_url, "fwo_api_base_url": fwo_api_base_url, diff --git a/roles/importer/files/importer/fwo_const.py b/roles/importer/files/importer/fwo_const.py index d2d657fdb4..ca675de460 100644 --- a/roles/importer/files/importer/fwo_const.py +++ b/roles/importer/files/importer/fwo_const.py @@ -1,4 +1,6 @@ -from urllib.parse import urlparse + +from typing import Any + base_dir = '/usr/local/fworch' importer_base_dir = base_dir + '/importer' @@ -39,7 +41,7 @@ api_call_chunk_size = 1000 rule_num_numeric_steps = 1024.0 -emptyNormalizedFwConfigJsonDict = { +emptyNormalizedFwConfigJsonDict: dict[str, list[Any]] = { 'network_objects': [], 'service_objects': [], 'user_objects': [], diff --git a/roles/importer/files/importer/fwo_encrypt.py b/roles/importer/files/importer/fwo_encrypt.py index dcca60de02..bbfd2d6792 100644 --- a/roles/importer/files/importer/fwo_encrypt.py +++ b/roles/importer/files/importer/fwo_encrypt.py @@ -3,12 +3,11 @@ from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.primitives import padding -import traceback from fwo_log import getFwoLogger from fwo_const import mainKeyFile # can be used for decrypting text encrypted with C# (mw-server) -def decrypt_aes_ciphertext(base64_encrypted_text, passphrase): +def decrypt_aes_ciphertext(base64_encrypted_text: str, passphrase: str) -> str: encrypted_data = base64.b64decode(base64_encrypted_text) ivLength = 16 # IV length for AES is 16 bytes @@ -24,7 +23,7 @@ def decrypt_aes_ciphertext(base64_encrypted_text, passphrase): decrypted_data = decryptor.update(encrypted_data[ivLength:]) + decryptor.finalize() # Remove padding - unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder() + unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder() #TODO: Check if block_size is correct #type: ignore try: unpadded_data = unpadder.update(decrypted_data) + unpadder.finalize() return unpadded_data.decode('utf-8') # Assuming plaintext is UTF-8 encoded @@ -33,7 +32,7 @@ def decrypt_aes_ciphertext(base64_encrypted_text, passphrase): # wrapper for trying the different decryption methods -def decrypt(encrypted_data, passphrase): +def decrypt(encrypted_data: str, passphrase: str) -> str: logger = getFwoLogger() try: decrypted = decrypt_aes_ciphertext(encrypted_data, passphrase) @@ -43,7 +42,7 @@ def decrypt(encrypted_data, passphrase): return encrypted_data -def read_main_key(filePath=mainKeyFile): +def read_main_key(filePath: str = mainKeyFile) -> str: with open(filePath, "r") as keyfile: mainKey = keyfile.read().rstrip(' \n') return mainKey diff --git a/roles/importer/files/importer/fwo_enums.py b/roles/importer/files/importer/fwo_enums.py index 1707e9c523..e7a1ecd351 100644 --- a/roles/importer/files/importer/fwo_enums.py +++ b/roles/importer/files/importer/fwo_enums.py @@ -22,7 +22,7 @@ class ConfFormat(Enum): CISCOFIREPOWER_LEGACY = 'CISCOFIREPOWER_LEGACY' @staticmethod - def IsLegacyConfigFormat(confFormatString): + def IsLegacyConfigFormat(confFormatString: str) -> bool: return ConfFormat(confFormatString) in [ConfFormat.NORMALIZED_LEGACY, ConfFormat.CHECKPOINT_LEGACY, ConfFormat.CISCOFIREPOWER_LEGACY, ConfFormat.FORTINET_LEGACY, ConfFormat.PALOALTO_LEGACY] diff --git a/roles/importer/files/importer/fwo_exceptions.py b/roles/importer/files/importer/fwo_exceptions.py index 625f24ec0d..7767565ea9 100644 --- a/roles/importer/files/importer/fwo_exceptions.py +++ b/roles/importer/files/importer/fwo_exceptions.py @@ -4,164 +4,164 @@ class FwLoginFailed(Exception): """Raised when login to FW management failed""" - def __init__(self, message="Login to FW management failed"): + def __init__(self, message: str = "Login to FW management failed"): self.message = message super().__init__(self.message) class FwApiCallFailed(Exception): """Raised when FW management API call failed""" - def __init__(self, message="An API call to the FW management failed"): + def __init__(self, message: str = "An API call to the FW management failed"): self.message = message super().__init__(self.message) class FwLogoutFailed(Exception): """Raised when logout from FW management failed""" - def __init__(self, message="Logout from FW management failed"): + def __init__(self, message: str = "Logout from FW management failed"): self.message = message super().__init__(self.message) class FwoNativeConfigFetchError(Exception): """Raised when getting native config from FW management fails, no rollback necessary""" - def __init__(self, message="Login to FW management failed"): + def __init__(self, message: str = "Login to FW management failed"): self.message = message super().__init__(self.message) class FwoNormalizedConfigParseError(Exception): """Raised while parsing normalized config""" - def __init__(self, message="Parsing normalized config failed"): + def __init__(self, message: str = "Parsing normalized config failed"): self.message = message super().__init__(self.message) class SecretDecryptionFailed(Exception): """Raised when the attempt to decrypt a secret with the given key fails""" - def __init__(self, message="Could not decrypt an API secret with given key"): + def __init__(self, message: str = "Could not decrypt an API secret with given key"): self.message = message super().__init__(self.message) class FwoApiLoginFailed(Exception): """Raised when login to FWO API fails""" - def __init__(self, message="Login to FWO API failed"): + def __init__(self, message: str = "Login to FWO API failed"): self.message = message super().__init__(self.message) class FwoApiFailedLockImport(Exception): """Raised when unable to lock import (import running?)""" - def __init__(self, message="Locking import failed - already running?"): + def __init__(self, message: str = "Locking import failed - already running?"): self.message = message super().__init__(self.message) class FwoApiFailedUnLockImport(Exception): """Raised when unable to remove import lock""" - def __init__(self, message="Unlocking import failed"): + def __init__(self, message: str = "Unlocking import failed"): self.message = message super().__init__(self.message) class FwoApiWriteError(Exception): """Raised when an FWO API mutation fails""" - def __init__(self, message="FWO API mutation failed"): + def __init__(self, message: str = "FWO API mutation failed"): self.message = message super().__init__(self.message) class FwoApiFailure(Exception): """Raised for any other FwoApi call exceptions""" - def __init__(self, message="There was an unclassified error while executing an FWO API call"): + def __init__(self, message: str = "There was an unclassified error while executing an FWO API call"): self.message = message super().__init__(self.message) class FwoApiTimeout(Exception): """Raised for 502 http error with proxy due to timeout""" - def __init__(self, message="reverse proxy timeout error during FWO API call - try increasing the reverse proxy timeout"): + def __init__(self, message: str = "reverse proxy timeout error during FWO API call - try increasing the reverse proxy timeout"): self.message = message super().__init__(self.message) class FwoApiServiceUnavailable(Exception): """Raised for 503 http error Serice unavailable""" - def __init__(self, message="FWO API Hasura container died"): + def __init__(self, message: str = "FWO API Hasura container died"): self.message = message super().__init__(self.message) class ConfigFileNotFound(Exception): """can only happen when specifying config file with -i switch""" - def __init__(self, message="Could not read config file"): + def __init__(self, message: str = "Could not read config file"): self.message = message super().__init__(self.message) class ImportRecursionLimitReached(Exception): """Raised when recursion of function inimport process reaches max allowed recursion limit""" - def __init__(self, message="Max recursion level reached - aborting"): + def __init__(self, message: str = "Max recursion level reached - aborting"): self.message = message super().__init__(self.message) class ImportInterruption(Exception): """Custom exception to signal an interrupted call requiring rollback.""" - def __init__(self, message=rollback_string): + def __init__(self, message: str = rollback_string): super().__init__(message) class FwoImporterError(Exception): """Custom exception to signal a failed import attempt.""" - def __init__(self, message=rollback_string): + def __init__(self, message: str = rollback_string): super().__init__(message) class FwoImporterErrorInconsistencies(Exception): """Custom exception to signal a failed import attempt.""" - def __init__(self, message=rollback_string): + def __init__(self, message: str = rollback_string): super().__init__(message) class RollbackNecessary(Exception): """Custom exception to signal a failed import attempt which needs a rollback.""" - def __init__(self, message="Rollback required."): + def __init__(self, message: str = "Rollback required."): super().__init__(message) class RollbackError(Exception): """Custom exception to signal a failed rollback attempt.""" - def __init__(self, message="Rollback failed."): + def __init__(self, message: str = "Rollback failed."): super().__init__(message) class FwApiError(Exception): """Custom exception to signal a failure during access checkpoint api.""" - def __init__(self, message="Error while trying to access firewall management API."): + def __init__(self, message: str = "Error while trying to access firewall management API."): super().__init__(message) class FwApiResponseDecodingError(Exception): """Custom exception to signal a failure during decoding checkpoint api response to JSON.""" - def __init__(self, message="Error while trying to decode firewall management API response into JSON."): + def __init__(self, message: str = "Error while trying to decode firewall management API response into JSON."): super().__init__(message) class FwoApiFailedDeleteOldImports(Exception): """Custom exception to signal a failure during deletion of old import data.""" - def __init__(self, message="Error while trying to remove old import data."): + def __init__(self, message: str = "Error while trying to remove old import data."): super().__init__(message) class FwoDuplicateKeyViolation(Exception): """Custom exception to signal a duplicate key violation during import.""" - def __init__(self, message="Error while trying to add data with duplicate keys"): + def __init__(self, message: str = "Error while trying to add data with duplicate keys"): super().__init__(message) class FwoUnknownDeviceForManager(Exception): """Custom exception to signal an unknown device during import.""" - def __init__(self, message="Could not find device in manager config"): + def __init__(self, message: str = "Could not find device in manager config"): super().__init__(message) class FwoDeviceWithoutLocalPackage(Exception): """Custom exception to signal a device without local package.""" - def __init__(self, message="Could not local package for device in manager config"): + def __init__(self, message: str = "Could not local package for device in manager config"): super().__init__(message) class ShutdownRequested(Exception): diff --git a/roles/importer/files/importer/fwo_file_import.py b/roles/importer/files/importer/fwo_file_import.py index a005fe0e72..66805151aa 100644 --- a/roles/importer/files/importer/fwo_file_import.py +++ b/roles/importer/files/importer/fwo_file_import.py @@ -1,16 +1,13 @@ """ read config from file and convert to non-legacy format (in case of legacy input) """ -from typing import Any, get_type_hints -from enum import Enum -import json, requests, requests.packages +import json, requests +from typing import Any from fwo_log import getFwoLogger import fwo_globals from fwo_exceptions import ConfigFileNotFound, FwoImporterError -from models.fwconfigmanagerlist import FwConfigManagerList from model_controllers.fwconfigmanagerlist_controller import FwConfigManagerListController -from models.fwconfig import FwConfig from fwconfig_base import ConfFormat import traceback @@ -99,7 +96,7 @@ def read_json_config_from_file(importState: ImportStateController) -> FwConfigMa ########### HELPERS ################## -def detect_legacy_format(configJson) -> ConfFormat: +def detect_legacy_format(configJson: dict[str, Any]) -> ConfFormat: result = ConfFormat.NORMALIZED_LEGACY @@ -111,9 +108,9 @@ def detect_legacy_format(configJson) -> ConfFormat: return result -def read_file(importState: ImportStateController) -> dict: +def read_file(importState: ImportStateController) -> dict[str, Any]: logger = getFwoLogger(debug_level=importState.DebugLevel) - configJson = {} + configJson: dict[str, Any] = {} if importState.ImportFileName=="": return configJson try: @@ -135,8 +132,8 @@ def read_file(importState: ImportStateController) -> dict: configJson = json.load(json_file) except requests.exceptions.RequestException: try: - r # check if response "r" is defined - importState.appendErrorString(f'got HTTP status code{str(r.status_code)} while trying to read config file from URL {importState.ImportFileName}') + r # check if response "r" is defined # type: ignore TODO: This practice is suspicious at best + importState.appendErrorString(f'got HTTP status code{str(r.status_code)} while trying to read config file from URL {importState.ImportFileName}') # type: ignore except NameError: importState.appendErrorString(f'got error while trying to read config file from URL {importState.ImportFileName}') importState.increaseErrorCounterByOne() diff --git a/roles/importer/files/importer/fwo_globals.py b/roles/importer/files/importer/fwo_globals.py index cbb22a4cc0..e5c6368da5 100644 --- a/roles/importer/files/importer/fwo_globals.py +++ b/roles/importer/files/importer/fwo_globals.py @@ -4,7 +4,7 @@ debug_level = 0 shutdown_requested = False -def set_global_values(verify_certs_in=None, suppress_cert_warnings_in=None, debug_level_in=0): +def set_global_values(verify_certs_in: bool | None = None, suppress_cert_warnings_in: bool | None = None, debug_level_in: int = 0): global verify_certs, suppress_cert_warnings, debug_level verify_certs = verify_certs_in suppress_cert_warnings = suppress_cert_warnings_in diff --git a/roles/importer/files/importer/fwo_log.py b/roles/importer/files/importer/fwo_log.py index 9cdc9f2899..f6a587c6d0 100644 --- a/roles/importer/files/importer/fwo_log.py +++ b/roles/importer/files/importer/fwo_log.py @@ -4,7 +4,9 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from services.uid2id_mapper import Uid2IdMapper + from importer.services.uid2id_mapper import Uid2IdMapper + from importer.model_controllers.import_state_controller import ImportStateController +from typing import Any, Literal class LogLock: @@ -53,7 +55,7 @@ def handle_log_lock(): stopwatch = -1 LogLock.semaphore.release() log_owned_by_external = False - except Exception as e: + except Exception as _: pass # Wait a second time.sleep(1) @@ -77,7 +79,7 @@ def handle_log_lock(): # LogLock.semaphore.release() -def getFwoLogger(debug_level=0): +def getFwoLogger(debug_level: int = 0): if int(debug_level) >= 1: log_level = logging.DEBUG else: @@ -103,9 +105,9 @@ def getFwoLogger(debug_level=0): return logger -def getFwoAlertLogger(debug_level=0): - debug_level=int(debug_level) - if debug_level>=1: +def getFwoAlertLogger(debug_level: int = 0): + debug_level = int(debug_level) # TODO: Check why str is passed sometimes or why the int cast is needed + if debug_level >= 1: llevel = logging.DEBUG else: llevel = logging.INFO @@ -139,10 +141,10 @@ class ChangeLogger: """ _instance = None - changed_nwobj_id_map: dict - changed_svc_id_map: dict + changed_nwobj_id_map: dict[int, int] + changed_svc_id_map: dict[int, int] _import_state = None - _uid2id_mapper: "Uid2IdMapper|None" = None + _uid2id_mapper: "Uid2IdMapper | None" = None def __new__(cls): """ @@ -157,15 +159,15 @@ def __new__(cls): return cls._instance - def create_change_id_maps(self, uid2id_mapper: "Uid2IdMapper", changed_nw_objs, changed_svcs, removedNwObjIds, removedNwSvcIds): - + def create_change_id_maps(self, uid2id_mapper: "Uid2IdMapper", changed_nw_objs: list[str], changed_svcs: list[str], removedNwObjIds: list[dict[str, Any]], removedNwSvcIds: list[dict[str, Any]]): + #TODO: removedNwObjUids? #TODO: removedNwObjUids? self._uid2id_mapper = uid2id_mapper self.changed_object_id_map = { next(removedNwObjId['obj_id'] for removedNwObjId in removedNwObjIds if removedNwObjId['obj_uid'] == old_item - ): self._uid2id_mapper.get_network_object_id(old_item) + ): self._uid2id_mapper.get_network_object_id(old_item) for old_item in changed_nw_objs } @@ -178,7 +180,7 @@ def create_change_id_maps(self, uid2id_mapper: "Uid2IdMapper", changed_nw_objs, } - def create_changelog_import_object(self, type, import_state, change_action, changeTyp, importTime, rule_id, rule_id_alternative = 0): + def create_changelog_import_object(self, type: str, import_state: "ImportStateController", change_action: str, changeTyp: Literal[2, 3], importTime: str, rule_id: int, rule_id_alternative: int = 0) -> dict[str, Any]: uniqueName = self._get_changelog_import_object_unique_name(rule_id) old_rule_id = None @@ -194,7 +196,7 @@ def create_changelog_import_object(self, type, import_state, change_action, chan if change_action == 'D': old_rule_id = rule_id - rule_changelog_object = { + rule_changelog_object: dict[str, Any] = { f"new_{type}_id": new_rule_id, f"old_{type}_id": old_rule_id, "control_id": self._import_state.ImportId, @@ -208,6 +210,6 @@ def create_changelog_import_object(self, type, import_state, change_action, chan return rule_changelog_object - def _get_changelog_import_object_unique_name(self, changelog_entity_id): + def _get_changelog_import_object_unique_name(self, changelog_entity_id: int) -> str: return str(changelog_entity_id) diff --git a/roles/importer/files/importer/fwo_signalling.py b/roles/importer/files/importer/fwo_signalling.py index 1e7944a442..82dcd3faca 100644 --- a/roles/importer/files/importer/fwo_signalling.py +++ b/roles/importer/files/importer/fwo_signalling.py @@ -1,8 +1,9 @@ import signal +from typing import Any import fwo_globals from fwo_exceptions import ShutdownRequested -def handle_shutdown_signal(signum, frame): +def handle_shutdown_signal(signum: int, frame: Any): fwo_globals.shutdown_requested = True raise ShutdownRequested diff --git a/roles/importer/files/importer/import-main-loop.py b/roles/importer/files/importer/import-main-loop.py index 8743f74431..2312444859 100755 --- a/roles/importer/files/importer/import-main-loop.py +++ b/roles/importer/files/importer/import-main-loop.py @@ -22,7 +22,7 @@ from services.enums import Services -def get_fwo_jwt(importUser, importPwd, userManagementApi) -> tuple [str, bool]: +def get_fwo_jwt(importUser: str, importPwd: str, userManagementApi: str) -> tuple [str, bool]: skipping = False try: jwt = FwoApi.login(importUser, importPwd, userManagementApi) @@ -136,7 +136,7 @@ def get_fwo_jwt(importUser, importPwd, userManagementApi) -> tuple [str, bool]: continue try: mgm_controller = ManagementController( - mgm_id=int(import_state.MgmDetails.Id), uid='', devices={}, + mgm_id=int(import_state.MgmDetails.Id), uid='', devices=[], device_info=DeviceInfo(), connection_info=ConnectionInfo(), importer_hostname='', diff --git a/roles/importer/files/importer/model_controllers/check_consistency.py b/roles/importer/files/importer/model_controllers/check_consistency.py index f6f6074500..728811fa4c 100644 --- a/roles/importer/files/importer/model_controllers/check_consistency.py +++ b/roles/importer/files/importer/model_controllers/check_consistency.py @@ -1,4 +1,5 @@ +from typing import Any import fwo_const from fwo_log import getFwoLogger from model_controllers.fwconfig_import import FwConfigImport @@ -6,11 +7,13 @@ from model_controllers.fwconfigmanagerlist_controller import FwConfigManagerListController from model_controllers.fwconfig_normalized_controller import FwConfigNormalizedController from model_controllers.fwconfigmanager_controller import FwConfigManager -from models.rulebase_link import RulebaseLink -from model_controllers.rulebase_link_controller import RulebaseLinkController from model_controllers.fwconfig_import_object import FwConfigImportObject from models.fwconfig_normalized import FwConfigNormalized from fwo_base import ConfFormat +from models.rulebase import Rulebase +from models.networkobject import NetworkObject +from models.gateway import Gateway +from models.rulebase_link import RulebaseLinkUidBased from services.service_provider import ServiceProvider from services.enums import Services from fwo_exceptions import FwoImporterErrorInconsistencies @@ -18,7 +21,7 @@ # this class is used for importing a config into the FWO API class FwConfigImportCheckConsistency(FwConfigImport): - issues: dict = {} + issues: dict[str, Any] = {} maps: FwConfigImportObject # = FwConfigImportObject() config: FwConfigNormalizedController = FwConfigNormalizedController(ConfFormat.NORMALIZED, FwConfigNormalized()) @@ -59,14 +62,14 @@ def checkConfigConsistency(self, config: FwConfigManagerListController): def checkNetworkObjectConsistency(self, config: FwConfigManagerListController): # check if all uid refs are valid - global_objects = set() + global_objects: set[str] = set() single_config: FwConfigNormalized # add all new obj refs from all rules for mgr in sorted(config.ManagerSet, key=lambda m: not getattr(m, 'IsSuperManager', False)): if mgr.IsSuperManager: global_objects = config.get_all_network_object_uids(mgr.ManagerUid) - all_used_obj_refs = [] + all_used_obj_refs: list[str] = [] for single_config in mgr.Configs: for rb in single_config.rulebases: all_used_obj_refs += self._collect_all_used_objects_from_rules(rb) @@ -74,9 +77,7 @@ def checkNetworkObjectConsistency(self, config: FwConfigManagerListController): all_used_obj_refs += self._collect_all_used_objects_from_groups(single_config) # now make list unique and get all refs not contained in network_objects - all_used_obj_refs = set(all_used_obj_refs) - - unresolvable_nw_obj_refs = all_used_obj_refs - config.get_all_network_object_uids(mgr.ManagerUid) - global_objects + unresolvable_nw_obj_refs = set(all_used_obj_refs) - config.get_all_network_object_uids(mgr.ManagerUid) - global_objects if len(unresolvable_nw_obj_refs)>0: self.issues.update({'unresolvableNwObRefs': list(unresolvable_nw_obj_refs)}) @@ -85,7 +86,7 @@ def checkNetworkObjectConsistency(self, config: FwConfigManagerListController): def _check_network_object_types_exist(self, mgr: FwConfigManager): - allUsedObjTypes: set = set() + allUsedObjTypes: set[str] = set() for single_config in mgr.Configs: for objId in single_config.network_objects: @@ -96,18 +97,19 @@ def _check_network_object_types_exist(self, mgr: FwConfigManager): @staticmethod - def _collect_all_used_objects_from_groups(single_config): - all_used_obj_refs = [] + def _collect_all_used_objects_from_groups(single_config: FwConfigNormalized) -> list[str]: + all_used_obj_refs: list[str] = [] # add all nw obj refs from groups for obj_id in single_config.network_objects: if single_config.network_objects[obj_id].obj_typ=='group': - if single_config.network_objects[obj_id].obj_member_refs is not None and len(single_config.network_objects[obj_id].obj_member_refs)>0: - all_used_obj_refs += single_config.network_objects[obj_id].obj_member_refs.split(fwo_const.list_delimiter) + obj_member_refs = single_config.network_objects[obj_id].obj_member_refs + if obj_member_refs is not None and len(obj_member_refs)>0: + all_used_obj_refs += obj_member_refs.split(fwo_const.list_delimiter) return all_used_obj_refs - def _collect_all_used_objects_from_rules(self, rb): - all_used_obj_refs = [] + def _collect_all_used_objects_from_rules(self, rb: Rulebase) -> list[str]: + all_used_obj_refs: list[str] = [] for rule_id in rb.rules: all_used_obj_refs += rb.rules[rule_id].rule_src_refs.split(fwo_const.list_delimiter) all_used_obj_refs += rb.rules[rule_id].rule_dst_refs.split(fwo_const.list_delimiter) @@ -115,9 +117,9 @@ def _collect_all_used_objects_from_rules(self, rb): return all_used_obj_refs - def _check_objects_with_missing_ips(self, single_config): + def _check_objects_with_missing_ips(self, single_config: FwConfigManager): # check if there are any objects with obj_typ<>group and empty ip addresses (breaking constraint) - nonGroupNwObjWithMissingIps = [] + nonGroupNwObjWithMissingIps: list[NetworkObject] = [] for conf in single_config.Configs: for objId in conf.network_objects: if conf.network_objects[objId].obj_typ!='group': @@ -131,7 +133,7 @@ def _check_objects_with_missing_ips(self, single_config): def checkServiceObjectConsistency(self, config: FwConfigManagerListController): # check if all uid refs are valid - global_objects = set() + global_objects: set[str] = set() for mgr in sorted(config.ManagerSet, key=lambda m: not getattr(m, 'IsSuperManager', False)): if len(mgr.Configs)==0: @@ -155,11 +157,10 @@ def checkServiceObjectConsistency(self, config: FwConfigManagerListController): def _check_service_object_types_exist(self, single_config: FwConfigNormalized): # check that all obj_typ exist - all_used_obj_types = set() + all_used_obj_types: set[str] = set() for obj_id in single_config.service_objects: all_used_obj_types.add(single_config.service_objects[obj_id].svc_typ) - all_used_obj_types = list(set(all_used_obj_types)) - missing_obj_types = all_used_obj_types - self.maps.ServiceObjectTypeMap.keys() + missing_obj_types = list(all_used_obj_types) - self.maps.ServiceObjectTypeMap.keys() if len(missing_obj_types)>0: self.issues.update({'unresolvableSvcObjTypes': list(missing_obj_types)}) @@ -178,7 +179,7 @@ def _collect_all_service_object_refs_from_groups(single_config: FwConfigNormaliz @staticmethod - def _collect_service_object_refs_from_rules(single_config) -> set[str]: + def _collect_service_object_refs_from_rules(single_config: FwConfigNormalized) -> set[str]: all_used_obj_refs: set[str] = set() for rb in single_config.rulebases: for ruleId in rb.rules: @@ -187,10 +188,10 @@ def _collect_service_object_refs_from_rules(single_config) -> set[str]: def checkUserObjectConsistency(self, config: FwConfigManagerListController): - global_objects = set() + global_objects: set[str] = set() # add all user refs from all rules for mgr in sorted(config.ManagerSet, key=lambda m: not getattr(m, 'IsSuperManager', False)): - all_used_obj_refs = [] + all_used_obj_refs: list[str] = [] if mgr.IsSuperManager: global_objects = config.get_all_user_object_uids(mgr.ManagerUid) for single_config in mgr.Configs: @@ -199,14 +200,13 @@ def checkUserObjectConsistency(self, config: FwConfigManagerListController): self._check_user_types_exist(single_config) # now make list unique and get all refs not contained in users - all_used_obj_refs = set(all_used_obj_refs) - unresolvable_obj_refs = all_used_obj_refs - config.get_all_user_object_uids(mgr.ManagerUid) - global_objects + unresolvable_obj_refs = set(all_used_obj_refs) - config.get_all_user_object_uids(mgr.ManagerUid) - global_objects if len(unresolvable_obj_refs)>0: self.issues.update({'unresolvableUserObjRefs': list(unresolvable_obj_refs)}) - def _collect_users_from_rules(self, single_config): - all_used_obj_refs = [] + def _collect_users_from_rules(self, single_config: FwConfigNormalized) -> list[str]: + all_used_obj_refs: list[str] = [] for rb in single_config.rulebases: for ruleId in rb.rules: if fwo_const.user_delimiter in rb.rules[ruleId].rule_src_refs: @@ -215,24 +215,23 @@ def _collect_users_from_rules(self, single_config): return all_used_obj_refs - def _collect_users_from_groups(self, single_config: FwConfigNormalized, all_used_obj_refs): + def _collect_users_from_groups(self, single_config: FwConfigNormalized, all_used_obj_refs: list[str]): return - def _check_user_types_exist(self, single_config): + def _check_user_types_exist(self, single_config: FwConfigNormalized): # check that all obj_typ exist - allUsedObjTypes = set() + allUsedObjTypes: set[str] = set() for objId in single_config.users: - allUsedObjTypes.add(single_config.users[objId].user_typ) - allUsedObjTypes = list(set(allUsedObjTypes)) # make list unique - missingObjTypes = allUsedObjTypes - self.maps.UserObjectTypeMap.keys() + allUsedObjTypes.add(single_config.users[objId].user_typ) # make list unique + missingObjTypes = list(set(allUsedObjTypes)) - self.maps.UserObjectTypeMap.keys() #TODO: why list(set())? if len(missingObjTypes)>0: self.issues.update({'unresolvableUserObjTypes': list(missingObjTypes)}) @staticmethod - def _collectUsersFromRule(listOfElements): - userRefs = [] + def _collectUsersFromRule(listOfElements: list[str]) -> list[str]: + userRefs: list[str] = [] for el in listOfElements: splitResult = el.split(fwo_const.user_delimiter) if len(splitResult)==2: @@ -242,7 +241,7 @@ def _collectUsersFromRule(listOfElements): def checkZoneObjectConsistency(self, config: FwConfigManagerListController): - global_objects = set() + global_objects: set[str] = set() for mgr in sorted(config.ManagerSet, key=lambda m: not getattr(m, 'IsSuperManager', False)): if len(mgr.Configs)==0: continue @@ -263,7 +262,7 @@ def checkZoneObjectConsistency(self, config: FwConfigManagerListController): @staticmethod def _collect_zone_refs_from_rules(single_config: FwConfigNormalized) -> set[str]: - all_used_zones_refs = set() + all_used_zones_refs: set[str] = set() for rb in single_config.rulebases: for rule_id in rb.rules: rule = rb.rules[rule_id] @@ -276,7 +275,7 @@ def _collect_zone_refs_from_rules(single_config: FwConfigNormalized) -> set[str] # check if all color refs are valid (in the DB) # fix=True means that missing color refs will be replaced by the default color (black) - def checkColorConsistency(self, config: FwConfigManagerListController, fix=True): + def checkColorConsistency(self, config: FwConfigManagerListController, fix: bool = True): self.import_state.SetColorRefMap(self.import_state.api_call) # collect all colors @@ -298,16 +297,16 @@ def checkColorConsistency(self, config: FwConfigManagerListController, fix=True) @staticmethod - def _collect_all_used_colors(single_config): - allUsedNwObjColorRefSet = set() - allUsedSvcColorRefSet = set() - allUsedUserColorRefSet = set() + def _collect_all_used_colors(single_config: FwConfigNormalized): + allUsedNwObjColorRefSet: set[str] = set() + allUsedSvcColorRefSet: set[str] = set() + allUsedUserColorRefSet: set[str] = set() for uid in single_config.network_objects: - if single_config.network_objects[uid].obj_color is not None: + if single_config.network_objects[uid].obj_color is not None: # type: ignore #TODO: obj_color cant be None allUsedNwObjColorRefSet.add(single_config.network_objects[uid].obj_color) for uid in single_config.service_objects: - if single_config.service_objects[uid].svc_color is not None: + if single_config.service_objects[uid].svc_color is not None: # type: ignore #TODO: svc_color cant be None allUsedSvcColorRefSet.add(single_config.service_objects[uid].svc_color) for uid in single_config.users: if single_config.users[uid].user_color is not None: @@ -316,33 +315,33 @@ def _collect_all_used_colors(single_config): return allUsedNwObjColorRefSet, allUsedSvcColorRefSet, allUsedUserColorRefSet - def _check_resolvability_of_used_colors(self, allUsedNwObjColorRefSet, allUsedSvcColorRefSet, allUsedUserColorRefSet): - unresolvableNwObjColors = [] - unresolvableSvcColors = [] - unresolvableUserColors = [] + def _check_resolvability_of_used_colors(self, allUsedNwObjColorRefSet: set[str], allUsedSvcColorRefSet: set[str], allUsedUserColorRefSet: set[str]): + unresolvableNwObjColors: list[str] = [] + unresolvableSvcColors: list[str] = [] + unresolvableUserColors: list[str] = [] # check all nwobj color refs for color_string in allUsedNwObjColorRefSet: color_id = self.import_state.lookupColorId(color_string) - if color_id is None: + if color_id is None: # type: ignore # TODO: lookupColorId cant return None unresolvableNwObjColors.append(color_string) # check all nwobj color refs for color_string in allUsedSvcColorRefSet: color_id = self.import_state.lookupColorId(color_string) - if color_id is None: + if color_id is None: # type: ignore # TODO: lookupColorId cant return None unresolvableSvcColors.append(color_string) # check all user color refs for color_string in allUsedUserColorRefSet: color_id = self.import_state.lookupColorId(color_string) - if color_id is None: + if color_id is None: # type: ignore # TODO: lookupColorId cant return None unresolvableUserColors.append(color_string) return unresolvableNwObjColors, unresolvableSvcColors, unresolvableUserColors @staticmethod - def _fix_colors(config, unresolvable_nw_obj_colors, unresolvable_svc_colors, unresolvable_user_colors): + def _fix_colors(config: FwConfigNormalized, unresolvable_nw_obj_colors: list[str], unresolvable_svc_colors: list[str], unresolvable_user_colors: list[str]): # Replace unresolvable network object colors for obj in config.network_objects.values(): if obj.obj_color in unresolvable_nw_obj_colors: @@ -358,9 +357,9 @@ def _fix_colors(config, unresolvable_nw_obj_colors, unresolvable_svc_colors, unr @staticmethod - def _extract_rule_track_n_action_refs(rulebases): - track_refs = [] - action_refs = [] + def _extract_rule_track_n_action_refs(rulebases: list[Rulebase]) -> tuple[list[str], list[str]]: + track_refs: list[str] = [] + action_refs: list[str] = [] for rb in rulebases: track_refs.extend(rule.rule_track for rule in rb.rules.values()) action_refs.extend(rule.rule_action for rule in rb.rules.values()) @@ -368,8 +367,8 @@ def _extract_rule_track_n_action_refs(rulebases): def check_rulebase_consistency(self, config: FwConfigManagerListController): - all_used_track_refs = [] - all_used_action_refs = [] + all_used_track_refs: list[str] = [] + all_used_action_refs: list[str] = [] for mgr in config.ManagerSet: for single_config in mgr.Configs: @@ -400,7 +399,7 @@ def check_gateway_consistency(self, config: FwConfigManagerListController): # - the same submanager or # - the super manager but not another sub manager def check_rulebase_link_consistency(self, config: FwConfigManagerListController): - broken_rulebase_links = [] + broken_rulebase_links: list[dict[str, Any]] = [] all_rulebase_uids, all_rule_uids = self._collect_uids(config) @@ -417,15 +416,15 @@ def check_rulebase_link_consistency(self, config: FwConfigManagerListController) self.issues.update({'brokenRulebaseLinks': broken_rulebase_links}) - def _check_rulebase_links_for_gateway(self, gw, broken_rulebase_links, all_rule_uids, all_rulebase_uids): + def _check_rulebase_links_for_gateway(self, gw: Gateway, broken_rulebase_links: list[dict[str, Any]], all_rule_uids: set[str], all_rulebase_uids: set[str]): if not gw.ImportDisabled: for rbl in gw.RulebaseLinks: self._check_rulebase_link(gw, rbl, broken_rulebase_links, all_rule_uids, all_rulebase_uids) - def _collect_uids(self, config): - all_rulebase_uids = set() - all_rule_uids = set() + def _collect_uids(self, config: FwConfigManagerListController): + all_rulebase_uids: set[str] = set() + all_rule_uids: set[str] = set() for mgr in config.ManagerSet: if self.import_state.MgmDetails.ImportDisabled: continue @@ -439,16 +438,16 @@ def _collect_uids(self, config): return all_rulebase_uids, all_rule_uids - def _check_rulebase_link(self, gw, rbl, broken_rulebase_links, all_rule_uids, all_rulebase_uids): + def _check_rulebase_link(self, gw: Gateway, rbl: RulebaseLinkUidBased, broken_rulebase_links: list[dict[str, Any]], all_rule_uids: set[str], all_rulebase_uids: set[str]): if rbl.from_rulebase_uid is not None and rbl.from_rulebase_uid != '' and rbl.from_rulebase_uid not in all_rulebase_uids: self._add_issue(broken_rulebase_links, rbl, gw, 'from_rulebase_uid broken') - if rbl.to_rulebase_uid is not None and rbl.to_rulebase_uid != '' and rbl.to_rulebase_uid not in all_rulebase_uids: + if rbl.to_rulebase_uid is not None and rbl.to_rulebase_uid != '' and rbl.to_rulebase_uid not in all_rulebase_uids: # type: ignore # TODO: to_rulebase_uid cant be None self._add_issue(broken_rulebase_links, rbl, gw, 'to_rulebase_uid broken') if rbl.from_rule_uid is not None and rbl.from_rule_uid != '' and rbl.from_rule_uid not in all_rule_uids: self._add_issue(broken_rulebase_links, rbl, gw, 'from_rule_uid broken') @staticmethod - def _add_issue(broken_rulebase_links, rbl, gw, error_txt): + def _add_issue(broken_rulebase_links: list[dict[str, Any]], rbl: RulebaseLinkUidBased, gw: Gateway, error_txt: str): rbl_dict = rbl.toDict() rbl_dict.update({'error': error_txt}) rbl_dict.update({'gw': f'{gw.Name} ({gw.Uid})'}) diff --git a/roles/importer/files/importer/model_controllers/fwconfig_controller.py b/roles/importer/files/importer/model_controllers/fwconfig_controller.py index 468f737fe5..2685e09571 100644 --- a/roles/importer/files/importer/model_controllers/fwconfig_controller.py +++ b/roles/importer/files/importer/model_controllers/fwconfig_controller.py @@ -1,18 +1,17 @@ import json from fwo_base import ConfFormat, ConfigAction from models.rulebase import Rulebase -from models.fwconfig import FwConfig from netaddr import IPNetwork class FwoEncoder(json.JSONEncoder): - def default(self, obj): + def default(self, obj: object) -> object: # type: ignore if isinstance(obj, ConfigAction) or isinstance(obj, ConfFormat): return obj.name if isinstance(obj, Rulebase): - return obj.toJson() + return obj.toJson() # type: ignore if isinstance(obj, IPNetwork): return str(obj) diff --git a/roles/importer/files/importer/model_controllers/fwconfig_import.py b/roles/importer/files/importer/model_controllers/fwconfig_import.py index 15e7ff9337..05b14d5da0 100644 --- a/roles/importer/files/importer/model_controllers/fwconfig_import.py +++ b/roles/importer/files/importer/model_controllers/fwconfig_import.py @@ -1,5 +1,5 @@ import traceback -from typing import Optional +from typing import Any import fwo_const from fwo_api_call import FwoApiCall @@ -9,7 +9,7 @@ from fwo_exceptions import ImportInterruption from fwo_log import getFwoLogger from model_controllers.import_state_controller import ImportStateController -from fwo_base import ConfigAction, ConfFormat, find_all_diffs +from fwo_base import ConfigAction, find_all_diffs from models.fwconfig_normalized import FwConfigNormalized from model_controllers.fwconfig_import_object import FwConfigImportObject from model_controllers.fwconfig_import_rule import FwConfigImportRule @@ -27,7 +27,7 @@ class FwConfigImport(): import_state: ImportStateController - NormalizedConfig: Optional[FwConfigNormalized] + NormalizedConfig: FwConfigNormalized | None _fw_config_import_rule: FwConfigImportRule _fw_config_import_object: FwConfigImportObject @@ -44,6 +44,7 @@ def __init__(self): if self._global_state.import_state is None: raise FwoImporterError("import_state not set in global state") self.import_state = self._global_state.import_state + self.NormalizedConfig = self._global_state.normalized_config self._fw_config_import_object = FwConfigImportObject() @@ -66,8 +67,6 @@ def import_single_config(self, single_manager: FwConfigManager): def import_management_set(self, import_state: ImportStateController, service_provider: ServiceProvider, mgr_set: FwConfigManagerListController): - global_state = service_provider.get_service(Services.GLOBAL_STATE) - for manager in sorted(mgr_set.ManagerSet, key=lambda m: not getattr(m, 'IsSuperManager', False)): """ the following loop is a preparation for future functionality @@ -118,7 +117,7 @@ def clear_management(self) -> FwConfigManagerListController: if len(self.import_state.MgmDetails.SubManagerIds)>0: # Read config fwo_api = FwoApi(self.import_state.FwoConfig.FwoApiUri, self.import_state.Jwt) - fwo_api_call = FwoApiCall(fwo_api) + _ = FwoApiCall(fwo_api) #TODO why not used ?? # # Authenticate to get JWT # try: # jwt = fwo_api.login(importer_user_name, fwoConfig.ImporterPassword, fwoConfig.FwoUserMgmtApiUri) @@ -129,7 +128,7 @@ def clear_management(self) -> FwConfigManagerListController: for subManagerId in self.import_state.MgmDetails.SubManagerIds: # Fetch sub management details mgm_controller = ManagementController( - mgm_id=int(subManagerId), uid='', devices={}, + mgm_id=int(subManagerId), uid='', devices=[], device_info=DeviceInfo(), connection_info=ConnectionInfo(), importer_hostname='', @@ -141,7 +140,7 @@ def clear_management(self) -> FwConfigManagerListController: mgm_details = ManagementController.fromJson(mgm_details_raw) configNormalized.addManager( manager=FwConfigManager( - ManagerUid=ManagementController.calcManagerUidHash(mgm_details_raw), + ManagerUid=ManagementController.calcManagerUidHash(mgm_details_raw), #type: ignore # TODO: check: should be mgm_details ManagerName=mgm_details.Name, IsSuperManager=mgm_details.IsSuperManager, SubManagerIds=mgm_details.SubManagerIds, @@ -186,10 +185,10 @@ def updateDiffs(self, prev_config: FwConfigNormalized, prev_global_config: FwCon self._fw_config_import_gateway.update_gateway_diffs() # get new rules details from API (for obj refs as well as enforcing gateways) - errors, changes, newRules = self._fw_config_import_rule.getRulesByIdWithRefUids(newRuleIds) + _, _, newRules = self._fw_config_import_rule.getRulesByIdWithRefUids(newRuleIds) enforcingController = RuleEnforcedOnGatewayController(self.import_state) - ids = enforcingController.add_new_rule_enforced_on_gateway_refs(newRules, self.import_state) + enforcingController.add_new_rule_enforced_on_gateway_refs(newRules, self.import_state) # cleanup configs which do not need to be retained according to data retention time @@ -210,7 +209,7 @@ def deleteOldImports(self) -> None: logger.error(f"error while trying to delete old imports for mgm {str(self.import_state.MgmDetails.Id)}") fwo_api_call.create_data_issue(self.import_state.FwoConfig.FwoApiUri, self.import_state.Jwt, mgm_id=int(self.import_state.MgmDetails.Id), severity=1, description="failed to get import lock for management id " + str(mgmId)) - fwo_api_call.set_alert(import_id=self.import_state.ImportId, title="import error", mgm_id=str(mgmId), severity=1, \ + fwo_api_call.set_alert(import_id=self.import_state.ImportId, title="import error", mgm_id=mgmId, severity=1, \ description="fwo_api: failed to get import lock", source='import', alertCode=15, mgm_details=self.import_state.MgmDetails) raise FwoApiFailedDeleteOldImports(f"management id: {mgmId}") from None @@ -238,7 +237,7 @@ def write_latest_config(self): getFwoLogger().warning(f"error while trying to delete latest config for mgm_id: {self.import_state.ImportId}") insertMutation = FwoApi.get_graphql_code([fwo_const.graphql_query_path + "import/storeLatestConfig.graphql"]) try: - query_variables = { + query_variables: dict[str, Any] = { 'mgmId': self.import_state.MgmDetails.CurrentMgmId, 'importId': self.import_state.ImportId, 'config': self.NormalizedConfig.model_dump_json() @@ -303,16 +302,7 @@ def get_latest_import_id(self) -> int|None: # return previous config or empty config if there is none; only returns the config of a single management def get_latest_config(self) -> FwConfigNormalized: mgm_id = self.import_state.MgmDetails.CurrentMgmId - prev_config = FwConfigNormalized(**{ - 'action': ConfigAction.INSERT, - 'network_objects': {}, - 'service_objects': {}, - 'users': {}, - 'zone_objects': {}, - 'rules': [], - 'gateways': [], - 'ConfigFormat': ConfFormat.NORMALIZED_LEGACY - }) + prev_config = FwConfigNormalized() logger = getFwoLogger(debug_level=self.import_state.DebugLevel) latest_import_id = self.get_latest_import_id() diff --git a/roles/importer/files/importer/model_controllers/fwconfig_import_gateway.py b/roles/importer/files/importer/model_controllers/fwconfig_import_gateway.py index cc584d8954..f2d8addab6 100644 --- a/roles/importer/files/importer/model_controllers/fwconfig_import_gateway.py +++ b/roles/importer/files/importer/model_controllers/fwconfig_import_gateway.py @@ -1,9 +1,11 @@ +from logging import Logger +from typing import Any from fwo_log import getFwoLogger from model_controllers.rulebase_link_controller import RulebaseLinkController from models.gateway import Gateway -from models.rulebase_link import RulebaseLink, RulebaseLinkUidBased +from models.rulebase_link import RulebaseLink, RulebaseLinkUidBased # TODO check if we need RulebaseLinkUidBased as well from fwo_exceptions import FwoImporterError from services.enums import Services @@ -44,7 +46,7 @@ def update_gateway_diffs(self): # self.ImportDetails.Stats.addError('simulate error') - def update_rulebase_link_diffs(self): + def update_rulebase_link_diffs(self) -> tuple[list[dict[str, Any]], list[int | None]]: if self._global_state.import_state is None: #TODO: should rework global state to not need these checks ? - #3154 raise FwoImporterError("ImportState is None in update_rulebase_link_diffs") @@ -53,8 +55,8 @@ def update_rulebase_link_diffs(self): if self._global_state.previous_config is None: raise FwoImporterError("previous_config is None in update_rulebase_link_diffs") - required_inserts: list[RulebaseLinkUidBased] = [] - required_removes: list[int] = [] + required_inserts: list[dict[str, Any]] = [] + required_removes: list[int | None] = [] logger = getFwoLogger(debug_level=self._global_state.import_state.DebugLevel) @@ -82,8 +84,7 @@ def update_rulebase_link_diffs(self): return required_inserts, required_removes - def _create_insert_args(self, normalized_gateway: Gateway, previous_gateway: Gateway|None, gw_id, logger, arg_list): - + def _create_insert_args(self, normalized_gateway: Gateway, previous_gateway: Gateway|None, gw_id: int | None, logger: Logger, arg_list: list[dict[str, Any]]): rulebase_links = [] for link in normalized_gateway.RulebaseLinks: @@ -92,9 +93,9 @@ def _create_insert_args(self, normalized_gateway: Gateway, previous_gateway: Gat self._try_add_single_link(arg_list, link, rulebase_links, gw_id, True, logger) - def _create_remove_args(self, normalized_gateway: Gateway, previous_gateway: Gateway, gw_id, logger, arg_list): + def _create_remove_args(self, normalized_gateway: Gateway, previous_gateway: Gateway, gw_id: int | None, logger: Logger, arg_list: list[int | None]): - removed_rulebase_links = [] + removed_rulebase_links: list[dict[str, Any]] = [] for link in previous_gateway.RulebaseLinks: self._try_add_single_link(removed_rulebase_links, link, normalized_gateway.RulebaseLinks, gw_id, False, logger) @@ -104,32 +105,32 @@ def _create_remove_args(self, normalized_gateway: Gateway, previous_gateway: Gat arg_list.append(link_in_db.id) - def _try_add_single_link(self, rb_link_list, link, link_list, gw_id, is_insert, logger): + def _try_add_single_link(self, rb_link_list: list[dict[str, Any]], link: RulebaseLinkUidBased, link_list: list[RulebaseLinkUidBased], gw_id: int | None, is_insert: bool, logger: Logger): if self._global_state.import_state is None: raise FwoImporterError("ImportState is None in _try_add_single_link") # If rule changed we need the id of the old version, since the rulebase links still have the old fks (for updates) - from_rule_id = self._global_state.import_state.removed_rules_map.get(link.from_rule_uid, None) + from_rule_id = self._global_state.import_state.removed_rules_map.get(link.from_rule_uid, None) if link.from_rule_uid is not None else None # If rule is unchanged or new id can be fetched from RuleMap, because it has been updated already if not from_rule_id or is_insert: - from_rule_id = self._global_state.import_state.lookupRule(link.from_rule_uid) + from_rule_id = self._global_state.import_state.lookupRule(link.from_rule_uid) if link.from_rule_uid is not None else None if link.from_rulebase_uid is None or link.from_rulebase_uid == '': from_rulebase_id = None else: from_rulebase_id = self._global_state.import_state.lookupRulebaseId(link.from_rulebase_uid) to_rulebase_id = self._global_state.import_state.lookupRulebaseId(link.to_rulebase_uid) - if to_rulebase_id is None: - self._global_state.import_state.Stats.addError(f"toRulebaseId is None for link {link}") - return link_type_id = self._global_state.import_state.lookupLinkType(link.link_type) - if link_type_id is None or type(link_type_id) is not int: + if type(link_type_id) is not int: logger.warning(f"did not find a link_type_id for link_type {link.link_type}") if not self._link_is_in_link_list(link, link_list): + if gw_id is None: + logger.warning(f"did not find a gwId for UID {link}") + return rb_link_list.append(RulebaseLink(gw_id=gw_id, from_rule_id=from_rule_id, to_rulebase_id=to_rulebase_id, @@ -144,7 +145,7 @@ def _try_add_single_link(self, rb_link_list, link, link_list, gw_id, is_insert, logger.debug(f"link {link} was added") - def _link_is_in_link_list(self, link: RulebaseLinkUidBased, link_list: list[RulebaseLinkUidBased]): + def _link_is_in_link_list(self, link: RulebaseLinkUidBased, link_list: list[RulebaseLinkUidBased]) -> bool: if link_list: existing_link = next(( @@ -159,7 +160,7 @@ def _link_is_in_link_list(self, link: RulebaseLinkUidBased, link_list: list[Rule return False - def _try_get_id_based_link(self, link: RulebaseLinkUidBased, link_list: list[RulebaseLink]): + def _try_get_id_based_link(self, link: dict[str, Any], link_list: list[RulebaseLink]): return next(( existing_link @@ -169,11 +170,11 @@ def _try_get_id_based_link(self, link: RulebaseLinkUidBased, link_list: list[Rul def update_interface_diffs(self): - logger = getFwoLogger(debug_level=self._global_state.import_state.DebugLevel) + logger = getFwoLogger(debug_level=self._global_state.import_state.DebugLevel) #type: ignore # TODO: needs to be implemented def update_routing_diffs(self): - logger = getFwoLogger(debug_level=self._global_state.import_state.DebugLevel) + logger = getFwoLogger(debug_level=self._global_state.import_state.DebugLevel) #type: ignore # TODO: needs to be implemented diff --git a/roles/importer/files/importer/model_controllers/fwconfig_import_object.py b/roles/importer/files/importer/model_controllers/fwconfig_import_object.py index bb84a86f12..2067799072 100644 --- a/roles/importer/files/importer/model_controllers/fwconfig_import_object.py +++ b/roles/importer/files/importer/model_controllers/fwconfig_import_object.py @@ -1,6 +1,6 @@ from enum import Enum import traceback -import time, datetime +import datetime import json from typing import Any @@ -11,7 +11,7 @@ from models.fwconfigmanager import FwConfigManager from models.serviceobject import ServiceObjectForImport import fwo_const -from fwo_api_call import FwoApiCall, FwoApi +from fwo_api_call import FwoApi from fwo_exceptions import FwoDuplicateKeyViolation, FwoImporterError from services.group_flats_mapper import GroupFlatsMapper from services.uid2id_mapper import Uid2IdMapper @@ -28,7 +28,7 @@ class FwConfigImportObject(): import_state: ImportStateController normalized_config: FwConfigNormalized - global_normalized_config: FwConfigNormalized|None = None + global_normalized_config: FwConfigNormalized | None = None group_flats_mapper: GroupFlatsMapper prev_group_flats_mapper: GroupFlatsMapper uid2id_mapper: Uid2IdMapper @@ -60,13 +60,13 @@ def updateObjectDiffs(self, prev_config: FwConfigNormalized, prev_global_config: # calculate network object diffs # here we are handling the previous config as a dict for a while # previousNwObjects = prevConfig.network_objects - deletedNwobjUids = list(prev_config.network_objects.keys() - self.normalized_config.network_objects.keys()) - newNwobjUids = list(self.normalized_config.network_objects.keys() - prev_config.network_objects.keys()) - nwobjUidsInBoth = list(self.normalized_config.network_objects.keys() & prev_config.network_objects.keys()) + deletedNwobjUids: list[str] = list(prev_config.network_objects.keys() - self.normalized_config.network_objects.keys()) + newNwobjUids: list[str] = list(self.normalized_config.network_objects.keys() - prev_config.network_objects.keys()) + nwobjUidsInBoth: list[str] = list(self.normalized_config.network_objects.keys() & prev_config.network_objects.keys()) # For correct changelog and stats. - changed_nw_objs = [] - changed_svcs = [] + changed_nw_objs: list[str] = [] + changed_svcs: list[str] = [] # decide if it is prudent to mix changed, deleted and added rules here: for nwObjUid in nwobjUidsInBoth: @@ -76,9 +76,9 @@ def updateObjectDiffs(self, prev_config: FwConfigNormalized, prev_global_config: changed_nw_objs.append(nwObjUid) # calculate service object diffs - deletedSvcObjUids = list(prev_config.service_objects.keys() - self.normalized_config.service_objects.keys()) - newSvcObjUids = list(self.normalized_config.service_objects.keys() - prev_config.service_objects.keys()) - svcObjUidsInBoth = list(self.normalized_config.service_objects.keys() & prev_config.service_objects.keys()) + deletedSvcObjUids: list[str] = list(prev_config.service_objects.keys() - self.normalized_config.service_objects.keys()) + newSvcObjUids: list[str] = list(self.normalized_config.service_objects.keys() - prev_config.service_objects.keys()) + svcObjUidsInBoth: list[str] = list(self.normalized_config.service_objects.keys() & prev_config.service_objects.keys()) for nwSvcUid in svcObjUidsInBoth: if self.normalized_config.service_objects[nwSvcUid] != prev_config.service_objects[nwSvcUid]: @@ -87,9 +87,9 @@ def updateObjectDiffs(self, prev_config: FwConfigNormalized, prev_global_config: changed_svcs.append(nwSvcUid) # calculate user diffs - deletedUserUids = list(prev_config.users.keys() - self.normalized_config.users.keys()) - newUserUids = list(self.normalized_config.users.keys() - prev_config.users.keys()) - userUidsInBoth = list(self.normalized_config.users.keys() & prev_config.users.keys()) + deletedUserUids: list[str] = list(prev_config.users.keys() - self.normalized_config.users.keys()) + newUserUids: list[str] = list(self.normalized_config.users.keys() - prev_config.users.keys()) + userUidsInBoth: list[str] = list(self.normalized_config.users.keys() & prev_config.users.keys()) for userUid in userUidsInBoth: if self.normalized_config.users[userUid] != prev_config.users[userUid]: newUserUids.append(userUid) @@ -113,10 +113,10 @@ def updateObjectDiffs(self, prev_config: FwConfigNormalized, prev_global_config: self.remove_outdated_memberships(prev_config, Type.USER) # calculate zone object diffs - deleted_zone_names = list(prev_config.zone_objects.keys() - self.normalized_config.zone_objects.keys()) - new_zone_names = list(self.normalized_config.zone_objects.keys() - prev_config.zone_objects.keys()) - zone_names_in_both = list(self.normalized_config.zone_objects.keys() & prev_config.zone_objects.keys()) - changed_zones = [] + deleted_zone_names: list[str] = list(prev_config.zone_objects.keys() - self.normalized_config.zone_objects.keys()) + new_zone_names: list[str] = list(self.normalized_config.zone_objects.keys() - prev_config.zone_objects.keys()) + zone_names_in_both: list[str] = list(self.normalized_config.zone_objects.keys() & prev_config.zone_objects.keys()) + changed_zones: list[str] = [] for zone_name in zone_names_in_both: if self.normalized_config.zone_objects[zone_name] != prev_config.zone_objects[zone_name]: @@ -125,7 +125,7 @@ def updateObjectDiffs(self, prev_config: FwConfigNormalized, prev_global_config: changed_zones.append(zone_name) # add newly created objects - newNwObjIds, newNwSvcIds, newUserIds, new_zone_ids, removedNwObjIds, removedNwSvcIds, removedUserIds, removed_zone_ids = \ + newNwObjIds, newNwSvcIds, newUserIds, new_zone_ids, removedNwObjIds, removedNwSvcIds, _, _ = \ self.updateObjectsViaApi(single_manager, newNwobjUids, newSvcObjUids, newUserUids, new_zone_names, deletedNwobjUids, deletedSvcObjUids, deletedUserUids, deleted_zone_names) self.uid2id_mapper.add_network_object_mappings(newNwObjIds, is_global=single_manager.IsSuperManager) @@ -136,7 +136,7 @@ def updateObjectDiffs(self, prev_config: FwConfigNormalized, prev_global_config: # insert new and updated group memberships self.addGroupMemberships(prev_config, Type.NETWORK_OBJECT) self.addGroupMemberships(prev_config, Type.SERVICE_OBJECT) - self.addGroupMemberships(prev_config, Type.USER) + self.addGroupMemberships(prev_config, Type.USER) # these objects have really been deleted so there should be no refs to them anywhere! verify this @@ -178,7 +178,7 @@ def updateObjectDiffs(self, prev_config: FwConfigNormalized, prev_global_config: self.import_state.Stats.ServiceObjectChangeCount = len(change_logger.changed_service_id_map.items()) - def GetNetworkObjTypeMap(self): + def GetNetworkObjTypeMap(self) -> dict[str, int]: query = "query getNetworkObjTypeMap { stm_obj_typ { obj_typ_name obj_typ_id } }" try: result = self.import_state.api_call.call(query=query, query_variables={}) @@ -187,12 +187,12 @@ def GetNetworkObjTypeMap(self): logger.error(f"Error while getting stm_obj_typ: str{e}") return {} - map = {} + map: dict[str, Any] = {} for nwType in result['data']['stm_obj_typ']: map.update({nwType['obj_typ_name']: nwType['obj_typ_id']}) return map - def GetServiceObjTypeMap(self): + def GetServiceObjTypeMap(self) -> dict[str, int]: query = "query getServiceObjTypeMap { stm_svc_typ { svc_typ_name svc_typ_id } }" try: result = self.import_state.api_call.call(query=query, query_variables={}) @@ -201,12 +201,12 @@ def GetServiceObjTypeMap(self): logger.error(f"Error while getting stm_svc_typ: {str(e)}") return {} - map = {} + map: dict[str, Any] = {} for svcType in result['data']['stm_svc_typ']: map.update({svcType['svc_typ_name']: svcType['svc_typ_id']}) return map - def GetUserObjTypeMap(self): + def GetUserObjTypeMap(self) -> dict[str, int]: query = "query getUserObjTypeMap { stm_usr_typ { usr_typ_name usr_typ_id } }" try: result = self.import_state.api_call.call(query=query, query_variables={}) @@ -215,12 +215,12 @@ def GetUserObjTypeMap(self): logger.error(f"Error while getting stm_usr_typ: {str(e)}") return {} - map = {} + map: dict[str, Any] = {} for usrType in result['data']['stm_usr_typ']: map.update({usrType['usr_typ_name']: usrType['usr_typ_id']}) return map - def GetProtocolMap(self): + def GetProtocolMap(self) -> dict[str, int]: query = "query getIpProtocols { stm_ip_proto { ip_proto_id ip_proto_name } }" try: result = self.import_state.api_call.call(query=query, query_variables={}) @@ -229,12 +229,12 @@ def GetProtocolMap(self): logger.error(f"Error while getting stm_ip_proto: {str(e)}") return {} - map = {} + map: dict[str, Any] = {} for proto in result['data']['stm_ip_proto']: map.update({proto['ip_proto_name'].lower(): proto['ip_proto_id']}) return map - def updateObjectsViaApi(self, single_manager, newNwObjectUids, newSvcObjectUids, newUserUids, new_zone_names, removedNwObjectUids, removedSvcObjectUids, removedUserUids, removed_zone_names): + def updateObjectsViaApi(self, single_manager: FwConfigManager, newNwObjectUids: list[str], newSvcObjectUids: list[str], newUserUids: list[str], new_zone_names: list[str], removedNwObjectUids: list[str], removedSvcObjectUids: list[str], removedUserUids: list[str], removed_zone_names: list[str]): # here we also mark old objects removed before adding the new versions logger = getFwoLogger(debug_level=self.import_state.DebugLevel) newNwObjIds = [] @@ -246,8 +246,10 @@ def updateObjectsViaApi(self, single_manager, newNwObjectUids, newSvcObjectUids, removedUserIds = [] removed_zone_ids = [] this_managements_id = self.import_state.lookupManagementId(single_manager.ManagerUid) - import_mutation = FwoApi.get_graphql_code([fwo_const.graphql_query_path + "allObjects/upsertObjects.graphql"]) - query_variables = { + if this_managements_id is None: + raise FwoImporterError(f"failed to update objects in updateObjectsViaApi: no management id found for manager uid '{single_manager.ManagerUid}'") + import_mutation = FwoApi.get_graphql_code(file_list=[fwo_const.graphql_query_path + "allObjects/upsertObjects.graphql"]) + query_variables: dict[str, Any] = { 'mgmId': this_managements_id, 'importId': self.import_state.ImportId, 'newNwObjects': self.prepareNewNwObjects(newNwObjectUids, this_managements_id), @@ -270,7 +272,7 @@ def updateObjectsViaApi(self, single_manager, newNwObjectUids, newSvcObjectUids, if 'errors' in import_result: raise FwoImporterError(f"failed to update objects in updateObjectsViaApi: {str(import_result['errors'])}") else: - changes = int(import_result['data']['insert_object']['affected_rows']) + \ + _ = int(import_result['data']['insert_object']['affected_rows']) + \ int(import_result['data']['insert_service']['affected_rows']) + \ int(import_result['data']['insert_usr']['affected_rows']) + \ int(import_result['data']['update_object']['affected_rows']) + \ @@ -291,8 +293,8 @@ def updateObjectsViaApi(self, single_manager, newNwObjectUids, newSvcObjectUids, return newNwObjIds, newNwSvcIds, newUserIds, new_zone_ids, removedNwObjIds, removedNwSvcIds, removedUserIds, removed_zone_ids - def prepareNewNwObjects(self, newNwobjUids, mgm_id): - newNwObjs = [] + def prepareNewNwObjects(self, newNwobjUids: list[str], mgm_id: int) -> list[dict[str, Any]]: + newNwObjs: list[dict[str, Any]] = [] for nwobjUid in newNwobjUids: newNwObj = NetworkObjectForImport(nwObject=self.normalized_config.network_objects[nwobjUid], mgmId=mgm_id, @@ -304,8 +306,8 @@ def prepareNewNwObjects(self, newNwobjUids, mgm_id): return newNwObjs - def prepareNewSvcObjects(self, newSvcobjUids, mgm_id): - newObjs = [] + def prepareNewSvcObjects(self, newSvcobjUids: list[str], mgm_id: int) -> list[dict[str, Any]]: + newObjs: list[dict[str, Any]] = [] for uid in newSvcobjUids: newObjs.append(ServiceObjectForImport(svcObject=self.normalized_config.service_objects[uid], mgmId=mgm_id, @@ -314,9 +316,9 @@ def prepareNewSvcObjects(self, newSvcobjUids, mgm_id): typId=self.lookupSvcType(self.normalized_config.service_objects[uid].svc_typ), ).toDict()) return newObjs - - def prepareNewUserObjects(self, newUserUids, mgm_id): - newObjs = [] + + def prepareNewUserObjects(self, newUserUids: list[str], mgm_id: int) -> list[dict[str, Any]]: + newObjs: list[dict[str, Any]] = [] for uid in newUserUids: newObjs.append({ 'user_uid': uid, @@ -327,10 +329,10 @@ def prepareNewUserObjects(self, newUserUids, mgm_id): 'user_name': self.normalized_config.users[uid]['user_name'], }) return newObjs - - - def prepare_new_zones(self, new_zone_names, mgm_id): - new_objects = [] + + + def prepare_new_zones(self, new_zone_names: list[str], mgm_id: int) -> list[dict[str, Any]]: + new_objects: list[dict[str, Any]] = [] for uid in new_zone_names: new_objects.append({ 'mgm_id': mgm_id, @@ -349,21 +351,21 @@ def get_config_objects(self, type: Type, prevConfig: FwConfigNormalized): if type == Type.USER: return prevConfig.users, self.normalized_config.users - def get_id(self, type, uid, before_update = False): + def get_id(self, type: Type, uid: str, before_update: bool = False) -> int | None: if type == Type.NETWORK_OBJECT: return self.uid2id_mapper.get_network_object_id(uid, before_update) if type == Type.SERVICE_OBJECT: return self.uid2id_mapper.get_service_object_id(uid, before_update) return self.uid2id_mapper.get_user_id(uid, before_update) - def get_local_id(self, type, uid, before_update = False): + def get_local_id(self, type: Type, uid: str, before_update: bool = False) -> int | None: if type == Type.NETWORK_OBJECT: return self.uid2id_mapper.get_network_object_id(uid, before_update, local_only=True) if type == Type.SERVICE_OBJECT: return self.uid2id_mapper.get_service_object_id(uid, before_update, local_only=True) return self.uid2id_mapper.get_user_id(uid, before_update, local_only=True) - def is_group(self, type: Type, obj): + def is_group(self, type: Type, obj: Any) -> bool: if type == Type.NETWORK_OBJECT: return obj.obj_typ == "group" if type == Type.SERVICE_OBJECT: @@ -372,26 +374,26 @@ def is_group(self, type: Type, obj): return obj.get('user_typ', None) == "group" - def get_refs(self, type: Type, obj): + def get_refs(self, type: Type, obj: Any) -> str | None: if type == Type.NETWORK_OBJECT: return obj.obj_member_refs if type == Type.SERVICE_OBJECT: return obj.svc_member_refs return obj.get('user_member_refs', None) - - def get_members(self, type, refs) -> list[str]: + + def get_members(self, type: Type, refs: str | None) -> list[str]: if type == Type.NETWORK_OBJECT: return [member.split(fwo_const.user_delimiter)[0] for member in refs.split(fwo_const.list_delimiter) if member] if refs else [] return refs.split(fwo_const.list_delimiter) if refs else [] - def get_flats(self, type, uid): + def get_flats(self, type: Type, uid: str) -> list[str]: if type == Type.NETWORK_OBJECT: return self.group_flats_mapper.get_network_object_flats([uid]) if type == Type.SERVICE_OBJECT: return self.group_flats_mapper.get_service_object_flats([uid]) return self.group_flats_mapper.get_user_flats([uid]) - - def get_prev_flats(self, type, uid): + + def get_prev_flats(self, type: Type, uid: str) -> list[str]: if type == Type.NETWORK_OBJECT: return self.prev_group_flats_mapper.get_network_object_flats([uid]) if type == Type.SERVICE_OBJECT: @@ -409,8 +411,8 @@ def get_prefix(self, type: Type): def remove_outdated_memberships(self, prev_config: FwConfigNormalized, type: Type): errors = 0 changes = 0 - removed_members = [] - removed_flats = [] + removed_members: list[dict[str, Any]] = [] + removed_flats: list[dict[str, Any]] = [] prev_config_objects, current_config_objects = self.get_config_objects(type, prev_config) prefix = self.get_prefix(type) @@ -441,7 +443,7 @@ def remove_outdated_memberships(self, prev_config: FwConfigNormalized, type: Typ }} }} """ - query_variables = { + query_variables: dict[str, Any] = { 'importId': self.import_state.ImportId, 'removedMembers': removed_members, 'removedFlats': removed_flats @@ -462,7 +464,7 @@ def remove_outdated_memberships(self, prev_config: FwConfigNormalized, type: Typ return errors, changes - def find_removed_objects(self, current_config_objects, prev_config_objects: dict[str,Any], removed_members:list, removed_flats: list, + def find_removed_objects(self, current_config_objects: dict[str, Any], prev_config_objects: dict[str, Any], removed_members: list[dict[str, Any]], removed_flats: list[dict[str, Any]], prefix: str, uid: str, type: Type) -> None: if not self.is_group(type, prev_config_objects[uid]): return @@ -497,7 +499,7 @@ def find_removed_objects(self, current_config_objects, prev_config_objects: dict }) - def addGroupMemberships(self, prev_config, obj_type: Type): + def addGroupMemberships(self, prev_config: FwConfigNormalized, obj_type: Type) -> tuple[int, int]: """ This function is used to update group memberships for nwobjs, services or users in the database. It adds group memberships and flats for new and updated members. @@ -505,8 +507,8 @@ def addGroupMemberships(self, prev_config, obj_type: Type): prev_config (FwConfigNormalized): The previous normalized config. """ errors = 0 - new_group_members = [] - new_group_member_flats = [] + new_group_members: list[dict[str, Any]] = [] + new_group_member_flats: list[dict[str, Any]] = [] prev_config_objects, current_config_objects = self.get_config_objects(obj_type, prev_config) prefix = self.get_prefix(obj_type) for uid in current_config_objects.keys(): @@ -523,6 +525,10 @@ def addGroupMemberships(self, prev_config, obj_type: Type): prev_flat_member_uids = self.get_prev_flats(obj_type, uid) group_id = self.get_id(obj_type, uid) + if group_id is None: + logger = getFwoLogger() + logger.error(f"failed to add group memberships: no id found for group uid '{uid}'") + continue self.collect_group_members(group_id, current_config_objects, new_group_members, member_uids, obj_type, prefix, prev_member_uids, prev_config_objects) flat_member_uids = self.get_flats(obj_type, uid) self.collect_flat_group_members(group_id, current_config_objects, new_group_member_flats, flat_member_uids, obj_type, prefix, prev_flat_member_uids, prev_config_objects) @@ -533,7 +539,7 @@ def addGroupMemberships(self, prev_config, obj_type: Type): return self.write_member_updates(new_group_members, new_group_member_flats, prefix, errors) - def collect_flat_group_members(self, group_id, current_config_objects, new_group_member_flats, flat_member_uids, obj_type, prefix, prev_flat_member_uids, prev_config_objects): + def collect_flat_group_members(self, group_id: int, current_config_objects: dict[str, Any], new_group_member_flats: list[dict[str, Any]], flat_member_uids: list[str], obj_type: Type, prefix: str, prev_flat_member_uids: list[str], prev_config_objects: dict[str, Any]): for flat_member_uid in flat_member_uids: if flat_member_uid in prev_flat_member_uids and prev_config_objects[flat_member_uid] == current_config_objects[flat_member_uid]: continue # flat member was not added or changed @@ -546,7 +552,7 @@ def collect_flat_group_members(self, group_id, current_config_objects, new_group }) - def collect_group_members(self, group_id, current_config_objects, new_group_members, member_uids, obj_type, prefix, prev_member_uids, prev_config_objects): + def collect_group_members(self, group_id: int, current_config_objects: dict[str, Any], new_group_members: list[dict[str, Any]], member_uids: list[str], obj_type: Type, prefix: str, prev_member_uids: list[str], prev_config_objects: dict[str, Any]): for member_uid in member_uids: if member_uid in prev_member_uids and prev_config_objects[member_uid] == current_config_objects[member_uid]: continue # member was not added or changed @@ -559,7 +565,7 @@ def collect_group_members(self, group_id, current_config_objects, new_group_memb }) - def write_member_updates(self, new_group_members, new_group_member_flats, prefix, errors): + def write_member_updates(self, new_group_members: list[dict[str, Any]], new_group_member_flats: list[dict[str, Any]], prefix: str, errors: int) -> tuple[int, int]: logger = getFwoLogger() changes = 0 import_mutation = f""" @@ -595,15 +601,15 @@ def write_member_updates(self, new_group_members, new_group_member_flats, prefix return errors, changes - def lookupObjType(self, objTypeString):# -> Any: + def lookupObjType(self, objTypeString: str) -> int: # TODO: might check for miss here as this is a mandatory field! return self.NetworkObjectTypeMap.get(objTypeString, -1) - def lookupSvcType(self, svcTypeString): + def lookupSvcType(self, svcTypeString: str) -> int: # TODO: might check for miss here as this is a mandatory field! return self.ServiceObjectTypeMap.get(svcTypeString, -1) - - def lookupUserType(self, userTypeString): + + def lookupUserType(self, userTypeString: str) -> int: return self.UserObjectTypeMap.get(userTypeString, -1) def lookupObjIdToUidAndPolicyName(self, objId: int) -> str: @@ -614,18 +620,15 @@ def lookupObjIdToUidAndPolicyName(self, objId: int) -> str: def lookupSvcIdToUidAndPolicyName(self, svcId: int): return str(svcId) # mock - def lookupProtoNameToId(self, protoString): + def lookupProtoNameToId(self, protoString: str | int) -> int | None: if isinstance(protoString, int): # logger = getFwoLogger() # logger.warning(f"found protocol with an id as name: {str(protoString)}") return protoString # already an int, do nothing else: - if protoString == None: - return None - else: - return self.ProtocolMap.get(protoString.lower(), None) + return self.ProtocolMap.get(protoString.lower(), None) - def prepareChangelogObjects(self, nwObjIdsAdded, svcObjIdsAdded, nwObjIdsRemoved, svcObjIdsRemoved): + def prepareChangelogObjects(self, nwObjIdsAdded: list[dict[str, int]], svcObjIdsAdded: list[dict[str, int]], nwObjIdsRemoved: list[dict[str, int]], svcObjIdsRemoved: list[dict[str, int]]) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: """ insert into stm_change_type (change_type_id,change_type_name) VALUES (1,'factory settings'); insert into stm_change_type (change_type_id,change_type_name) VALUES (2,'initial import'); @@ -633,8 +636,8 @@ def prepareChangelogObjects(self, nwObjIdsAdded, svcObjIdsAdded, nwObjIdsRemoved """ # TODO: deal with object changes where we need old and new obj id - nwObjs = [] - svcObjs = [] + nwObjs: list[dict[str, Any]] = [] + svcObjs: list[dict[str, Any]] = [] importTime = datetime.datetime.now().isoformat() changeTyp = 3 # standard change_logger = ChangeLogger() @@ -667,7 +670,7 @@ def prepareChangelogObjects(self, nwObjIdsAdded, svcObjIdsAdded, nwObjIdsRemoved return nwObjs, svcObjs - def addChangelogObjects(self, nwObjIdsAdded, svcObjIdsAdded, nwObjIdsRemoved, svcObjIdsRemoved): + def addChangelogObjects(self, nwObjIdsAdded: list[dict[str, int]], svcObjIdsAdded: list[dict[str, int]], nwObjIdsRemoved: list[dict[str, int]], svcObjIdsRemoved: list[dict[str, int]]): logger = getFwoLogger() errors = 0 diff --git a/roles/importer/files/importer/model_controllers/fwconfig_import_rule.py b/roles/importer/files/importer/model_controllers/fwconfig_import_rule.py index 244f3c6019..2658b45d30 100644 --- a/roles/importer/files/importer/model_controllers/fwconfig_import_rule.py +++ b/roles/importer/files/importer/model_controllers/fwconfig_import_rule.py @@ -2,7 +2,7 @@ import traceback from difflib import ndiff import json -from typing import List, Optional +from typing import Generator, Any import fwo_globals import fwo_const @@ -19,6 +19,8 @@ from models.rule_to import RuleTo from models.rule_service import RuleService from models.rule import RuleNormalized +from models.networkobject import NetworkObject +from models.serviceobject import ServiceObject from services.global_state import GlobalState from services.group_flats_mapper import GroupFlatsMapper from services.enums import Services @@ -38,10 +40,10 @@ class RefType(Enum): # this class is used for importing rules and rule refs into the FWO API class FwConfigImportRule(): - _changed_rule_id_map: dict + _changed_rule_id_map: dict[int, int] global_state: GlobalState import_details: ImportStateController - normalized_config: FwConfigNormalized + normalized_config: FwConfigNormalized | None = None uid2id_mapper: Uid2IdMapper group_flats_mapper: GroupFlatsMapper prev_group_flats_mapper: GroupFlatsMapper @@ -61,7 +63,7 @@ def __init__(self): self.prev_group_flats_mapper = service_provider.get_service(Services.PREV_GROUP_FLATS_MAPPER, self.import_details.ImportId) self.rule_order_service = service_provider.get_service(Services.RULE_ORDER_SERVICE, self.import_details.ImportId) - def updateRulebaseDiffs(self, prevConfig: FwConfigNormalized): + def updateRulebaseDiffs(self, prevConfig: FwConfigNormalized) -> list[int]: logger = getFwoLogger(debug_level=self.import_details.DebugLevel) @@ -73,7 +75,7 @@ def updateRulebaseDiffs(self, prevConfig: FwConfigNormalized): ruleUidsInBoth: dict[str, list[str]] = {} previous_rulebase_uids: list[str] = [] current_rulebase_uids: list[str] = [] - new_hit_information = [] + new_hit_information: list[dict[str, Any]] = [] rule_order_diffs: dict[str, dict[str, list[str]]] = self.rule_order_service.update_rule_order_diffs(self.import_details.DebugLevel) @@ -86,13 +88,14 @@ def updateRulebaseDiffs(self, prevConfig: FwConfigNormalized): current_rulebase_uids.append(rulebase.uid) for rulebase_uid in previous_rulebase_uids: - current_rulebase = self.normalized_config.get_rulebase(rulebase_uid) + current_rulebase = self.normalized_config.get_rulebase_or_none(rulebase_uid) if current_rulebase is None: - continue # rulebase has been deleted + logger.info(f"current rulebase has been deleted: {rulebase_uid}") + continue if rulebase_uid in current_rulebase_uids: # deal with policies contained both in this and previous config previous_rulebase = prevConfig.get_rulebase(rulebase_uid) - ruleUidsInBoth.update({ rulebase_uid: list(current_rulebase.rules.keys() & previous_rulebase.rules.keys()) }) # type: ignore + ruleUidsInBoth.update({ rulebase_uid: list(current_rulebase.rules.keys() & previous_rulebase.rules.keys()) }) else: logger.info(f"previous rulebase has been deleted: {current_rulebase.name} (id:{rulebase_uid})") @@ -115,7 +118,7 @@ def updateRulebaseDiffs(self, prevConfig: FwConfigNormalized): newRulebases = self.getRules(rule_order_diffs["new_rule_uids"]) # update rule_metadata before adding rules - num_added_metadata_rules, new_rule_metadata_ids = self.addNewRuleMetadata(newRulebases) + _, _ = self.addNewRuleMetadata(newRulebases) _ = self.update_rule_metadata_last_hit(new_hit_information) # # now update the database with all rule diffs @@ -125,10 +128,10 @@ def updateRulebaseDiffs(self, prevConfig: FwConfigNormalized): num_changed_rules, old_rule_ids, updated_rule_ids = self.create_new_rule_version(changedRuleUids) self.uid2id_mapper.add_rule_mappings(new_rule_ids + updated_rule_ids) - num_new_refs = self.add_new_refs(prevConfig) + _ = self.add_new_refs(prevConfig) num_deleted_rules, removed_rule_ids = self.mark_rules_removed(rule_order_diffs["deleted_rule_uids"]) - num_removed_refs = self.remove_outdated_refs(prevConfig) + _ = self.remove_outdated_refs(prevConfig) _, num_moved_rules, _ = self.verify_rules_moved(changedRuleUids) @@ -159,7 +162,7 @@ def _create_removed_rules_map(self, removed_rule_ids: list[int]): - def _collect_uncaught_moves(self, movedRuleUids, changedRuleUids): + def _collect_uncaught_moves(self, movedRuleUids: dict[str, list[str]], changedRuleUids: dict[str, list[str]]): for rulebaseId in movedRuleUids: for ruleUid in movedRuleUids[rulebaseId]: if ruleUid not in changedRuleUids.get(rulebaseId, []): @@ -167,7 +170,7 @@ def _collect_uncaught_moves(self, movedRuleUids, changedRuleUids): changedRuleUids[rulebaseId] = [] changedRuleUids[rulebaseId].append(ruleUid) - def collect_all_hit_information(self, prev_config: FwConfigNormalized, new_hit_information: list[dict]): + def collect_all_hit_information(self, prev_config: FwConfigNormalized, new_hit_information: list[dict[str, Any]]): """ Consolidated hit information collection for ALL rules that need hit updates. @@ -175,9 +178,9 @@ def collect_all_hit_information(self, prev_config: FwConfigNormalized, new_hit_i prev_config: Previous configuration for comparison new_hit_information: List to append hit update information to """ - processed_rules = set() + processed_rules: set[str] = set() - def add_hit_update(new_hit_information: list[dict], rule: RuleNormalized): + def add_hit_update(new_hit_information: list[dict[str, Any]], rule: RuleNormalized): """Add a hit information update entry for a rule.""" new_hit_information.append({ "where": { "rule_uid": { "_eq": rule.rule_uid } }, @@ -185,8 +188,11 @@ def add_hit_update(new_hit_information: list[dict], rule: RuleNormalized): }) # check all rulebases in current config + if self.normalized_config is None: + raise FwoImporterError("cannot collect hit information: normalized_config is None") + for current_rulebase in self.normalized_config.rulebases: - previous_rulebase = prev_config.get_rulebase(current_rulebase.uid) + previous_rulebase = prev_config.get_rulebase_or_none(current_rulebase.uid) for rule_uid in current_rulebase.rules: current_rule = current_rulebase.rules[rule_uid] @@ -201,7 +207,7 @@ def add_hit_update(new_hit_information: list[dict], rule: RuleNormalized): add_hit_update(new_hit_information, current_rule) processed_rules.add(rule_uid) - def update_rule_metadata_last_hit(self, new_hit_information: list[dict]) -> int: + def update_rule_metadata_last_hit(self, new_hit_information: list[dict[str, Any]]) -> int: """ Updates rule_metadata.rule_last_hit for all rules with hit information changes. This method executes the actual database updates for hit information. @@ -231,26 +237,21 @@ def update_rule_metadata_last_hit(self, new_hit_information: list[dict]) -> int: return changes @staticmethod - def collect_changed_rules(rule_uid, current_rulebase, previous_rulebase, rulebase_id, changed_rule_uids): + def collect_changed_rules(rule_uid: str, current_rulebase: Rulebase, previous_rulebase: Rulebase, rulebase_id: str, changed_rule_uids: dict[str, list[str]]): if current_rulebase.rules[rule_uid] != previous_rulebase.rules[rule_uid]: changed_rule_uids[rulebase_id].append(rule_uid) @staticmethod - def preserve_rule_num_numeric(current_rulebase, previous_rulebase, rule_uid): + def preserve_rule_num_numeric(current_rulebase: Rulebase, previous_rulebase: Rulebase, rule_uid: str): if current_rulebase.rules[rule_uid].rule_num_numeric == 0: current_rulebase.rules[rule_uid].rule_num_numeric = previous_rulebase.rules[rule_uid].rule_num_numeric - - def get_members(self, type, refs) -> list[str]: - if type == type.NETWORK_OBJECT: - return [member.split(fwo_const.user_delimiter)[0] for member in refs.split(fwo_const.list_delimiter) if member] if refs else [] - return refs.split(fwo_const.list_delimiter) if refs else [] - def get_rule_refs(self, rule, is_prev=False) -> dict[RefType, list[str]]: - froms = [] - tos = [] - users = [] + def get_rule_refs(self, rule: RuleNormalized, is_prev: bool = False) -> dict[RefType, list[tuple[str, str | None]] | list[str]]: + froms: list[tuple[str, str | None]] = [] + tos: list[tuple[str, str | None]] = [] + users: list[str] = [] for src_ref in rule.rule_src_refs.split(fwo_const.list_delimiter): user_ref = None if fwo_const.user_delimiter in src_ref: @@ -281,18 +282,26 @@ def get_rule_refs(self, rule, is_prev=False) -> dict[RefType, list[str]]: RefType.USER_RESOLVED: user_resolveds } - def get_ref_objs(self, ref_type, ref_uid, prev_config: FwConfigNormalized): + def get_ref_objs(self, ref_type: RefType, ref_uid: tuple[str, str | None] | str , prev_config: FwConfigNormalized) -> tuple[tuple[None | NetworkObject, None | Any], tuple[None | NetworkObject , None | Any]] | tuple[None | NetworkObject | ServiceObject, None | Any]: #TODO Any is user type but there is no user Type + + if self.normalized_config is None: + raise FwoImporterError("cannot get ref objs: normalized_config is None") + if ref_type == RefType.SRC or ref_type == RefType.DST: nwobj_uid, user_uid = ref_uid + return (prev_config.network_objects.get(nwobj_uid, None), prev_config.users.get(user_uid, None) if user_uid else None), \ (self.normalized_config.network_objects.get(nwobj_uid, None), self.normalized_config.users.get(user_uid, None) if user_uid else None) + + if ref_type == RefType.NWOBJ_RESOLVED: - return prev_config.network_objects.get(ref_uid, None), self.normalized_config.network_objects.get(ref_uid, None) + return prev_config.network_objects.get(ref_uid, None), self.normalized_config.network_objects.get(ref_uid, None) # type: ignore TODO: change ref_uid to str only + if ref_type == RefType.SVC or ref_type == RefType.SVC_RESOLVED: - return prev_config.service_objects.get(ref_uid, None), self.normalized_config.service_objects.get(ref_uid, None) - return prev_config.users.get(ref_uid, None), self.normalized_config.users.get(ref_uid, None) + return prev_config.service_objects.get(ref_uid, None), self.normalized_config.service_objects.get(ref_uid, None) # type: ignore + return prev_config.users.get(ref_uid, None), self.normalized_config.users.get(ref_uid, None) # type: ignore - def get_ref_remove_statement(self, ref_type, rule_uid, ref_uid): + def get_ref_remove_statement(self, ref_type: RefType, rule_uid: str, ref_uid: tuple[str, str | None] | str) -> dict[str, Any]: if ref_type == RefType.SRC or ref_type == RefType.DST: nwobj_uid, user_uid = ref_uid statement = { @@ -310,26 +319,26 @@ def get_ref_remove_statement(self, ref_type, rule_uid, ref_uid): return { "_and": [ {"rule_id": {"_eq": self.uid2id_mapper.get_rule_id(rule_uid, before_update=True)}}, - {"svc_id": {"_eq": self.uid2id_mapper.get_service_object_id(ref_uid, before_update=True)}} + {"svc_id": {"_eq": self.uid2id_mapper.get_service_object_id(ref_uid, before_update=True)}} # type: ignore # ref_uid is str here ] } elif ref_type == RefType.NWOBJ_RESOLVED: return { "_and": [ {"rule_id": {"_eq": self.uid2id_mapper.get_rule_id(rule_uid, before_update=True)}}, - {"obj_id": {"_eq": self.uid2id_mapper.get_network_object_id(ref_uid, before_update=True)}} + {"obj_id": {"_eq": self.uid2id_mapper.get_network_object_id(ref_uid, before_update=True)}} # type: ignore # ref_uid is str here ] } elif ref_type == RefType.USER_RESOLVED: return { "_and": [ {"rule_id": {"_eq": self.uid2id_mapper.get_rule_id(rule_uid, before_update=True)}}, - {"user_id": {"_eq": self.uid2id_mapper.get_user_id(ref_uid, before_update=True)}} + {"user_id": {"_eq": self.uid2id_mapper.get_user_id(ref_uid, before_update=True)}} # type: ignore # ref_uid is str here ] } - def get_outdated_refs_to_remove(self, prev_rule: RuleNormalized, rule: RuleNormalized|None, prev_config, remove_all): + def get_outdated_refs_to_remove(self, prev_rule: RuleNormalized, rule: RuleNormalized | None, prev_config: FwConfigNormalized, remove_all: bool) -> dict[RefType, list[dict[str, Any]]]: """ Get the references that need to be removed for a rule based on comparison with the previous rule. Args: @@ -338,11 +347,15 @@ def get_outdated_refs_to_remove(self, prev_rule: RuleNormalized, rule: RuleNorma prev_config (FwConfigNormalized): The previous configuration containing the rules. remove_all (bool): If True, all references will be removed. If False, it will check for changes in references that need to be removed. """ - ref_uids = { ref_type: [] for ref_type in RefType } + ref_uids: dict[RefType, list[tuple[str, str | None]] | list[str]] = { ref_type: [] for ref_type in RefType } + + if rule is None: + raise FwoImporterError("cannot get outdated refs to remove: rule is None") + if not remove_all: ref_uids = self.get_rule_refs(rule) prev_ref_uids = self.get_rule_refs(prev_rule, is_prev=True) - refs_to_remove = {} + refs_to_remove: dict[RefType, list[dict[str, Any]]] = {} for ref_type in RefType: refs_to_remove[ref_type] = [] for prev_ref_uid in prev_ref_uids[ref_type]: @@ -351,13 +364,19 @@ def get_outdated_refs_to_remove(self, prev_rule: RuleNormalized, rule: RuleNorma if prev_ref_obj == ref_obj: continue # ref not removed or changed # ref removed or changed + if prev_rule.rule_uid is None: + raise FwoImporterError(f"previous reference UID is None: {prev_ref_uid} in rule {prev_rule.rule_uid}") refs_to_remove[ref_type].append(self.get_ref_remove_statement(ref_type, prev_rule.rule_uid, prev_ref_uid)) return refs_to_remove def remove_outdated_refs(self, prev_config: FwConfigNormalized): - all_refs_to_remove = {ref_type: [] for ref_type in RefType} + all_refs_to_remove: dict[RefType, list[dict[str, Any]]] = {ref_type: [] for ref_type in RefType} for prev_rulebase in prev_config.rulebases: - rules = next((rb.rules for rb in self.normalized_config.rulebases if rb.uid == prev_rulebase.uid), {}) + if self.normalized_config is None: + raise FwoImporterError("cannot remove outdated refs: normalized_config is None") + rules = next((rb.rules for rb in self.normalized_config.rulebases if rb.uid == prev_rulebase.uid), None) + if rules is None: + continue for prev_rule in prev_rulebase.rules.values(): uid = prev_rule.rule_uid if uid is None: @@ -372,7 +391,7 @@ def remove_outdated_refs(self, prev_config: FwConfigNormalized): import_mutation = FwoApi.get_graphql_code([fwo_const.graphql_query_path + "rule/updateRuleRefs.graphql"]) - query_variables = { + query_variables: dict[str, Any] = { 'importId': self.import_details.ImportId, 'ruleFroms': all_refs_to_remove[RefType.SRC], 'ruleTos': all_refs_to_remove[RefType.DST], @@ -391,21 +410,22 @@ def remove_outdated_refs(self, prev_config: FwConfigNormalized): else: return sum((import_result['data'][f"update_{ref_type.value}"].get('affected_rows', 0) for ref_type in RefType)) - def get_ref_add_statement(self, ref_type, rule, ref_uid): + def get_ref_add_statement(self, ref_type: RefType, rule: RuleNormalized, ref_uid: tuple[str, str | None] | str) -> dict[str, Any]: + + if rule.rule_uid is None: + raise FwoImporterError(f"rule UID is None: {rule} in rulebase during get_ref_add_statement") # should not happen + if ref_type == RefType.SRC: nwobj_uid, user_uid = ref_uid - obj_id = self.uid2id_mapper.get_network_object_id(nwobj_uid) - if obj_id is None: - self.import_details.Stats.addError(f"Network object {nwobj_uid} not found for rule {rule.rule_uid}") - raise FwoImporterError(f"Network object {nwobj_uid} not found for rule {rule.rule_uid}") + _ = self.uid2id_mapper.get_network_object_id(nwobj_uid) # check if nwobj exists new_ref_dict = RuleFrom( - rule_id=self.uid2id_mapper.get_rule_id(rule.rule_uid), + rule_id=self.uid2id_mapper.get_rule_id(rule.rule_uid), obj_id=self.uid2id_mapper.get_network_object_id(nwobj_uid), user_id=self.uid2id_mapper.get_user_id(user_uid) if user_uid else None, rf_create=self.import_details.ImportId, rf_last_seen=self.import_details.ImportId, #TODO: to be removed in the future negated=rule.rule_src_neg - ).dict() + ).model_dump() return new_ref_dict elif ref_type == RefType.DST: nwobj_uid, user_uid = ref_uid @@ -416,40 +436,40 @@ def get_ref_add_statement(self, ref_type, rule, ref_uid): rt_create=self.import_details.ImportId, rt_last_seen=self.import_details.ImportId, #TODO: to be removed in the future negated=rule.rule_dst_neg - ).dict() + ).model_dump() return new_ref_dict elif ref_type == RefType.SVC: new_ref_dict = RuleService( rule_id=self.uid2id_mapper.get_rule_id(rule.rule_uid), - svc_id=self.uid2id_mapper.get_service_object_id(ref_uid), + svc_id=self.uid2id_mapper.get_service_object_id(ref_uid), # type: ignore # ref_uid is str here TODO: Cleanup ref_uid dict rs_create=self.import_details.ImportId, rs_last_seen=self.import_details.ImportId, #TODO: to be removed in the future - ).dict() + ).model_dump() return new_ref_dict elif ref_type == RefType.NWOBJ_RESOLVED: return { "mgm_id": self.import_details.MgmDetails.CurrentMgmId, "rule_id": self.uid2id_mapper.get_rule_id(rule.rule_uid), - "obj_id": self.uid2id_mapper.get_network_object_id(ref_uid), + "obj_id": self.uid2id_mapper.get_network_object_id(ref_uid), # type: ignore # ref_uid is str here TODO: Cleanup ref_uid dict "created": self.import_details.ImportId, } elif ref_type == RefType.SVC_RESOLVED: return { "mgm_id": self.import_details.MgmDetails.CurrentMgmId, "rule_id": self.uid2id_mapper.get_rule_id(rule.rule_uid), - "svc_id": self.uid2id_mapper.get_service_object_id(ref_uid), + "svc_id": self.uid2id_mapper.get_service_object_id(ref_uid), # type: ignore # ref_uid is str here TODO: Cleanup ref_uid dict "created": self.import_details.ImportId, } elif ref_type == RefType.USER_RESOLVED: return { "mgm_id": self.import_details.MgmDetails.CurrentMgmId, "rule_id": self.uid2id_mapper.get_rule_id(rule.rule_uid), - "user_id": self.uid2id_mapper.get_user_id(ref_uid), + "user_id": self.uid2id_mapper.get_user_id(ref_uid), # type: ignore # ref_uid is str here TODO: Cleanup ref_uid dict "created": self.import_details.ImportId, } - def get_new_refs_to_add(self, rule, prev_rule, prev_config, add_all): + def get_new_refs_to_add(self, rule: RuleNormalized, prev_rule: RuleNormalized | None, prev_config: FwConfigNormalized, add_all: bool) -> dict[RefType, list[dict[str, Any]]]: """ Get the references that need to be added for a rule based on comparison with the previous rule. Args: @@ -458,11 +478,11 @@ def get_new_refs_to_add(self, rule, prev_rule, prev_config, add_all): prev_config (FwConfigNormalized): The previous configuration containing the rules. add_all (bool): If True, all references will be added. If False, it will check for changes in references that need to be added. """ - prev_ref_uids = { ref_type: [] for ref_type in RefType } - if not add_all: + prev_ref_uids: dict[RefType, list[tuple[str, str | None]] | list[str]] = { ref_type: [] for ref_type in RefType } + if not add_all and prev_rule is not None: prev_ref_uids = self.get_rule_refs(prev_rule, is_prev=True) ref_uids = self.get_rule_refs(rule) - refs_to_add = {} + refs_to_add: dict[RefType, list[dict[str, Any]]] = {} for ref_type in RefType: refs_to_add[ref_type] = [] for ref_uid in ref_uids[ref_type]: @@ -475,9 +495,13 @@ def get_new_refs_to_add(self, rule, prev_rule, prev_config, add_all): return refs_to_add def add_new_refs(self, prev_config: FwConfigNormalized): - all_refs_to_add = {ref_type: [] for ref_type in RefType} + all_refs_to_add: dict[RefType, list[dict[str, Any]]] = {ref_type: [] for ref_type in RefType} + if self.normalized_config is None: + raise FwoImporterError("cannot add new refs: normalized_config is None") for rulebase in self.normalized_config.rulebases: - prev_rules = next((rb.rules for rb in prev_config.rulebases if rb.uid == rulebase.uid), {}) + prev_rules = next((rb.rules for rb in prev_config.rulebases if rb.uid == rulebase.uid), None) + if prev_rules is None: + continue for rule in rulebase.rules.values(): uid = rule.rule_uid if uid is None: @@ -510,9 +534,8 @@ def add_new_refs(self, prev_config: FwConfigNormalized): return sum((import_result['data'][f"insert_{ref_type.value}"].get('affected_rows', 0) for ref_type in RefType)) - def getRulesByIdWithRefUids(self, ruleIds: list[int]) -> tuple[int, int, list[Rule]]: + def getRulesByIdWithRefUids(self, ruleIds: list[int]) -> tuple[int, int, list[dict[str, Any]]]: #TODO: change return type to list[Rule] and cast logger = getFwoLogger() - rulesToBeReferenced = [] getRuleUidRefsQuery = FwoApi.get_graphql_code([fwo_const.graphql_query_path + "rule/getRulesByIdWithRefUids.graphql"]) query_variables = { 'ruleIds': ruleIds } @@ -520,7 +543,7 @@ def getRulesByIdWithRefUids(self, ruleIds: list[int]) -> tuple[int, int, list[Ru import_result = self.import_details.api_call.call(getRuleUidRefsQuery, query_variables=query_variables) if 'errors' in import_result: logger.exception(f"fwconfig_import_rule:getRulesByIdWithRefUids - error in addNewRules: {str(import_result['errors'])}") - return 1, 0, rulesToBeReferenced + return 1, 0, [] else: return 0, 0, import_result['data']['rule'] except Exception: @@ -528,9 +551,13 @@ def getRulesByIdWithRefUids(self, ruleIds: list[int]) -> tuple[int, int, list[Ru raise - def getRules(self, ruleUids) -> list[Rulebase]: + def getRules(self, ruleUids: dict[str, list[str]]) -> list[Rulebase]: #TODO: seems unnecessary, as the rulebases should already have been created this way in the normalized config - rulebases = [] + rulebases: list[Rulebase] = [] + + if self.normalized_config is None: + raise FwoImporterError("cannot get rules: normalized_config is None") + for rb in self.normalized_config.rulebases: if rb.uid in ruleUids: filtered_rules = {uid: rule for uid, rule in rb.rules.items() if uid in ruleUids[rb.uid]} @@ -548,13 +575,13 @@ def getRules(self, ruleUids) -> list[Rulebase]: # assuming input of form: # {'rule-uid1': {'rule_num': 17', ... }, 'rule-uid2': {'rule_num': 8, ...}, ... } @staticmethod - def ruleDictToOrderedListOfRuleUids(rules): + def ruleDictToOrderedListOfRuleUids(rules: dict[str, dict[str, Any]]) -> list[str]: return sorted(rules, key=lambda x: rules[x]['rule_num']) @staticmethod - def listDiff(oldRules, newRules): + def listDiff(oldRules: list[str], newRules: list[str]) -> list[tuple[str, str]]: diff = list(ndiff(oldRules, newRules)) - changes = [] + changes: list[tuple[str, str]] = [] for change in diff: if change.startswith("- "): @@ -566,7 +593,7 @@ def listDiff(oldRules, newRules): return changes - def _find_following_rules(self, ruleUid, previousRulebase, rulebaseId): + def _find_following_rules(self, ruleUid: str, previousRulebase: dict[str, int], rulebaseId: str) -> Generator[str, None, None]: """ Helper method to find the next rule in self that has an existing rule number. @@ -575,9 +602,9 @@ def _find_following_rules(self, ruleUid, previousRulebase, rulebaseId): :return: Generator yielding rule IDs that appear after `current_rule_id` in self.new_rules. """ found = False + if self.normalized_config is None: + raise FwoImporterError("cannot find following rules: normalized_config is None") current_rulebase = self.normalized_config.get_rulebase(rulebaseId) - if current_rulebase is None: - raise FwoImporterError(f"rulebase with id {rulebaseId} not found in current config") for currentUid in current_rulebase.rules: if currentUid == ruleUid: found = True @@ -586,12 +613,12 @@ def _find_following_rules(self, ruleUid, previousRulebase, rulebaseId): # adds new rule_metadatum to the database - def addNewRuleMetadata(self, newRules: list[Rulebase]): + def addNewRuleMetadata(self, newRules: list[Rulebase]) -> tuple[int, list[int]]: logger = getFwoLogger() - changes = 0 - newRuleMetaDataIds = [] - newRuleIds = [] - + changes: int = 0 + newRuleMetaDataIds: list[int] = [] + newRuleIds: list[int] = [] + addNewRuleMetadataMutation = """mutation upsertRuleMetadata($ruleMetadata: [rule_metadata_insert_input!]!) { insert_rule_metadata(objects: $ruleMetadata, on_conflict: {constraint: rule_metadata_rule_uid_unique, update_columns: [rule_last_modified]}) { affected_rows @@ -602,7 +629,7 @@ def addNewRuleMetadata(self, newRules: list[Rulebase]): } """ - addNewRuleMetadata: list[dict] = self.PrepareNewRuleMetadata(newRules) + addNewRuleMetadata: list[dict[str, Any]] = self.PrepareNewRuleMetadata(newRules) query_variables = { 'ruleMetadata': addNewRuleMetadata } if fwo_globals.debug_level>9: @@ -626,8 +653,8 @@ def addNewRuleMetadata(self, newRules: list[Rulebase]): def add_rulebases_without_rules(self, newRules: list[Rulebase]): logger = getFwoLogger() - changes = 0 - newRulebaseIds = [] + changes: int = 0 + newRulebaseIds: list[int] = [] addRulebasesWithoutRulesMutation = """mutation upsertRulebaseWithoutRules($rulebases: [rulebase_insert_input!]!) { insert_rulebase( @@ -671,20 +698,20 @@ def add_rulebases_without_rules(self, newRules: list[Rulebase]): # as we cannot add the rules for all rulebases in one go (using a constraint from the rule table), # we need to add them per rulebase separately #TODO: separation because of constraint still needed? - def add_rules_within_rulebases(self, rulebases: List[Rulebase]) -> tuple[int, list[dict]]: + def add_rules_within_rulebases(self, rulebases: list[Rulebase]) -> tuple[int, list[dict[str, Any]]]: """ Adds rules within the given rulebases to the database. Args: - rulebases (List[Rulebase]): List of Rulebase objects containing rules to be added + rulebases (list[Rulebase]): List of Rulebase objects containing rules to be added Returns: tuple[int, list[dict]]: A tuple containing the number of changes made and a list of dictionaries, each with 'rule_id' and 'rule_uid' for each newly added rule. """ logger = getFwoLogger() - changes = 0 - newRuleIds = [] + changes: int = 0 + newRuleIds: list[dict[str, Any]] = [] upsertRulebaseWithRules = """mutation upsertRules($rules: [rule_insert_input!]!) { insert_rule( @@ -712,16 +739,16 @@ def add_rules_within_rulebases(self, rulebases: List[Rulebase]) -> tuple[int, li # adds only new rules to the database # unchanged or deleted rules are not touched here - def add_new_rules(self, rulebases: list[Rulebase]) -> tuple[int, list[dict]]: + def add_new_rules(self, rulebases: list[Rulebase]) -> tuple[int, list[dict[str, Any]]]: #TODO: currently brute-forcing all rulebases and rules and depending on constraints to avoid duplicates. seems inefficient. - changes1, newRulebaseIds = self.add_rulebases_without_rules(rulebases) + changes1, _ = self.add_rulebases_without_rules(rulebases) changes2, newRuleIds = self.add_rules_within_rulebases(rulebases) return changes1 + changes2, newRuleIds - def PrepareNewRuleMetadata(self, newRules: list[Rulebase]) -> list[dict]: - newRuleMetadata: list[dict] = [] + def PrepareNewRuleMetadata(self, newRules: list[Rulebase]) -> list[dict[str, Any]]: + newRuleMetadata: list[dict[str, Any]] = [] now = datetime.now().strftime("%Y-%m-%d %H:%M:%S") for rulebase in newRules: @@ -754,9 +781,8 @@ def PrepareNewRulebases(self, newRulebases: list[Rulebase]) -> list[RulebaseForI return newRulesForImport def mark_rules_removed(self, removedRuleUids: dict[str, list[str]]) -> tuple[int, list[int]]: - logger = getFwoLogger() changes = 0 - collectedRemovedRuleIds = [] + collectedRemovedRuleIds: list[int] = [] # TODO: make sure not to mark new (changed) rules as removed (order of calls!) @@ -771,7 +797,7 @@ def mark_rules_removed(self, removedRuleUids: dict[str, list[str]]) -> tuple[int } } """ - query_variables = { 'importId': self.import_details.ImportId, + query_variables: dict[str, Any] = { 'importId': self.import_details.ImportId, 'mgmId': self.import_details.MgmDetails.CurrentMgmId, 'uids': list(removedRuleUids[rbName]) } @@ -789,7 +815,7 @@ def mark_rules_removed(self, removedRuleUids: dict[str, list[str]]) -> tuple[int return changes, collectedRemovedRuleIds - def create_new_rule_version(self, rule_uids: dict[str, list[str]]) -> tuple[int, list[int], list[dict]]: + def create_new_rule_version(self, rule_uids: dict[str, list[str]]) -> tuple[int, list[int], list[dict[str, Any]]]: """ Creates new versions of rules specified in rule_uids by inserting new rule entries and marking the old ones as removed. @@ -799,7 +825,6 @@ def create_new_rule_version(self, rule_uids: dict[str, list[str]]) -> tuple[int, Returns: tuple[int, list[int], list[dict]]: A tuple containing the number of changes made, a list of old rule IDs that were changed, and a list of newly inserted rule entries. """ - logger = getFwoLogger() self._changed_rule_id_map = {} if len(rule_uids) == 0: @@ -859,7 +884,7 @@ def create_new_rule_version(self, rule_uids: dict[str, list[str]]) -> tuple[int, import_rules.extend(import_rules_of_rulebase) - create_new_rule_version_variables = { + create_new_rule_version_variables: dict[str, Any] = { "objects": [rule.model_dump() for rule in import_rules], "uids": [rule.rule_uid for rule in import_rules], "mgmId": self.import_details.MgmDetails.CurrentMgmId, @@ -874,8 +899,8 @@ def create_new_rule_version(self, rule_uids: dict[str, list[str]]) -> tuple[int, raise FwoApiWriteError(f"failed to create new rule versions: {str(create_new_rule_version_result['errors'])}") else: changes = int(create_new_rule_version_result['data']['update_rule']['affected_rows']) - update_rules_return = create_new_rule_version_result['data']['update_rule']['returning'] - insert_rules_return = create_new_rule_version_result['data']['insert_rule']['returning'] + update_rules_return: list[dict[str, Any]] = create_new_rule_version_result['data']['update_rule']['returning'] + insert_rules_return: list[dict[str, Any]] = create_new_rule_version_result['data']['insert_rule']['returning'] self._changed_rule_id_map = { update_item['rule_id']: next( @@ -887,20 +912,20 @@ def create_new_rule_version(self, rule_uids: dict[str, list[str]]) -> tuple[int, } - collected_changed_rule_ids = list(self._changed_rule_id_map.keys()) or [] + collected_changed_rule_ids: list[int] = list(self._changed_rule_id_map.keys()) or [] return changes, collected_changed_rule_ids, insert_rules_return - def update_rule_enforced_on_gateway_after_move(self, insert_rules_return, update_rules_return): + def update_rule_enforced_on_gateway_after_move(self, insert_rules_return: list[dict[str, Any]], update_rules_return: list[dict[str, Any]]) -> tuple[int, int, list[str]]: """ Updates the db table rule_enforced_on_gateway by creating new entries for a list of rule_ids and setting the old versions of said rules removed. """ logger = getFwoLogger() - id_map = {} + id_map: dict[int, int] = {} for insert_rules_return_entry in insert_rules_return: id_map[ @@ -930,7 +955,7 @@ def update_rule_enforced_on_gateway_after_move(self, insert_rules_return, update } """ - set_rule_enforced_on_gateway_entries_removed_variables = { + set_rule_enforced_on_gateway_entries_removed_variables: dict[str, Any] = { "rule_ids": list(id_map.values()), "importId": self.import_details.ImportId, } @@ -951,8 +976,8 @@ def update_rule_enforced_on_gateway_after_move(self, insert_rules_return, update if 'errors' in set_rule_enforced_on_gateway_entries_removed_result: logger.exception(f"fwo_api:update_rule_enforced_on_gateway_after_move - error while updating moved rules refs: {str(set_rule_enforced_on_gateway_entries_removed_result['errors'])}") return 1, 0, [] - - insert_rule_enforced_on_gateway_entries_variables = { + + insert_rule_enforced_on_gateway_entries_variables: dict[str, Any] = { "new_entries": [ { "rule_id": new_id, @@ -980,7 +1005,7 @@ def verify_rules_moved(self, changed_rule_uids: dict[str, list[str]]) -> tuple[i error_count_move = 0 number_of_moved_rules = 0 - moved_rule_uids = [] + moved_rule_uids: list[str] = [] changed_rule_uids_flat = [ uid @@ -1006,7 +1031,7 @@ def verify_rules_moved(self, changed_rule_uids: dict[str, list[str]]) -> tuple[i # TODO: limit query to a single rulebase - def GetRuleNumMap(self): + def GetRuleNumMap(self) -> dict[str, dict[str, float]]: query = "query getRuleNumMap($mgmId: Int) { rule(where:{mgm_id:{_eq:$mgmId}}) { rule_uid rulebase_id rule_num_numeric } }" try: result = self.import_details.api_call.call(query=query, query_variables={"mgmId": self.import_details.MgmDetails.CurrentMgmId}) @@ -1014,29 +1039,29 @@ def GetRuleNumMap(self): logger = getFwoLogger() logger.error(f'Error while getting rule number map') return {} - - map = {} + + map: dict[str, dict[str, float]] = {} for ruleNum in result['data']['rule']: if ruleNum['rulebase_id'] not in map: map.update({ ruleNum['rulebase_id']: {} }) # initialize rulebase map[ruleNum['rulebase_id']].update({ ruleNum['rule_uid']: ruleNum['rule_num_numeric']}) return map - def GetNextRuleNumMap(self): # TODO: implement! + def GetNextRuleNumMap(self) -> dict[str, float]: #TODO: implement! query = "query getRuleNumMap { rule { rule_uid rule_num_numeric } }" try: - result = self.import_details.api_call.call(query=query, query_variables={}) + _ = self.import_details.api_call.call(query=query, query_variables={}) except Exception: logger = getFwoLogger() logger.error(f'Error while getting rule number') return {} - - map = {} + + map: dict[str, float] = {} # for ruleNum in result['data']['rule']: # map.update({ruleNum['rule_uid']: ruleNum['rule_num_numeric']}) return map - def GetRuleTypeMap(self): + def GetRuleTypeMap(self) -> dict[str, int]: query = "query getTrackMap { stm_track { track_name track_id } }" try: result = self.import_details.api_call.call(query=query, query_variables={}) @@ -1045,13 +1070,13 @@ def GetRuleTypeMap(self): logger.error(f'Error while getting stm_track') return {} - map = {} + map: dict[str, int] = {} for track in result['data']['stm_track']: map.update({track['track_name']: track['track_id']}) return map - def getCurrentRules(self, importId, mgmId, rulebaseName): - query_variables = { + def getCurrentRules(self, importId: int, mgmId: int, rulebaseName: str) -> list[list[Any]] | None: + query_variables: dict[str, Any] = { "importId": importId, "mgmId": mgmId, "rulebaseName": rulebaseName @@ -1084,15 +1109,15 @@ def getCurrentRules(self, importId, mgmId, rulebaseName): logger.error(f'could not find rules in query result: {queryResult}') self.import_details.increaseErrorCounterByOne() return - - rules = [] + + rules: list[list[Any]] = [] for rule in ruleList: - rules.append([rule['rule']['rule_num'], rule['rule']['rule_num_numeric'], rule['rule']['rule_uid']]) + rules.append([rule['rule']['rule_num'], rule['rule']['rule_num_numeric'], rule['rule']['rule_uid']]) # TODO: change to tuple? return rules - - def insertRulebase(self, ruleBaseName, isGlobal=False): + + def insertRulebase(self, ruleBaseName: str, isGlobal: bool = False): # call for each rulebase to add - query_variables = { + query_variables: dict[str, Any] = { "rulebase": { "is_global": isGlobal, "mgm_id": self.import_details.MgmDetails.CurrentMgmId, @@ -1122,8 +1147,8 @@ def insertRulebase(self, ruleBaseName, isGlobal=False): return self.import_details.api_call.call(mutation, query_variables=query_variables) - def importInsertRulebaseOnGateway(self, rulebaseId, devId, orderNo=0): - query_variables = { + def importInsertRulebaseOnGateway(self, rulebaseId: int, devId: int, orderNo: int = 0): + query_variables: dict[str, Any] = { "rulebase2gateway": [ { "dev_id": devId, @@ -1141,10 +1166,10 @@ def importInsertRulebaseOnGateway(self, rulebaseId, devId, orderNo=0): return self.import_details.api_call.call(mutation, query_variables=query_variables) - def _get_list_of_enforced_gateways(self, rule: RuleNormalized, importDetails: ImportStateController) -> Optional[List[int]]: + def _get_list_of_enforced_gateways(self, rule: RuleNormalized, importDetails: ImportStateController) -> list[int] | None: if rule.rule_installon is None: return None - enforced_gw_ids = [] + enforced_gw_ids: list[int] = [] for gwUid in rule.rule_installon.split(fwo_const.list_delimiter): gwId = importDetails.lookupGatewayId(gwUid) if gwId is None: @@ -1157,11 +1182,9 @@ def _get_list_of_enforced_gateways(self, rule: RuleNormalized, importDetails: Im return enforced_gw_ids - def prepare_rules_for_import(self, rules: list[RuleNormalized], rulebase_uid: str) -> List[Rule]: + def prepare_rules_for_import(self, rules: list[RuleNormalized], rulebase_uid: str) -> list[Rule]: # get rulebase_id for rulebaseUid rulebase_id = self.import_details.lookupRulebaseId(rulebase_uid) - if rulebase_id is None: - raise FwoApiWriteError(f"could not find rulebase id for rulebase uid {rulebase_uid} during rule import preparation") prepared_rules = [ self.prepare_single_rule_for_import(rule, self.import_details, rulebase_id) @@ -1226,7 +1249,7 @@ def prepare_single_rule_for_import(self, rule: RuleNormalized, importDetails: Im return rule_for_import - def write_changelog_rules(self, added_rules_ids, removed_rules_ids): + def write_changelog_rules(self, added_rules_ids: list[int], removed_rules_ids: list[int]) -> int: logger = getFwoLogger() errors = 0 @@ -1251,13 +1274,13 @@ def write_changelog_rules(self, added_rules_ids, removed_rules_ids): return errors - def prepare_changelog_rules_insert_objects(self, added_rules_ids, removed_rules_ids): + def prepare_changelog_rules_insert_objects(self, added_rules_ids: list[int], removed_rules_ids: list[int]) -> list[dict[str, Any]]: """ Creates two lists of insert arguments for the changelog_rules db table, one for new rules, one for deleted. """ change_logger = ChangeLogger() - changelog_rule_insert_objects = [] + changelog_rule_insert_objects: list[dict[str, Any]] = [] importTime = datetime.now().isoformat() changeTyp = 3 diff --git a/roles/importer/files/importer/model_controllers/fwconfig_import_ruleorder.py b/roles/importer/files/importer/model_controllers/fwconfig_import_ruleorder.py index dc52620fa0..9e7e1842a8 100644 --- a/roles/importer/files/importer/model_controllers/fwconfig_import_ruleorder.py +++ b/roles/importer/files/importer/model_controllers/fwconfig_import_ruleorder.py @@ -1,13 +1,19 @@ +from typing import Any from fwo_const import rule_num_numeric_steps -from models.fwconfig_normalized import FwConfigNormalized + from models.rule import RuleNormalized +from models.rulebase import Rulebase +from services.global_state import GlobalState from fwo_exceptions import FwoApiFailure from fwo_log import getFwoLogger -from services.global_state import GlobalState from services.service_provider import ServiceProvider from services.enums import Services +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from models.fwconfig_normalized import FwConfigNormalized + class RuleOrderService: """ A singleton service that holds data and provides logic to compute rule order values. @@ -15,15 +21,15 @@ class RuleOrderService: _service_provider: ServiceProvider _global_state: GlobalState - _normalized_config: FwConfigNormalized | None - _previous_config: FwConfigNormalized | None + _normalized_config: 'FwConfigNormalized | None' + _previous_config: 'FwConfigNormalized | None' _target_rule_uids: list[str] _target_rules_flat: list[RuleNormalized] _source_rule_uids: list[str] _source_rules_flat: list[RuleNormalized] - _min_moves: dict[str, list] + _min_moves: dict[str, Any] _deleted_rule_uids: dict[str, list[str]] _new_rule_uids: dict[str, list[str]] @@ -186,10 +192,12 @@ def _set_initial_rule_num_numerics(self): changed_rule.rule_num_numeric = current_rule_num_numeric - def _update_rule_on_move_or_insert(self, rule_uid, target_rulebase_uid): + def _update_rule_on_move_or_insert(self, rule_uid: str, target_rulebase_uid: str) -> None: next_rules_rule_num_numeric = 0.0 previous_rule_num_numeric = 0.0 + if self._normalized_config is None or self._previous_config is None: + raise ValueError("Config objects in global state not correctly initialized — expected non-None values.") target_rulebase = next((rulebase for rulebase in self._normalized_config.rulebases if rulebase.uid == target_rulebase_uid), None) unchanged_target_rulebase = next((rulebase for rulebase in self._previous_config.rulebases if rulebase.uid == target_rulebase_uid), None) @@ -235,6 +243,8 @@ def _update_rule_on_consecutive_insert(self, rule_uid: str, rulebase_uid: str) - prev_rule_num_numeric = 0 next_rule_num_numeric = 0 + if self._normalized_config is None: + raise ValueError("Config objects in global state not correctly initialized — expected non-None values.") target_rulebase = next(rulebase for rulebase in self._normalized_config.rulebases if rulebase.uid == rulebase_uid) while prev_rule_num_numeric == 0: @@ -266,14 +276,18 @@ def _update_rule_on_consecutive_insert(self, rule_uid: str, rulebase_uid: str) - rule.rule_num_numeric = (prev_rule_num_numeric + next_rule_num_numeric) / 2 - def _parse_rule_uids_and_objects_from_config(self, config: FwConfigNormalized): + def _parse_rule_uids_and_objects_from_config(self, config: 'FwConfigNormalized') -> tuple[list[str], list[RuleNormalized]]: uids_and_rules = [ (rule_uid, rule) for rulebase in config.rulebases for rule_uid, rule in rulebase.rules.items() ] - return map(list, zip(*uids_and_rules)) if uids_and_rules else ([], []) + if not uids_and_rules: + return ([], []) + + uids, rules = zip(*uids_and_rules) + return (list(uids), list(rules)) @@ -309,7 +323,7 @@ def _is_part_of_consecutive_insert(self, rule_uid: str): return True - def _get_adjacent_list_element(self, lst, index): + def _get_adjacent_list_element(self, lst:list[str], index: int) -> tuple[str | None, str | None]: if not lst or index < 0 or index >= len(lst): return None, None @@ -317,8 +331,7 @@ def _get_adjacent_list_element(self, lst, index): next_item = lst[index + 1] if index + 1 < len(lst) else None return prev_item, next_item - - def _get_index_and_rule_object_from_flat_list(self, flat_list, rule_uid): + def _get_index_and_rule_object_from_flat_list(self, flat_list: list[RuleNormalized], rule_uid: str): return next( (i, rule) for i, rule in enumerate(flat_list) if rule.rule_uid == rule_uid ) @@ -326,10 +339,10 @@ def _get_index_and_rule_object_from_flat_list(self, flat_list, rule_uid): def _get_relevant_rule_num_numeric( self, - rule_uid, - flat_list, + rule_uid: str, + flat_list: list[RuleNormalized] | None, #TODO flat_list should not be needed here ascending: bool, - target_rulebase + target_rulebase: Rulebase ) -> float: """ Returns the relevant rule_num_numeric for rule_uid. @@ -359,11 +372,11 @@ def _get_relevant_rule_num_numeric( return float(rule.rule_num_numeric) - def _compute_num_for_changed_rule(self, rule_uid, ascending: bool, target_rulebase) -> float: + def _compute_num_for_changed_rule(self, rule_uid: str, ascending: bool, target_rulebase: Rulebase) -> float: """Calculates rule_num_numeric for a new/moved rule relative to its neighbors in the target.""" # Get rule & neighbors in the target index, changed_rule = self._get_index_and_rule_object_from_flat_list( - target_rulebase.rules.values(), rule_uid + list(target_rulebase.rules.values()), rule_uid ) prev_uid, next_uid = self._get_adjacent_list_element(list(target_rulebase.rules.keys()), index) @@ -373,7 +386,7 @@ def _compute_num_for_changed_rule(self, rule_uid, ascending: bool, target_ruleba return self._num_for_descending_case(changed_rule, prev_uid, target_rulebase) - def _num_for_ascending_case(self, changed_rule, next_uid, target_rulebase) -> float: + def _num_for_ascending_case(self, changed_rule: RuleNormalized, next_uid: str | None, target_rulebase: Rulebase) -> float: """ Ascending: - If a next neighbor exists, recursively use its relevant value @@ -394,7 +407,7 @@ def _num_for_ascending_case(self, changed_rule, next_uid, target_rulebase) -> fl return 0.0 - def _num_for_descending_case(self, changed_rule, prev_uid, target_rulebase) -> float: + def _num_for_descending_case(self, changed_rule: RuleNormalized, prev_uid: str | None, target_rulebase: Rulebase) -> float: """ Descending: - If a previous neighbor exists, recursively use its relevant value @@ -418,7 +431,7 @@ def _num_for_descending_case(self, changed_rule, prev_uid, target_rulebase) -> f return 0 - def _max_num_numeric_rule(self, target_rulebase): + def _max_num_numeric_rule(self, target_rulebase: Rulebase): """Return the rule with the maximum rule_num_numeric, or None if empty.""" return max( (r for r in target_rulebase.rules.values()), @@ -427,7 +440,7 @@ def _max_num_numeric_rule(self, target_rulebase): ) - def _min_nonzero_num_numeric_rule(self, target_rulebase): + def _min_nonzero_num_numeric_rule(self, target_rulebase: Rulebase): """Return the rule with the minimum non-zero rule_num_numeric, or None if none exist.""" return min( (r for r in target_rulebase.rules.values() if getattr(r, "rule_num_numeric", 0) != 0), @@ -436,7 +449,7 @@ def _min_nonzero_num_numeric_rule(self, target_rulebase): ) - def _is_rule_uid_in_return_object(self, rule_uid, return_object): + def _is_rule_uid_in_return_object(self, rule_uid: str, return_object: Any) -> bool: for rule_uids in return_object.values(): for _rule_uid in rule_uids: if rule_uid == _rule_uid: diff --git a/roles/importer/files/importer/model_controllers/fwconfig_normalized_controller.py b/roles/importer/files/importer/model_controllers/fwconfig_normalized_controller.py index d67d42c6fe..ee6579575a 100644 --- a/roles/importer/files/importer/model_controllers/fwconfig_normalized_controller.py +++ b/roles/importer/files/importer/model_controllers/fwconfig_normalized_controller.py @@ -1,6 +1,5 @@ +from typing import Any from fwo_log import getFwoLogger -from model_controllers.import_state_controller import ImportStateController -from models.gateway import Gateway from fwo_base import ConfFormat from models.fwconfig_normalized import FwConfigNormalized @@ -14,9 +13,9 @@ def __init__(self, ConfigFormat: ConfFormat, fwConfig: FwConfigNormalized): self.NormalizedConfig = fwConfig @staticmethod - def convertListToDict(listIn: list, idField: str) -> dict: + def convertListToDict(listIn: list[Any], idField: str) -> dict[Any, Any]: logger = getFwoLogger() - result = {} + result: dict[Any, Any] = {} for item in listIn: if idField in item: key = item[idField] @@ -26,28 +25,28 @@ def convertListToDict(listIn: list, idField: str) -> dict: return result # { listIn[idField]: listIn for listIn in listIn } def __str__(self): - return f"{self.action}({str(self.network_objects)})" + return f"{self.action}({str(self.network_objects)})" # TODO self.action not defined? # type: ignore @staticmethod - def deleteControlIdFromDictList(dictListInOut: dict): + def deleteControlIdFromDictList(dictListInOut: dict[Any, Any] | list[Any]) -> dict[Any, Any] | list[Any]: if isinstance(dictListInOut, list): - deleteListDictElements(dictListInOut, ['control_id']) + deleteListDictElements(dictListInOut, ['control_id']) # TODO deleteListDictElements not defined elif isinstance(dictListInOut, dict): - deleteDictElements(dictListInOut, ['control_id']) + deleteDictElements(dictListInOut, ['control_id']) # TODO deleteListDictElements not defined return dictListInOut def split(self): return [self] # for now not implemented @staticmethod - def join(configList): - resultingConfig = FwConfigNormalized() + def join(configList: list[FwConfigNormalized]): + resultingConfig = FwConfigNormalized() for conf in configList: - resultingConfig.addElements(conf) + resultingConfig.addElements(conf) # TODO addElements not defined return resultingConfig - def addElements(self, config): - self.network_objects += config.Networks + def addElements(self, config: FwConfigNormalized): + self.network_objects += config.Networks # TODO: all members are not defined self.service_objects += config.Services self.users += config.Users self.zone_objects += config.Zones diff --git a/roles/importer/files/importer/model_controllers/fwconfigmanager_controller.py b/roles/importer/files/importer/model_controllers/fwconfigmanager_controller.py index 14c8737437..65bf40fae2 100644 --- a/roles/importer/files/importer/model_controllers/fwconfigmanager_controller.py +++ b/roles/importer/files/importer/model_controllers/fwconfigmanager_controller.py @@ -1,3 +1,4 @@ +from typing import Any from models.fwconfig_normalized import FwConfigNormalized from models.fwconfigmanager import FwConfigManager @@ -14,13 +15,13 @@ class FwConfigManagerController(FwConfigManager): } @classmethod - def fromJson(cls, jsonDict): - ManagerUid = jsonDict['manager_uid'] - ManagerName = jsonDict['mgm_name'] - IsGlobal = jsonDict['is_global'] - DependantManagerUids = jsonDict['dependant_manager_uids'] - Configs = jsonDict['configs'] - return cls(ManagerUid, ManagerName, IsGlobal, DependantManagerUids, Configs) + def fromJson(cls, jsonDict: dict[str, Any]) -> 'FwConfigManagerController': + ManagerUid: str = jsonDict['manager_uid'] + ManagerName: str = jsonDict['mgm_name'] + IsGlobal: bool = jsonDict['is_global'] + DependantManagerUids: list[str] = jsonDict['dependant_manager_uids'] + Configs: list[FwConfigNormalized] = jsonDict['configs'] + return cls(ManagerUid, ManagerName, IsGlobal, DependantManagerUids, Configs)#type: ignore # TODO: this class does not have a Constructor! def __str__(self): return f"{self.ManagerUid}({str(self.Configs)})" diff --git a/roles/importer/files/importer/model_controllers/fwconfigmanagerlist_controller.py b/roles/importer/files/importer/model_controllers/fwconfigmanagerlist_controller.py index fff2ae1144..0233bcbb3b 100644 --- a/roles/importer/files/importer/model_controllers/fwconfigmanagerlist_controller.py +++ b/roles/importer/files/importer/model_controllers/fwconfigmanagerlist_controller.py @@ -1,24 +1,18 @@ import json -import jsonpickle import time import traceback from copy import deepcopy +from typing import Any import fwo_globals from fwo_log import getFwoLogger -from fwo_exceptions import FwoImporterError -from model_controllers.interface_controller import InterfaceSerializable -from model_controllers.route_controller import RouteSerializable -from fwo_base import split_list, serializeDictToClassRecursively, deserializeClassToDictRecursively -from fwo_const import max_objs_per_chunk, import_tmp_path +from fwo_base import serializeDictToClassRecursively, deserializeClassToDictRecursively +from fwo_const import import_tmp_path from model_controllers.import_state_controller import ImportStateController -from model_controllers.management_controller import Management -from models.fwconfig_normalized import FwConfig, FwConfigNormalized from models.fwconfigmanagerlist import FwConfigManagerList from models.fwconfigmanager import FwConfigManager from model_controllers.fwconfig_controller import FwoEncoder -from model_controllers.management_controller import ManagementController from fwo_base import ConfFormat """ @@ -32,7 +26,7 @@ def __str__(self): def toJson(self): return deserializeClassToDictRecursively(self) - def toJsonString(self, prettyPrint=False): + def toJsonString(self, prettyPrint: bool=False): jsonDict = self.toJson() if prettyPrint: return json.dumps(jsonDict, indent=2, cls=FwoEncoder) @@ -44,7 +38,7 @@ def mergeConfigs(self, conf2: 'FwConfigManagerListController'): self.ManagerSet.extend(conf2.ManagerSet) @staticmethod - def generate_empty_config(is_super_manager=False): + def generate_empty_config(is_super_manager: bool=False): """ Generates an empty FwConfigManagerListController with a single empty FwConfigManager. """ @@ -67,7 +61,7 @@ def toJsonLegacy(self): return deserializeClassToDictRecursively(self) # to be re-written: - def toJsonStringLegacy(self, prettyPrint=False): + def toJsonStringLegacy(self, prettyPrint: bool=False): jsonDict = self.toJson() if prettyPrint: return json.dumps(jsonDict, indent=2, cls=FwoEncoder) @@ -75,11 +69,11 @@ def toJsonStringLegacy(self, prettyPrint=False): return json.dumps(jsonDict, cls=FwoEncoder) - def get_all_zone_names(self, mgr_uid): + def get_all_zone_names(self, mgr_uid: str) -> set[str]: """ Returns a list of all zone UIDs in the configuration. """ - all_zone_names = [] + all_zone_names: list[str] = [] for mgr in self.ManagerSet: if mgr.IsSuperManager or mgr.ManagerUid==mgr_uid: for single_config in mgr.Configs: @@ -87,11 +81,11 @@ def get_all_zone_names(self, mgr_uid): return set(all_zone_names) - def get_all_network_object_uids(self, mgr_uid): + def get_all_network_object_uids(self, mgr_uid: str) -> set[str]: """ Returns a list of all network objects in the configuration. """ - all_network_objects = [] + all_network_objects: list[str] = [] for mgr in self.ManagerSet: if mgr.IsSuperManager or mgr.ManagerUid==mgr_uid: for single_config in mgr.Configs: @@ -99,11 +93,11 @@ def get_all_network_object_uids(self, mgr_uid): return set(all_network_objects) - def get_all_service_object_uids(self, mgr_uid): + def get_all_service_object_uids(self, mgr_uid: str) -> set[str]: """ Returns a list of all service objects in the configuration. """ - all_service_objects = [] + all_service_objects: list[str] = [] for mgr in self.ManagerSet: if mgr.IsSuperManager or mgr.ManagerUid==mgr_uid: for single_config in mgr.Configs: @@ -111,11 +105,11 @@ def get_all_service_object_uids(self, mgr_uid): return set(all_service_objects) - def get_all_user_object_uids(self, mgr_uid): + def get_all_user_object_uids(self, mgr_uid: str) -> set[str]: """ Returns a list of all user objects in the configuration. """ - all_user_objects = [] + all_user_objects: list[str] = [] for mgr in self.ManagerSet: if mgr.IsSuperManager or mgr.ManagerUid==mgr_uid: for single_config in mgr.Configs: @@ -123,7 +117,7 @@ def get_all_user_object_uids(self, mgr_uid): return set(all_user_objects) - def addManager(self, manager): + def addManager(self, manager: FwConfigManager): self.ManagerSet.append(manager) def getFirstManager(self): @@ -141,7 +135,7 @@ def getPolicyUidFromRulebaseName(rb_name: str) -> str: return rb_name @classmethod - def FromJson(cls, jsonIn): + def FromJson(cls, jsonIn: dict[str, Any]) -> 'FwConfigManagerListController': return serializeDictToClassRecursively(jsonIn, cls) diff --git a/roles/importer/files/importer/model_controllers/fworch_config_controller.py b/roles/importer/files/importer/model_controllers/fworch_config_controller.py index 8c59c8f42a..0997ccc661 100644 --- a/roles/importer/files/importer/model_controllers/fworch_config_controller.py +++ b/roles/importer/files/importer/model_controllers/fworch_config_controller.py @@ -1,3 +1,4 @@ +from typing import Any from models.fworch_config import FworchConfig """ @@ -6,11 +7,11 @@ """ class FworchConfigController(FworchConfig): - def __init__(self, fwoApiUri, fwoUserMgmtApiUri, importerPwd, apiFetchSize=500): + def __init__(self, fwoApiUri: str | None, fwoUserMgmtApiUri: str | None, importerPwd: str | None , apiFetchSize: int = 500): if fwoApiUri is not None: self.FwoApiUri = fwoApiUri else: - self.FwoApiUFwoUserMgmtApiri = None + self.FwoApiUFwoUserMgmtApiri = None #TODO: Mispell? FwoApiUFwoUserMgmtApiUri if fwoUserMgmtApiUri is not None: self.FwoUserMgmtApiUri = fwoUserMgmtApiUri else: @@ -19,7 +20,7 @@ def __init__(self, fwoApiUri, fwoUserMgmtApiUri, importerPwd, apiFetchSize=500): self.ApiFetchSize = apiFetchSize @classmethod - def fromJson(cls, json_dict): + def fromJson(cls, json_dict: dict[str, Any]) -> "FworchConfigController": fwoApiUri = json_dict['fwo_api_base_url'] fwoUserMgmtApiUri = json_dict['user_management_api_base_url'] if 'importerPassword' in json_dict: @@ -30,7 +31,7 @@ def fromJson(cls, json_dict): return cls(fwoApiUri, fwoUserMgmtApiUri, fwoImporterPwd) def __str__(self): - return f"{self.FwoApiUri}, {self.FwoUserMgmtApi}, {self.ApiFetchSize}" - - def setImporterPwd(self, importerPassword): + return f"{self.FwoApiUri}, {self.FwoUserMgmtApi}, {self.ApiFetchSize}" # type: ignore + #TODO Mispell? FwoUserMgmtApi? + def setImporterPwd(self, importerPassword: str | None): self.ImporterPassword = importerPassword diff --git a/roles/importer/files/importer/model_controllers/gateway_controller.py b/roles/importer/files/importer/model_controllers/gateway_controller.py index b2e3d5b98e..7425f91e2d 100644 --- a/roles/importer/files/importer/model_controllers/gateway_controller.py +++ b/roles/importer/files/importer/model_controllers/gateway_controller.py @@ -6,13 +6,13 @@ def __init__(self, gw: Gateway): self.Gateway = gw @staticmethod - def replaceNoneWithEmpty(s): + def replaceNoneWithEmpty(s: str | None) -> str: if s is None or s == '': return '' else: return str(s) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, Gateway): return ( self.Name == other.Name and diff --git a/roles/importer/files/importer/model_controllers/import_state_controller.py b/roles/importer/files/importer/model_controllers/import_state_controller.py index 16212fec11..7eda358df1 100644 --- a/roles/importer/files/importer/model_controllers/import_state_controller.py +++ b/roles/importer/files/importer/model_controllers/import_state_controller.py @@ -1,5 +1,6 @@ import time from datetime import datetime, timezone +from typing import Any from dateutil import parser import urllib3 @@ -26,8 +27,8 @@ class ImportStateController(ImportState): api_call: FwoApiCall management_map: dict[str, int] # maps management uid to management id - def __init__(self, debugLevel, configChangedSinceLastImport, fwoConfig, mgmDetails, jwt, force, - version=8, isFullImport=False, isInitialImport=False, isClearingImport=False, verifyCerts=False, LastSuccessfulImport=None): + def __init__(self, debugLevel: int, configChangedSinceLastImport: bool, fwoConfig: FworchConfigController, mgmDetails: dict[str, Any], jwt: str, force: bool, + version: int, isFullImport: bool = False, isInitialImport: bool = False, isClearingImport: bool = False, verifyCerts: bool = False, LastSuccessfulImport: str | None = None): self.Stats = ImportStatisticsController() self.StartTime = int(time.time()) self.DebugLevel = debugLevel @@ -43,28 +44,28 @@ def __init__(self, debugLevel, configChangedSinceLastImport, fwoConfig, mgmDetai self.IsFullImport = isFullImport self.IsInitialImport = isInitialImport self.IsClearingImport = isClearingImport - self.RulbaseToGatewayMap = {} + self.RulbaseToGatewayMap: dict[int, list[int]] = {} self.LastSuccessfulImport = LastSuccessfulImport self.api_connection = FwoApi(fwoConfig.FwoApiUri, jwt) self.api_call = FwoApiCall(self.api_connection) - self.removed_rules_map = {} + self.removed_rules_map: dict[str, int] = {} def __str__(self): return f"{str(self.MgmDetails)}(import_id={self.ImportId})" - def setImportFileName(self, importFileName): + def setImportFileName(self, importFileName: str): self.ImportFileName = importFileName - def setImportId(self, importId): + def setImportId(self, importId: int): self.ImportId = importId - def increaseErrorCounter(self, errorNo): + def increaseErrorCounter(self, errorNo: int): self.Stats.ErrorCount = self.Stats.ErrorCount + errorNo def increaseErrorCounterByOne(self): self.increaseErrorCounter(1) - def appendErrorString(self, errorStr): + def appendErrorString(self, errorStr: str): self.Stats.ErrorDetails.append(errorStr) def getErrors(self): @@ -72,8 +73,8 @@ def getErrors(self): def getErrorString(self): return str(self.Stats.ErrorDetails) - - def addError(self, error, log=False): + + def addError(self, error: str, log: bool = False): self.increaseErrorCounterByOne() self.appendErrorString(str(error)) if log and not self.Stats.ErrorAlreadyLogged: @@ -83,31 +84,28 @@ def addError(self, error, log=False): @classmethod - def initializeImport(cls, mgmId, fwo_api_uri, jwt, - debugLevel=0, suppressCertWarnings=False, - sslVerification=False, force=False, version=8, - isClearingImport=False, isFullImport=False, isInitialImport=False, + def initializeImport(cls, mgmId: int | None, fwo_api_uri: str, jwt: str, + debugLevel: int = 0, suppressCertWarnings: bool = False, + sslVerification: bool = False, force: bool = False, version: int = 8, + isClearingImport: bool = False, isFullImport: bool = False, isInitialImport: bool = False, ): - def _check_input_parameters(mgmId): - if mgmId is None: - raise ValueError("parameter mgm_id is mandatory") - logger = getFwoLogger() - _check_input_parameters(mgmId) + if mgmId is None: + raise ValueError("parameter mgm_id is mandatory") fwoConfig = FworchConfigController.fromJson(readConfig(fwo_config_filename)) api_conn = FwoApi(ApiUri=fwoConfig.FwoApiUri, Jwt=jwt) api_call = FwoApiCall(api_conn) # set global https connection values - fwo_globals.set_global_values (suppress_cert_warnings_in=suppressCertWarnings, verify_certs_in=sslVerification, debug_level_in=debugLevel) + fwo_globals.set_global_values(suppress_cert_warnings_in=suppressCertWarnings, verify_certs_in=sslVerification, debug_level_in=debugLevel) if fwo_globals.suppress_cert_warnings: urllib3.disable_warnings() # suppress ssl warnings only try: # get mgm_details (fw-type, port, ip, user credentials): mgm_controller = ManagementController( - mgm_id=int(mgmId), uid='', devices={}, + mgm_id=int(mgmId), uid='', devices={}, #type: ignore # TODO: why int cast here? device_info=DeviceInfo(), connection_info=ConnectionInfo(), importer_hostname='', @@ -116,12 +114,12 @@ def _check_input_parameters(mgmId): domain_info=DomainInfo() ) mgmDetails = mgm_controller.get_mgm_details(api_conn, mgmId, debugLevel) - except Exception as e: + except Exception as _: logger.error(f"import_management - error while getting fw management details for mgm={str(mgmId)}: {str(traceback.format_exc())}") raise try: # get last import data - last_import_id, last_import_date = api_call.get_last_complete_import({"mgmId": int(mgmId)}, debug_level=0) + _, last_import_date = api_call.get_last_complete_import({"mgmId": int(mgmId)}, debug_level=0) except Exception: logger.error("import_management - error while getting last import data for mgm=" + str(mgmId) ) raise @@ -144,7 +142,7 @@ def _check_input_parameters(mgmId): result.getPastImportInfos() result.setCoreData() - if type(result) is str: + if type(result) is str: # type: ignore # TODO: This should never happen logger.error("error while getting import state") raise FwoImporterError("error while getting import state") @@ -200,7 +198,7 @@ def setCoreData(self): self.SetRulebaseMap(api_call) self.SetRuleMap(api_call) - def SetActionMap(self, api_call): + def SetActionMap(self, api_call: FwoApiCall): query = "query getActionMap { stm_action { action_name action_id allowed } }" try: result = api_call.call(query=query, query_variables={}) @@ -214,7 +212,7 @@ def SetActionMap(self, api_call): map.update({action['action_name']: action['action_id']}) self.Actions = map - def SetTrackMap(self, api_call): + def SetTrackMap(self, api_call: FwoApiCall): query = "query getTrackMap { stm_track { track_name track_id } }" try: result = api_call.call(query=query, query_variables={}) @@ -228,7 +226,7 @@ def SetTrackMap(self, api_call): track_map.update({track['track_name']: track['track_id']}) self.Tracks = track_map - def SetLinkTypeMap(self, api_call): + def SetLinkTypeMap(self, api_call: FwoApiCall): query = "query getLinkType { stm_link_type { id name } }" try: result = api_call.call(query=query, query_variables={}) @@ -242,7 +240,7 @@ def SetLinkTypeMap(self, api_call): link_map.update({track['name']: track['id']}) self.LinkTypes = link_map - def SetColorRefMap(self, api_call): + def SetColorRefMap(self, api_call: FwoApiCall): get_colors_query = FwoApi.get_graphql_code([graphql_query_path + "stmTables/getColors.graphql"]) try: @@ -263,7 +261,7 @@ def SetColorRefMap(self, api_call): # TODO: map update inconsistencies: import_state is global over all sub managers, so map needs to be updated for each sub manager # currently, this is done in fwconfig_import_rule. But what about other maps? - see #3646 # TODO: global rulebases not yet included - def SetRulebaseMap(self, api_call): + def SetRulebaseMap(self, api_call: FwoApiCall) -> None: logger = getFwoLogger() # TODO: maps need to be updated directly after data changes @@ -275,7 +273,7 @@ def SetRulebaseMap(self, api_call): self.RulebaseMap = {} raise - m = {} + m: dict[str, int] = {} for rulebase in result['data']['rulebase']: rbid = rulebase['id'] m.update({rulebase['uid']: rbid}) @@ -286,7 +284,7 @@ def SetRulebaseMap(self, api_call): # limited to the current mgm_id # creats a dict with key = rule.uid and value = rule.id # should be called sparsely, as there might be a lot of rules for a mgmt - def SetRuleMap(self, api_call): + def SetRuleMap(self, api_call: FwoApi) -> None: query = """query getRuleMap($mgmId: Int) { rule(where:{mgm_id: {_eq: $mgmId}, removed:{_is_null:true }}) { rule_id rule_uid } }""" try: result = api_call.call(query=query, query_variables= {"mgmId": self.MgmDetails.Id}) @@ -296,7 +294,7 @@ def SetRuleMap(self, api_call): self.RuleMap = {} raise - m = {} + m: dict[str, int] = {} for rule in result['data']['rule']: m.update({rule['rule_uid']: rule['rule_id']}) self.RuleMap = m @@ -304,7 +302,7 @@ def SetRuleMap(self, api_call): # getting all gateways (not limitited to the current mgm_id) to support super managements # creates a dict with key = gateway.uid and value = gateway.id # and also key = gateway.name and value = gateway.id - def SetGatewayMap(self, api_call): + def SetGatewayMap(self, api_call: FwoApiCall): query = """ query getGatewayMap { device { @@ -331,7 +329,7 @@ def SetGatewayMap(self, api_call): # getting all managements (not limitited to the current mgm_id) to support super managements # creates a dict with key = management.uid and value = management.id - def SetManagementMap(self, api_call): + def SetManagementMap(self, api_call: FwoApiCall): query = """ query getManagementMap($mgmId: Int!) { management(where: {mgm_id: {_eq: $mgmId}}) { @@ -360,10 +358,10 @@ def SetManagementMap(self, api_call): self.ManagementMap = m - def lookupRule(self, ruleUid): + def lookupRule(self, ruleUid: str) -> int | None: return self.RuleMap.get(ruleUid, None) - def lookupAction(self, actionStr): + def lookupAction(self, actionStr: str) -> int: action_id = self.Actions.get(actionStr.lower(), None) if action_id is None: logger = getFwoLogger() @@ -371,7 +369,7 @@ def lookupAction(self, actionStr): raise FwoImporterError(f"Action {actionStr} not found") return action_id - def lookupTrack(self, trackStr): + def lookupTrack(self, trackStr: str) -> int: track_id = self.Tracks.get(trackStr.lower(), None) if track_id is None: logger = getFwoLogger() @@ -379,7 +377,7 @@ def lookupTrack(self, trackStr): raise FwoImporterError(f"Track {trackStr} not found") return track_id - def lookupRulebaseId(self, rulebaseUid) -> int: + def lookupRulebaseId(self, rulebaseUid: str) -> int: rulebaseId = self.RulebaseMap.get(rulebaseUid, None) if rulebaseId is None: logger = getFwoLogger() @@ -387,10 +385,10 @@ def lookupRulebaseId(self, rulebaseUid) -> int: raise FwoImporterError(f"Rulebase {rulebaseUid} not found in {len(self.RulebaseMap)} known rulebases") return rulebaseId - def lookupLinkType(self, linkUid): + def lookupLinkType(self, linkUid: str) -> int: return self.LinkTypes.get(linkUid, -1) - def lookupGatewayId(self, gwUid: str) -> int|None: + def lookupGatewayId(self, gwUid: str) -> int | None: mgm_id = self.MgmDetails.CurrentMgmId gws_for_mgm = self.GatewayMap.get(mgm_id, {}) gw_id = gws_for_mgm.get(gwUid, None) @@ -406,14 +404,14 @@ def lookup_all_gateway_ids(self) -> list[int]: gw_ids = list(gws_for_mgm.values()) return gw_ids - def lookupManagementId(self, mgmUid): + def lookupManagementId(self, mgmUid: str) -> int | None: if not self.ManagementMap.get(mgmUid, None): logger = getFwoLogger() logger.error(f"fwo_api:import_latest_config - no mgm id found for current manager uid '{mgmUid}'") return self.ManagementMap.get(mgmUid, None) - def lookupColorId(self, color_str): + def lookupColorId(self, color_str: str) -> int: return self.ColorMap.get(color_str, 1) # 1 = forground color black diff --git a/roles/importer/files/importer/model_controllers/import_statistics_controller.py b/roles/importer/files/importer/model_controllers/import_statistics_controller.py index e63544b89d..06778d81c2 100644 --- a/roles/importer/files/importer/model_controllers/import_statistics_controller.py +++ b/roles/importer/files/importer/model_controllers/import_statistics_controller.py @@ -48,7 +48,7 @@ def getRuleChangeNumber(self): self.rule_enforce_change_count + self.rulebase_add_count + self.rulebase_change_count + self.rulebase_delete_count def getChangeDetails(self): - result = {} + result: dict[str, int] = {} self.collect_nw_obj_change_details(result) self.collect_svc_obj_change_details(result) self.collect_usr_obj_change_details(result) @@ -57,9 +57,7 @@ def getChangeDetails(self): return result - def collect_nw_obj_change_details(self, result): - if result is None: - result = {} + def collect_nw_obj_change_details(self, result: dict[str, int]): if self.NetworkObjectAddCount > 0: result['NetworkObjectAddCount'] = self.NetworkObjectAddCount if self.NetworkObjectDeleteCount > 0: @@ -68,9 +66,7 @@ def collect_nw_obj_change_details(self, result): result['NetworkObjectChangeCount'] = self.NetworkObjectChangeCount - def collect_svc_obj_change_details(self, result): - if result is None: - result = {} + def collect_svc_obj_change_details(self, result: dict[str, int]): if self.ServiceObjectAddCount > 0: result['ServiceObjectAddCount'] = self.ServiceObjectAddCount if self.ServiceObjectDeleteCount > 0: @@ -79,9 +75,7 @@ def collect_svc_obj_change_details(self, result): result['ServiceObjectChangeCount'] = self.ServiceObjectChangeCount - def collect_usr_obj_change_details(self, result): - if result is None: - result = {} + def collect_usr_obj_change_details(self, result: dict[str, int]): if self.UserObjectAddCount > 0: result['UserObjectAddCount'] = self.UserObjectAddCount if self.UserObjectDeleteCount > 0: @@ -90,9 +84,7 @@ def collect_usr_obj_change_details(self, result): result['UserObjectChangeCount'] = self.UserObjectChangeCount - def collect_zone_obj_change_details(self, result): - if result is None: - result = {} + def collect_zone_obj_change_details(self, result: dict[str, int]): if self.ZoneObjectAddCount > 0: result['ZoneObjectAddCount'] = self.ZoneObjectAddCount if self.ZoneObjectDeleteCount > 0: @@ -101,9 +93,7 @@ def collect_zone_obj_change_details(self, result): result['ZoneObjectChangeCount'] = self.ZoneObjectChangeCount - def collect_rule_change_details(self, result): - if result is None: - result = {} + def collect_rule_change_details(self, result: dict[str, int]): if self.RuleAddCount > 0: result['RuleAddCount'] = self.RuleAddCount if self.RuleDeleteCount > 0: diff --git a/roles/importer/files/importer/model_controllers/interface_controller.py b/roles/importer/files/importer/model_controllers/interface_controller.py index 47eaff8e56..b0e334e47d 100644 --- a/roles/importer/files/importer/model_controllers/interface_controller.py +++ b/roles/importer/files/importer/model_controllers/interface_controller.py @@ -1,9 +1,10 @@ +from typing import Any from fwo_log import getFwoLogger -from netaddr import IPAddress, IPNetwork +from netaddr import IPAddress class Interface: - def __init__(self, device_id, name, ip, netmask_bits, state_up=True, ip_version=4): + def __init__(self, device_id: int, name: str, ip: IPAddress, netmask_bits: int, state_up: bool = True, ip_version: int = 4): self.routing_device = int(device_id) # check if routing device id exists? self.name = str(name) @@ -26,7 +27,14 @@ def __init__(self, device_id, name, ip, netmask_bits, state_up=True, ip_version= class InterfaceSerializable(Interface): - def __init__(self, ifaceIn): + name : str + routing_device : int + ip : str + netmask_bits : int + state_up : bool + ip_version : int + #TYPING: check if these types are correct + def __init__(self, ifaceIn: dict[Any, Any] | Interface): if type(ifaceIn) is dict: self.name = ifaceIn['name'] self.routing_device = ifaceIn['routing_device'] diff --git a/roles/importer/files/importer/model_controllers/management_controller.py b/roles/importer/files/importer/model_controllers/management_controller.py index abdbc1a1d9..f3bb6dca2b 100644 --- a/roles/importer/files/importer/model_controllers/management_controller.py +++ b/roles/importer/files/importer/model_controllers/management_controller.py @@ -1,5 +1,6 @@ import hashlib from dataclasses import dataclass +from typing import Any from models.management import Management from fwo_exceptions import FwLoginFailed @@ -41,7 +42,7 @@ class DomainInfo: domain_uid: str = '' class ManagementController(Management): - def __init__(self, mgm_id: int, uid: str, devices: dict, device_info: DeviceInfo, + def __init__(self, mgm_id: int, uid: str, devices: list[dict[str, Any]], device_info: DeviceInfo, connection_info: ConnectionInfo, importer_hostname: str, credential_info: CredentialInfo, manager_info: ManagerInfo, domain_info: DomainInfo, import_disabled: bool = False): @@ -83,7 +84,7 @@ def __init__(self, mgm_id: int, uid: str, devices: dict, device_info: DeviceInfo self.DomainUid = domain_info.domain_uid @classmethod - def fromJson(cls, json_dict: dict): + def fromJson(cls, json_dict: dict[str, Any]) -> "ManagementController": device_info = DeviceInfo( name=json_dict['name'], type_name=json_dict['deviceType']['name'], @@ -152,16 +153,16 @@ def buildFwApiString(self): raise FwLoginFailed(f"Unsupported device type: {self.DeviceTypeName}") - def getDomainString(self): - return self.DomainUid if self.DomainUid != None else self.DomainName + def getDomainString(self) -> str: + return self.DomainUid if self.DomainUid != None else self.DomainName # type: ignore #TODO: check if None check is needed if yes, change type @classmethod def buildGatewayList(cls, mgmDetails: "ManagementController") -> list['Gateway']: - devs = [] + devs: list['Gateway'] = [] for dev in mgmDetails.Devices: # check if gateway import is enabled - if 'do_not_import' in dev and dev['do_not_import']: # TODO: get this key from the device + if 'do_not_import' in dev and dev['do_not_import']: continue devs.append(Gateway(Name = dev['name'], Uid = f"{dev['name']}/{mgmDetails.calcManagerUidHash()}")) return devs @@ -170,14 +171,14 @@ def buildGatewayList(cls, mgmDetails: "ManagementController") -> list['Gateway'] def calcManagerUidHash(self): combination = f""" {replaceNoneWithEmpty(self.Hostname)} - {replaceNoneWithEmpty(self.Port)} + {replaceNoneWithEmpty(str(self.Port))} {replaceNoneWithEmpty(self.DomainUid)} {replaceNoneWithEmpty(self.DomainName)} """ return hashlib.sha256(combination.encode()).hexdigest() - def get_mgm_details(self, api_conn, mgm_id, debug_level=0): + def get_mgm_details(self, api_conn: FwoApi, mgm_id: int, debug_level: int = 0) -> dict[str, Any]: service_provider = ServiceProvider() _global_state = service_provider.get_service(Services.GLOBAL_STATE) @@ -190,7 +191,7 @@ def get_mgm_details(self, api_conn, mgm_id, debug_level=0): graphql_query_path + "device/fragments/importCredentials.graphql"]) api_call_result = api_conn.call(getMgmDetailsQuery, query_variables={'mgmId': mgm_id }) - if api_call_result is None or 'data' not in api_call_result or 'management' not in api_call_result['data'] or len(api_call_result['data']['management'])<1: + if api_call_result is None or 'data' not in api_call_result or 'management' not in api_call_result['data'] or len(api_call_result['data']['management'])<1: #type: ignore #TODO: check if api_call_result can be None raise FwoApiFailure('did not succeed in getting management details from FWO API') if not '://' in api_call_result['data']['management'][0]['hostname']: diff --git a/roles/importer/files/importer/model_controllers/rollback.py b/roles/importer/files/importer/model_controllers/rollback.py index f355342cd7..a88a222374 100644 --- a/roles/importer/files/importer/model_controllers/rollback.py +++ b/roles/importer/files/importer/model_controllers/rollback.py @@ -21,7 +21,7 @@ def __init__(self): # also deletes latest_config for this management # TODO: also take super management id into account as second option - def rollbackCurrentImport(self) -> None: + def rollbackCurrentImport(self) -> None | int: logger = getFwoLogger() rollbackMutation = FwoApi.get_graphql_code([f"{fwo_const.graphql_query_path}import/rollbackImport.graphql"]) try: diff --git a/roles/importer/files/importer/model_controllers/route_controller.py b/roles/importer/files/importer/model_controllers/route_controller.py index 94bcd46e19..dfb38024cb 100644 --- a/roles/importer/files/importer/model_controllers/route_controller.py +++ b/roles/importer/files/importer/model_controllers/route_controller.py @@ -1,10 +1,13 @@ +from typing import Any from fwo_log import getFwoLogger from netaddr import IPAddress, IPNetwork +from model_controllers.interface_controller import InterfaceSerializable + class Route: - def __init__(self, device_id, target_gateway, destination, - static=True, source=None, interface=None, metric=None, distance=None, ip_version=4): + def __init__(self, device_id: int, target_gateway: str, destination: str, + static: bool = True, source: str | None = None, interface: str | None = None, metric: int | None = None, distance: int | None = None, ip_version: int = 4): self.routing_device = int(device_id) if interface is not None: self.interface = str(interface) @@ -41,7 +44,7 @@ def isDefaultRouteV6(self): return self.ip_version==6 and self.destination == IPNetwork('::/0') - def routeMatches(self, destination, dev_id): + def routeMatches(self, destination: str, dev_id: int) -> bool: ip_n = IPNetwork(self.destination).cidr dest_n = IPNetwork(destination).cidr return dev_id == self.routing_device and (ip_n in dest_n or dest_n in ip_n) @@ -52,7 +55,7 @@ def getRouteDestination(self): class RouteSerializable(Route): - def __init__(self, routeIn): + def __init__(self, routeIn: dict[str, Any] | Route): if type(routeIn) is dict: self.routing_device = routeIn['routing_device'] self.interface = routeIn['interface'] @@ -81,7 +84,7 @@ def __init__(self, routeIn): self.ip_version = routeIn.ip_version -def getRouteDestination(obj): +def getRouteDestination(obj: Route): return obj.destination @@ -101,7 +104,7 @@ def getRouteDestination(obj): # return default_route_v4.append(default_route_v6) -def get_matching_route_obj(destination_ip, routing_table, dev_id): +def get_matching_route_obj(destination_ip: str, routing_table: list[Route], dev_id: int) -> Route | None: logger = getFwoLogger() @@ -118,7 +121,7 @@ def get_matching_route_obj(destination_ip, routing_table, dev_id): return None -def get_ip_of_interface_obj(interface_name, dev_id, interface_list=[]): +def get_ip_of_interface_obj(interface_name: str | None, dev_id: int, interface_list: list[InterfaceSerializable]) -> str | None: interface_details = next((sub for sub in interface_list if sub.name == interface_name and sub.routing_device==dev_id), None) if interface_details is not None: diff --git a/roles/importer/files/importer/model_controllers/rule_enforced_on_gateway_controller.py b/roles/importer/files/importer/model_controllers/rule_enforced_on_gateway_controller.py index fe27604ef3..49682fef9a 100644 --- a/roles/importer/files/importer/model_controllers/rule_enforced_on_gateway_controller.py +++ b/roles/importer/files/importer/model_controllers/rule_enforced_on_gateway_controller.py @@ -3,17 +3,15 @@ import fwo_const from model_controllers.import_state_controller import ImportStateController -from models.rule_enforced_on_gateway import RuleEnforcedOnGateway from fwo_log import getFwoLogger from model_controllers.rulebase_link_controller import RulebaseLinkController -from models.rule import Rule class RuleEnforcedOnGatewayController: def __init__(self, import_state: ImportStateController): self.import_details: ImportStateController = import_state - def add_new_rule_enforced_on_gateway_refs(self, new_rules, import_state): + def add_new_rule_enforced_on_gateway_refs(self, new_rules: list[dict[str, Any]], import_state: ImportStateController): """ Main function to add new rule-to-gateway references. """ @@ -32,7 +30,7 @@ def add_new_rule_enforced_on_gateway_refs(self, new_rules, import_state): # Step 4: Insert the references into the database self.insert_rule_to_gateway_references(rule_to_gw_refs) - def initialize_rulebase_link_controller(self, import_state): + def initialize_rulebase_link_controller(self, import_state: ImportStateController) -> RulebaseLinkController: """ Initialize the RulebaseLinkController and set the map of enforcing gateways. """ @@ -40,14 +38,14 @@ def initialize_rulebase_link_controller(self, import_state): rb_link_controller.set_map_of_all_enforcing_gateway_ids_for_rulebase_id(import_state) return rb_link_controller - def prepare_rule_to_gateway_references(self, new_rules, rb_link_controller): + def prepare_rule_to_gateway_references(self, new_rules: list[dict[str, Any]], rb_link_controller: RulebaseLinkController) -> list[dict[str, Any]]: """ Prepare the list of rule-to-gateway references based on the rules and their 'install on' settings. """ - rule_to_gw_refs = [] + rule_to_gw_refs: list[dict[str, Any]] = [] for rule in new_rules: - if 'rule_installon' in rule: - if rule['rule_installon'] is None: + if 'rule_installon' in rule: # TODO rule should not be a dict + if rule['rule_installon'] is None: # TODO rule should not be a dict self.handle_rule_without_installon(rule, rb_link_controller, rule_to_gw_refs) else: self.handle_rule_with_installon(rule, rule_to_gw_refs) @@ -55,7 +53,7 @@ def prepare_rule_to_gateway_references(self, new_rules, rb_link_controller): def handle_rule_without_installon(self, - rule: dict, + rule: dict[str, Any], # TODO rule should not be a dict rb_link_controller: RulebaseLinkController, rule_to_gw_refs: list[dict[str, Any]] ) -> None: @@ -67,7 +65,7 @@ def handle_rule_without_installon(self, def handle_rule_with_installon(self, - rule: dict, + rule: dict[str, Any], # TODO rule should not be a dict rule_to_gw_refs: list[dict[str, Any]] ) -> None: """ @@ -82,7 +80,7 @@ def handle_rule_with_installon(self, logger.warning(f"Found a broken reference to a non-existing gateway (uid={gw_uid}). Ignoring.") - def create_rule_to_gateway_reference(self, rule, gw_id) -> dict[str, Any]: + def create_rule_to_gateway_reference(self, rule: dict[str, Any], gw_id: int) -> dict[str, Any]: """ Create a dictionary representing a rule-to-gateway reference. """ @@ -94,7 +92,7 @@ def create_rule_to_gateway_reference(self, rule, gw_id) -> dict[str, Any]: } - def insert_rule_to_gateway_references(self, rule_to_gw_refs): + def insert_rule_to_gateway_references(self, rule_to_gw_refs: list[dict[str, Any]]) -> None: """ Insert the rule-to-gateway references into the database. """ @@ -115,7 +113,7 @@ def insert_rule_to_gateway_references(self, rule_to_gw_refs): raise - def insert_rules_enforced_on_gateway(self, enforcements: list[dict]) -> dict[str, Any]: + def insert_rules_enforced_on_gateway(self, enforcements: list[dict[str, Any]]) -> dict[str, Any]: """ Insert rules enforced on gateways into the database. """ diff --git a/roles/importer/files/importer/model_controllers/rulebase_link_controller.py b/roles/importer/files/importer/model_controllers/rulebase_link_controller.py index 103c4cd58a..2f2c25564b 100644 --- a/roles/importer/files/importer/model_controllers/rulebase_link_controller.py +++ b/roles/importer/files/importer/model_controllers/rulebase_link_controller.py @@ -1,5 +1,6 @@ # from pydantic import BaseModel +from typing import Any from models.rulebase_link import RulebaseLink, parse_rulebase_links from model_controllers.import_state_controller import ImportStateController from fwo_log import getFwoLogger @@ -8,10 +9,10 @@ class RulebaseLinkController(): - rulbase_to_gateway_map: dict = {} + rulbase_to_gateway_map: dict[int, list[int]] = {} rb_links: list[RulebaseLink] - def insert_rulebase_links(self, import_state: ImportStateController, rb_links: list[RulebaseLink]): + def insert_rulebase_links(self, import_state: ImportStateController, rb_links: list[dict[str, Any]]) -> None: logger = getFwoLogger() query_variables = { "rulebaseLinks": rb_links } if len(rb_links) == 0: @@ -26,9 +27,9 @@ def insert_rulebase_links(self, import_state: ImportStateController, rb_links: l import_state.Stats.rulebase_link_add_count += changes - def remove_rulebase_links(self, import_state: ImportStateController, removed_rb_links_ids: list[int]): + def remove_rulebase_links(self, import_state: ImportStateController, removed_rb_links_ids: list[int | None]) -> None: logger = getFwoLogger() - query_variables = { "removedRulebaseLinks": removed_rb_links_ids, "importId": import_state.ImportId } + query_variables: dict[str, Any] = { "removedRulebaseLinks": removed_rb_links_ids, "importId": import_state.ImportId } if len(removed_rb_links_ids) == 0: return mutation = FwoApi.get_graphql_code([f"{fwo_const.graphql_query_path}rule/removeRulebaseLinks.graphql"]) @@ -51,7 +52,7 @@ def get_rulebase_links(self, import_state: ImportStateController): # we always need to provide gwIds since rulebase_links may be duplicate across different gateways query_variables = { "gwIds": gw_ids} - query = FwoApi.get_graphql_code([f"{fwo_const.graphql_query_path}rule/getRulebaseLinks.graphql"]) + query = FwoApi.get_graphql_code(file_list=[f"{fwo_const.graphql_query_path}rule/getRulebaseLinks.graphql"]) links = import_state.api_call.call(query, query_variables=query_variables) if 'errors' in links: import_state.Stats.addError(f"fwo_api:getRulebaseLinks - error while getting rulebaseLinks: {str(links['errors'])}") @@ -74,6 +75,6 @@ def set_map_of_all_enforcing_gateway_ids_for_rulebase_id(self, importState: Impo self.rulbase_to_gateway_map[rulebase_id].append(gw_id) - def get_gw_ids_for_rulebase_id(self, rulebase_id): + def get_gw_ids_for_rulebase_id(self, rulebase_id: int) -> list[int]: return self.rulbase_to_gateway_map.get(rulebase_id, []) diff --git a/roles/importer/files/importer/model_controllers/rulebase_link_map.py b/roles/importer/files/importer/model_controllers/rulebase_link_map.py index e8f6c7b56d..9b083174cd 100644 --- a/roles/importer/files/importer/model_controllers/rulebase_link_map.py +++ b/roles/importer/files/importer/model_controllers/rulebase_link_map.py @@ -1,17 +1,14 @@ +from typing import Any from fwo_log import getFwoLogger -from models.rulebase_link import RulebaseLink from model_controllers.import_state_controller import ImportStateController -from models.import_state import ImportState -from model_controllers.import_statistics_controller import ImportStatisticsController -from fwo_api_call import FwoApiCall class RulebaseLinkMap(): - def getRulebaseLinks(self, importState: ImportStateController, gwIds: list[int] = []): + def getRulebaseLinks(self, importState: ImportStateController, gwIds: list[int] = []) -> list[dict[str, Any]]: logger = getFwoLogger() query_variables = { "gwIds": gwIds} - rbLinks = [] + rbLinks: list[dict[str, Any]] = [] query = """ query getRulebaseLinks($gwIds: [Int!]) { @@ -35,6 +32,6 @@ def getRulebaseLinks(self, importState: ImportStateController, gwIds: list[int] # TODO: implement SetMapOfAllEnforcingGatewayIdsForRulebaseId - def GetGwIdsForRulebaseId(self, rulebaseId, importState: ImportStateController): + def GetGwIdsForRulebaseId(self, rulebaseId: int, importState: ImportStateController) -> list[int]: return importState.RulbaseToGatewayMap.get(rulebaseId, []) \ No newline at end of file diff --git a/roles/importer/files/importer/models/caseinsensitiveenum.py b/roles/importer/files/importer/models/caseinsensitiveenum.py index e306d92a1d..de73231a81 100644 --- a/roles/importer/files/importer/models/caseinsensitiveenum.py +++ b/roles/importer/files/importer/models/caseinsensitiveenum.py @@ -2,7 +2,7 @@ class CaseInsensitiveEnum(str, Enum): @classmethod - def _missing_(cls, value): + def _missing_(cls, value: object) -> object | None: if isinstance(value, str): s = value.strip() for member in cls: diff --git a/roles/importer/files/importer/models/fwconfig.py b/roles/importer/files/importer/models/fwconfig.py index eacb23fa59..be91e9647a 100644 --- a/roles/importer/files/importer/models/fwconfig.py +++ b/roles/importer/files/importer/models/fwconfig.py @@ -1,3 +1,4 @@ +from typing import Any from pydantic import BaseModel from fwo_base import ConfFormat @@ -8,4 +9,4 @@ """ class FwConfig(BaseModel): ConfigFormat: ConfFormat - FwConf: dict + FwConf: dict[str, Any] = {} diff --git a/roles/importer/files/importer/models/fwconfig_base.py b/roles/importer/files/importer/models/fwconfig_base.py index f078171e78..3e2095faff 100644 --- a/roles/importer/files/importer/models/fwconfig_base.py +++ b/roles/importer/files/importer/models/fwconfig_base.py @@ -5,7 +5,7 @@ class FwoEncoder(json.JSONEncoder): - def default(self, obj): + def default(self, obj: object) -> object: if isinstance(obj, ConfigAction) or isinstance(obj, ConfFormat): return obj.name diff --git a/roles/importer/files/importer/models/fwconfig_normalized.py b/roles/importer/files/importer/models/fwconfig_normalized.py index c6766c7281..a4dee51a2f 100644 --- a/roles/importer/files/importer/models/fwconfig_normalized.py +++ b/roles/importer/files/importer/models/fwconfig_normalized.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Any from pydantic import BaseModel from fwo_base import ConfigAction, ConfFormat @@ -49,8 +49,8 @@ class FwConfigNormalized(FwConfig): action: ConfigAction = ConfigAction.INSERT network_objects: dict[str, NetworkObject] = {} service_objects: dict[str, ServiceObject] = {} - users: dict = {} - zone_objects: dict = {} + users: dict[str, Any] = {} + zone_objects: dict[str, Any] = {} rulebases: list[Rulebase] = [] gateways: list[Gateway] = [] ConfigFormat: ConfFormat = ConfFormat.NORMALIZED_LEGACY @@ -61,7 +61,19 @@ class FwConfigNormalized(FwConfig): } - def get_rulebase(self, rulebaseUid: str) -> Optional[Rulebase]: + def get_rulebase(self, rulebaseUid: str) -> Rulebase: + """ + get the policy with a specific uid + :param policyUid: The UID of the relevant policy. + :return: Returns the policy with a specific uid, otherwise returns None. + """ + rulebase = self.get_rulebase_or_none(rulebaseUid) + if rulebase is not None: + return rulebase + + raise KeyError(f"Rulebase with UID {rulebaseUid} not found.") + + def get_rulebase_or_none(self, rulebaseUid: str) -> Rulebase | None: """ get the policy with a specific uid :param policyUid: The UID of the relevant policy. diff --git a/roles/importer/files/importer/models/fwconfigmanagerlist.py b/roles/importer/files/importer/models/fwconfigmanagerlist.py index f43c72f986..131fbbe178 100644 --- a/roles/importer/files/importer/models/fwconfigmanagerlist.py +++ b/roles/importer/files/importer/models/fwconfigmanagerlist.py @@ -12,7 +12,7 @@ class FwConfigManagerList(BaseModel): ConfigFormat: ConfFormat = ConfFormat.NORMALIZED ManagerSet: list[FwConfigManager] = [] - native_config: dict[str,Any] = {} # native config as dict, if available + native_config: dict[str,Any] | None = {} # native config as dict, if available # TODO: change inital value to None? model_config = { "arbitrary_types_allowed": True diff --git a/roles/importer/files/importer/models/fworch_config.py b/roles/importer/files/importer/models/fworch_config.py index 84c78534ce..a2887c561d 100644 --- a/roles/importer/files/importer/models/fworch_config.py +++ b/roles/importer/files/importer/models/fworch_config.py @@ -4,6 +4,6 @@ """ class FworchConfig(): FwoApiUri: str - FwoUserMgmtApiUri: str + FwoUserMgmtApiUri: str | None ApiFetchSize: int - ImporterPassword: str + ImporterPassword: str | None diff --git a/roles/importer/files/importer/models/gateway.py b/roles/importer/files/importer/models/gateway.py index 4a112f995a..992fd3ccee 100644 --- a/roles/importer/files/importer/models/gateway.py +++ b/roles/importer/files/importer/models/gateway.py @@ -1,3 +1,4 @@ +from typing import Any from pydantic import BaseModel from models.rulebase_link import RulebaseLinkUidBased @@ -15,8 +16,8 @@ class Gateway(BaseModel): Uid: str|None = None Name: str|None = None - Routing: list[dict] = [] - Interfaces: list[dict] = [] + Routing: list[dict[str, Any]] = [] + Interfaces: list[dict[str, Any]] = [] RulebaseLinks: list[RulebaseLinkUidBased] = [] GlobalPolicyUid: str|None = None EnforcedPolicyUids: list[str]|None = [] diff --git a/roles/importer/files/importer/models/import_state.py b/roles/importer/files/importer/models/import_state.py index e7cfcdcd30..c990832993 100644 --- a/roles/importer/files/importer/models/import_state.py +++ b/roles/importer/files/importer/models/import_state.py @@ -13,7 +13,7 @@ class ImportState(): MgmDetails: ManagementController ImportId: int ImportFileName: str - ForceImport: str + ForceImport: bool ImportVersion: int DataRetentionDays: int DaysSinceLastFullImport: int diff --git a/roles/importer/files/importer/models/management.py b/roles/importer/files/importer/models/management.py index 60977cdcd8..3f359fa5e0 100644 --- a/roles/importer/files/importer/models/management.py +++ b/roles/importer/files/importer/models/management.py @@ -1,3 +1,4 @@ +from typing import Any from pydantic import Field class Management(): @@ -7,7 +8,7 @@ class Management(): IsSuperManager: bool = Field(description="Indicates if the management is a super manager") Hostname: str = Field(description="Hostname of the management server") ImportDisabled: bool = Field(description="Indicates if import is disabled for the management") - Devices: dict = Field(description="Dictionary of devices managed by this entity") + Devices: list[dict[str, Any]] = Field(description="Dictionary of devices managed by this entity") ImporterHostname: str = Field(description="Hostname of the machine running the importer") DeviceTypeName: str = Field(description="Name of the device type") DeviceTypeVersion: str = Field(description="Version of the device type") diff --git a/roles/importer/files/importer/models/networkobject.py b/roles/importer/files/importer/models/networkobject.py index cfac79a732..8e4bfee71e 100644 --- a/roles/importer/files/importer/models/networkobject.py +++ b/roles/importer/files/importer/models/networkobject.py @@ -1,3 +1,4 @@ +from typing import Any from pydantic import BaseModel, field_validator, field_serializer from netaddr import IPNetwork, AddrFormatError @@ -14,7 +15,7 @@ class NetworkObject(BaseModel): @field_validator('obj_ip', 'obj_ip_end', mode='before') - def convert_strings_to_ip_objects(cls, value, info): + def convert_strings_to_ip_objects(cls, value: object, info: Any) -> IPNetwork | None: """ Convert string values to IPNetwork objects, treating 'None' or empty as None. """ @@ -33,7 +34,7 @@ def convert_strings_to_ip_objects(cls, value, info): raise ValueError(f"Invalid {info.field_name} network format: {value}") from e @field_serializer('obj_ip', 'obj_ip_end') - def serialize_ipnetwork(self, value: IPNetwork | None, _info): + def serialize_ipnetwork(self, value: IPNetwork | None, _info: Any) -> str | None: """ Serialize IPNetwork objects to strings, keeping None as None. """ @@ -73,8 +74,8 @@ def __init__(self, nwObject: NetworkObject, mgmId: int, importId: int, colorId: self.obj_last_seen = importId self.obj_typ_id = typId - def toDict (self): - result = { + def toDict (self) -> dict[str, Any]: + result: dict[str, Any] = { 'obj_uid': self.obj_uid, 'obj_name': self.obj_name, 'obj_color_id': self.obj_color_id, diff --git a/roles/importer/files/importer/models/rule.py b/roles/importer/files/importer/models/rule.py index 214d21b4d8..479fb8b489 100644 --- a/roles/importer/files/importer/models/rule.py +++ b/roles/importer/files/importer/models/rule.py @@ -63,7 +63,7 @@ class RuleNormalized(BaseModel): rule_dst_zone: str|None = None rule_head_text: str|None = None - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if not isinstance(other, RuleNormalized): return NotImplemented # Compare all fields except 'last_hit' and 'rule_num' diff --git a/roles/importer/files/importer/models/rule_enforced_on_gateway.py b/roles/importer/files/importer/models/rule_enforced_on_gateway.py index 0fbba9627a..a75c4f992f 100644 --- a/roles/importer/files/importer/models/rule_enforced_on_gateway.py +++ b/roles/importer/files/importer/models/rule_enforced_on_gateway.py @@ -1,5 +1,5 @@ +from typing import Any from pydantic import BaseModel -from model_controllers.import_state_controller import ImportStateController # the model for a connection between a rule and a gateway @@ -23,7 +23,7 @@ def __init__(self, rule_id: int, dev_id: int, created: int|None = None, removed: self.removed=removed - def to_dict(self): + def to_dict(self) -> dict[str, Any]: return { "rule_id": self.rule_id, "dev_id": self.dev_id, diff --git a/roles/importer/files/importer/models/rule_metadatum.py b/roles/importer/files/importer/models/rule_metadatum.py index 6a0536430f..292d136116 100644 --- a/roles/importer/files/importer/models/rule_metadatum.py +++ b/roles/importer/files/importer/models/rule_metadatum.py @@ -1,5 +1,4 @@ from pydantic import BaseModel -from models.caseinsensitiveenum import CaseInsensitiveEnum # Create table "rule_metadata" # ( diff --git a/roles/importer/files/importer/models/rulebase.py b/roles/importer/files/importer/models/rulebase.py index 2ee5f5507d..4c02e8c119 100644 --- a/roles/importer/files/importer/models/rulebase.py +++ b/roles/importer/files/importer/models/rulebase.py @@ -1,4 +1,4 @@ -from models.rule import RuleNormalized, Rule +from models.rule import Rule, RuleNormalized from pydantic import BaseModel # Rulebase is the model for a rulebase (containing no DB IDs) diff --git a/roles/importer/files/importer/models/rulebase_link.py b/roles/importer/files/importer/models/rulebase_link.py index 3431252f78..0c298c0db1 100644 --- a/roles/importer/files/importer/models/rulebase_link.py +++ b/roles/importer/files/importer/models/rulebase_link.py @@ -1,3 +1,4 @@ +from typing import Any from pydantic import BaseModel, TypeAdapter @@ -12,7 +13,7 @@ class RulebaseLinkUidBased(BaseModel): is_section: bool - def toDict(self): + def toDict(self) -> dict[str, object|str|bool|None]: return { "from_rule_uid": self.from_rule_uid, "from_rulebase_uid": self.from_rulebase_uid, @@ -42,7 +43,7 @@ class Config: populate_by_name = True - def toDict(self): + def toDict(self) -> dict[str, Any]: return { "gw_id": self.gw_id, "from_rule_id": self.from_rule_id, @@ -57,7 +58,7 @@ def toDict(self): } -def parse_rulebase_links(data: list[dict]) -> list[RulebaseLink]: +def parse_rulebase_links(data: list[dict[str, Any]]) -> list[RulebaseLink]: adapter = TypeAdapter(list[RulebaseLink]) return adapter.validate_python(data) diff --git a/roles/importer/files/importer/models/serviceobject.py b/roles/importer/files/importer/models/serviceobject.py index 59ed6ef346..226a4b562e 100644 --- a/roles/importer/files/importer/models/serviceobject.py +++ b/roles/importer/files/importer/models/serviceobject.py @@ -1,5 +1,5 @@ +from typing import Any from pydantic import BaseModel -import json class ServiceObject(BaseModel): @@ -54,7 +54,7 @@ def __init__(self, svcObject: ServiceObject, mgmId: int, importId: int, colorId: self.svc_last_seen = importId - def toDict (self): + def toDict(self) -> dict[str, Any]: return { 'svc_uid': self.svc_uid, 'svc_name': self.svc_name, diff --git a/roles/importer/files/importer/query_analyzer.py b/roles/importer/files/importer/query_analyzer.py index a554eabb7c..ca38b1b415 100644 --- a/roles/importer/files/importer/query_analyzer.py +++ b/roles/importer/files/importer/query_analyzer.py @@ -2,12 +2,12 @@ # GraphQL-core v3+ from graphql import parse, print_ast, visit from graphql.language import Visitor - from graphql.language.ast import DocumentNode as Document, VariableDefinitionNode as VariableDefinition, OperationDefinitionNode as OperationDefinition + from graphql.language.ast import DocumentNode as Document, VariableDefinitionNode as VariableDefinition, OperationDefinitionNode as OperationDefinition # type: ignore except ImportError: # GraphQL-core v2 from graphql import parse, print_ast, visit from graphql.language.visitor import Visitor - from graphql.language.ast import Document, VariableDefinition, OperationDefinition + from graphql.language.ast import Document, VariableDefinition, OperationDefinition # type: ignore from typing import Any @@ -31,9 +31,9 @@ def variable_definitions(self) -> dict[str, dict[str, Any]]: return self._variable_definitions @property - def ast(self) -> Document|None: + def ast(self) -> Document|None: # type: ignore """Returns the AST.""" - return self._ast + return self._ast # type: ignore @property def query_string(self) -> str: @@ -44,11 +44,6 @@ def query_string(self) -> str: def query_variables(self) -> dict[str, Any]: """Returns the provided query variables.""" return self._query_variables - - @property - def query_variables(self) -> dict[str, Any]: - """Returns a dictionary that provides all information about the query and the provided query variables.""" - return self._query_variables def __init__(self): @@ -73,7 +68,7 @@ def analyze_payload(self, query_string: str, query_variables: dict[str, Any]|Non # Apply visitor pattern (calls enter_* methods) - visit(self._ast, self) + visit(self._ast, self) # type: ignore # Analyze necessity of chunking and parameters that are necessary for the chunking process. @@ -88,7 +83,7 @@ def analyze_payload(self, query_string: str, query_variables: dict[str, Any]|Non return self._query_info - def get_adjusted_chunk_size(self, lists_in_query_variable: dict): + def get_adjusted_chunk_size(self, lists_in_query_variable: dict[str, Any]) -> int: """ Gets an adjusted chunk size. """ @@ -103,38 +98,38 @@ def get_adjusted_chunk_size(self, lists_in_query_variable: dict): ) or 1 - def enter_OperationDefinition(self, node: OperationDefinition, *_): + def enter_OperationDefinition(self, node: OperationDefinition, *_): # type: ignore """ Called by visit function for each variable definition in the AST. """ - - self.enter_operation_definition(node) + + self.enter_operation_definition(node) # type: ignore - def enter_VariableDefinition(self, node: VariableDefinition, *_): + def enter_VariableDefinition(self, node: VariableDefinition, *_): # type: ignore """ Called by visit function for each variable definition in the AST. """ - - self.enter_variable_definition(node) + + self.enter_variable_definition(node) # type: ignore - def enter_operation_definition(self, node: OperationDefinition, *_): + def enter_operation_definition(self, node: OperationDefinition, *_): # type: ignore """ Called by visit function for each variable definition in the AST. """ - self._query_info["query_type"] = node.operation - self._query_info["query_name"] = node.name.value if node.name else "" + self._query_info["query_type"] = node.operation # type: ignore + self._query_info["query_name"] = node.name.value if node.name else "" # type: ignore - def enter_variable_definition(self, node: VariableDefinition, *_): + def enter_variable_definition(self, node: VariableDefinition, *_): # type: ignore """ Called by visit function for each variable definition in the AST. """ - var_name = node.variable.name.value - type_str = print_ast(node.type) + var_name = node.variable.name.value # type: ignore + type_str = print_ast(node.type) # type: ignore # Store information about the variable definitions. @@ -155,13 +150,14 @@ def enter_variable_definition(self, node: VariableDefinition, *_): self._variable_definitions[var_name]["provided_value"] = self._query_variables[var_name] - def _get_chunking_info(self, query_variables: dict): + def _get_chunking_info(self, query_variables: dict[str, Any] | None) -> tuple[bool, int, int, list[str]]: # Get all query variables of type list. + query_vars = query_variables or {} - lists_in_query_variable = { + lists_in_query_variable: dict[str, Any] = { chunkable_variable_name: list_object - for chunkable_variable_name, list_object in query_variables.items() + for chunkable_variable_name, list_object in query_vars.items() if isinstance(list_object, list) } diff --git a/roles/importer/files/importer/services/group_flats_mapper.py b/roles/importer/files/importer/services/group_flats_mapper.py index 10623f299a..1af4004997 100644 --- a/roles/importer/files/importer/services/group_flats_mapper.py +++ b/roles/importer/files/importer/services/group_flats_mapper.py @@ -3,10 +3,9 @@ if TYPE_CHECKING: from models.fwconfig_normalized import FwConfigNormalized - + from model_controllers.import_state_controller import ImportStateController import fwo_const from fwo_log import getFwoLogger -from model_controllers.import_state_controller import ImportStateController from services.service_provider import ServiceProvider from services.enums import Services @@ -19,7 +18,7 @@ class GroupFlatsMapper: This class is responsible for mapping group objects to their fully resolved members. """ - import_state: ImportStateController + import_state: 'ImportStateController' normalized_config: FwConfigNormalized|None = None global_normalized_config: FwConfigNormalized|None = None diff --git a/roles/importer/files/importer/services/uid2id_mapper.py b/roles/importer/files/importer/services/uid2id_mapper.py index b2d621271d..d513126189 100644 --- a/roles/importer/files/importer/services/uid2id_mapper.py +++ b/roles/importer/files/importer/services/uid2id_mapper.py @@ -1,6 +1,8 @@ from logging import Logger +from typing import TYPE_CHECKING, Any from fwo_log import getFwoLogger -from model_controllers.import_state_controller import ImportStateController +if TYPE_CHECKING: + from model_controllers.import_state_controller import ImportStateController from fwo_exceptions import FwoImporterError from services.service_provider import ServiceProvider from services.enums import Services @@ -49,7 +51,7 @@ class Uid2IdMapper: This class is used to maintain a mapping between UID and relevant ID in the database. """ - import_state: ImportStateController + import_state: 'ImportStateController' logger: Logger nwobj_uid2id: Uid2IdMap @@ -59,11 +61,8 @@ class Uid2IdMapper: rule_uid2id: Uid2IdMap @property - def api_connection(self): - if self.import_state is None: - return None - else: - return self.import_state.api_connection + def api_connection(self) -> FwoApi: + return self.import_state.api_connection def __init__(self): """ @@ -179,7 +178,7 @@ def get_rule_id(self, uid: str, before_update: bool = False) -> int: raise KeyError(f"Rule UID '{uid}' not found in mapping.") return rule_id - def add_network_object_mappings(self, mappings: list[dict], is_global=False): + def add_network_object_mappings(self, mappings: list[dict[str, Any]], is_global: bool = False): """ Add network object mappings to the internal mapping dictionary. @@ -195,7 +194,7 @@ def add_network_object_mappings(self, mappings: list[dict], is_global=False): msg = f"Added {len(mappings)} {'global ' if is_global else ''}network object mappings." self.log_debug(msg) - def add_service_object_mappings(self, mappings: list[dict], is_global=False): + def add_service_object_mappings(self, mappings: list[dict[str, Any]], is_global: bool = False): """ Add service object mappings to the internal mapping dictionary. @@ -210,7 +209,7 @@ def add_service_object_mappings(self, mappings: list[dict], is_global=False): self.log_debug(f"Added {len(mappings)} {'global ' if is_global else ''}service object mappings.") - def add_user_mappings(self, mappings: list[dict], is_global=False): + def add_user_mappings(self, mappings: list[dict[str, Any]], is_global: bool = False): """ Add user object mappings to the internal mapping dictionary. @@ -225,7 +224,7 @@ def add_user_mappings(self, mappings: list[dict], is_global=False): self.log_debug(f"Added {len(mappings)} {'global ' if is_global else ''}user mappings.") - def add_zone_mappings(self, mappings: list[dict], is_global=False): + def add_zone_mappings(self, mappings: list[dict[str, Any]], is_global: bool = False): """ Add zone object mappings to the internal mapping dictionary. @@ -240,7 +239,7 @@ def add_zone_mappings(self, mappings: list[dict], is_global=False): self.log_debug(f"Added {len(mappings)} {'global ' if is_global else ''}zone mappings.") - def add_rule_mappings(self, mappings: list[dict]): + def add_rule_mappings(self, mappings: list[dict[str, Any]]): """ Add rule mappings to the internal mapping dictionary. @@ -255,7 +254,7 @@ def add_rule_mappings(self, mappings: list[dict]): self.log_debug(f"Added {len(mappings)} rule mappings.") - def update_network_object_mapping(self, uids: list[str]|None = None, is_global=False): + def update_network_object_mapping(self, uids: list[str]|None = None, is_global: bool = False): """ Update the mapping for network objects based on the provided UIDs. @@ -283,8 +282,8 @@ def update_network_object_mapping(self, uids: list[str]|None = None, is_global=F self.log_debug(f"Network object mapping updated for {len(response['data']['object'])} objects") except Exception as e: raise FwoImporterError(f"Error updating network object mapping: {e}") - - def update_service_object_mapping(self, uids: list[str]|None = None, is_global=False): + + def update_service_object_mapping(self, uids: list[str]|None = None, is_global: bool = False): """ Update the mapping for service objects based on the provided UIDs. @@ -311,8 +310,8 @@ def update_service_object_mapping(self, uids: list[str]|None = None, is_global=F self.log_debug(f"Service object mapping updated for {len(response['data']['service'])} objects") except Exception as e: raise FwoImporterError(f"Error updating service object mapping: {e}") - - def update_user_mapping(self, uids: list[str]|None = None, is_global=False): + + def update_user_mapping(self, uids: list[str]|None = None, is_global: bool = False): """ Update the mapping for users based on the provided UIDs. @@ -340,7 +339,7 @@ def update_user_mapping(self, uids: list[str]|None = None, is_global=False): except Exception as e: raise FwoImporterError(f"Error updating user mapping: {e}") - def update_zone_mapping(self, names: list[str]|None = None, is_global=False): + def update_zone_mapping(self, names: list[str]|None = None, is_global: bool = False): """ Update the mapping for zones based on the provided names.