From 3051b6b5748ccc8b2466e58b0e4b5d126c48dec6 Mon Sep 17 00:00:00 2001 From: jinyoungmoonDEV Date: Mon, 2 Dec 2024 16:53:11 +0900 Subject: [PATCH] fix: fix vulnerable_ports method --- .../manager/ec2/security_group_manager.py | 34 +++++++++++-------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/src/plugin/manager/ec2/security_group_manager.py b/src/plugin/manager/ec2/security_group_manager.py index b31e710..063fac1 100644 --- a/src/plugin/manager/ec2/security_group_manager.py +++ b/src/plugin/manager/ec2/security_group_manager.py @@ -65,7 +65,7 @@ def create_cloud_service(self, region, options, secret_data, schema): for _ip_range in in_rule.get("IpRanges", []): in_rule_copy = copy.deepcopy(in_rule) inbound_rules.append( - self.custom_security_group_rule_info( + self.custom_security_group_inbound_rule_info( in_rule_copy, _ip_range, "ip_ranges",vulnerable_ports ) ) @@ -73,7 +73,7 @@ def create_cloud_service(self, region, options, secret_data, schema): for _user_group_pairs in in_rule.get("UserIdGroupPairs", []): in_rule_copy = copy.deepcopy(in_rule) inbound_rules.append( - self.custom_security_group_rule_info( + self.custom_security_group_inbound_rule_info( in_rule_copy, _user_group_pairs, "user_id_group_pairs", @@ -84,7 +84,7 @@ def create_cloud_service(self, region, options, secret_data, schema): for _ip_v6_range in in_rule.get("Ipv6Ranges", []): in_rule_copy = copy.deepcopy(in_rule) inbound_rules.append( - self.custom_security_group_rule_info( + self.custom_security_group_inbound_rule_info( in_rule_copy, _ip_v6_range, "ipv6_ranges",vulnerable_ports ) ) @@ -96,7 +96,7 @@ def create_cloud_service(self, region, options, secret_data, schema): out_rule_copy = copy.deepcopy(out_rule) outbound_rules.append( self.custom_security_group_rule_info( - out_rule_copy, _ip_range, "ip_ranges",vulnerable_ports + out_rule_copy, _ip_range, "ip_ranges" ) ) @@ -106,7 +106,7 @@ def create_cloud_service(self, region, options, secret_data, schema): self.custom_security_group_rule_info( out_rule_copy, _user_group_pairs, - "user_id_group_pairs",vulnerable_ports, + "user_id_group_pairs", ) ) @@ -114,7 +114,7 @@ def create_cloud_service(self, region, options, secret_data, schema): out_rule_copy = copy.deepcopy(out_rule) outbound_rules.append( self.custom_security_group_rule_info( - out_rule_copy, _ip_v6_range, "ipv6_ranges",vulnerable_ports + out_rule_copy, _ip_v6_range, "ipv6_ranges" ) ) @@ -165,7 +165,16 @@ def create_cloud_service(self, region, options, secret_data, schema): region_name=region, ) - def custom_security_group_rule_info(self, raw_rule, remote, remote_type, vulnerable_ports): + def custom_security_group_inbound_rule_info(self, raw_rule, remote, remote_type, vulnerable_ports): + raw_rule = self.custom_security_group_rule_info(raw_rule, remote, remote_type) + + protocol_display = raw_rule.get("protocol_display") + + raw_rule.update({"vulnerable_ports": self._get_vulnerable_ports(protocol_display, raw_rule, vulnerable_ports)}) + + return raw_rule + + def custom_security_group_rule_info(self, raw_rule, remote, remote_type): protocol_display = self._get_protocol_display(raw_rule.get("IpProtocol")) raw_rule.update( { @@ -174,7 +183,6 @@ def custom_security_group_rule_info(self, raw_rule, remote, remote_type, vulnera "source_display": self._get_source_display(remote), "description_display": self._get_description_display(remote), remote_type: remote, - "vulnerable_ports": self._get_vulnerable_ports(protocol_display, raw_rule, vulnerable_ports) } ) @@ -296,8 +304,10 @@ def get_instance_name_from_tags(instance): @staticmethod def _get_vulnerable_ports(protocol_display: str, raw_rule: dict, vulnerable_ports: str): try: + ports = [int(port.strip()) for port in vulnerable_ports.split(',')] + if protocol_display == "ALL": - return [int(port.strip()) for port in vulnerable_ports.split(',')] + return ports to_port = raw_rule.get("ToPort") from_port = raw_rule.get("FromPort") @@ -305,10 +315,6 @@ def _get_vulnerable_ports(protocol_display: str, raw_rule: dict, vulnerable_port if to_port is None or from_port is None: return [] - return [ - int(port.strip()) - for port in vulnerable_ports.split(',') - if from_port <= int(port.strip()) <= to_port - ] + return [port for port in ports if from_port <= port <= to_port] except ValueError: raise ERROR_VULNERABLE_PORTS(vulnerable_ports)