@@ -1376,84 +1376,114 @@ def test_resource_combination_name(
13761376 assert_that (combination_name ).is_equal_to (expected_combination_name )
13771377
13781378
1379- def test_storage_security_group_deduplication (mocker , test_datadir ):
1380- """
1381- Test that storage security group rules are deduplicated when head, compute, and login nodes share the same SG.
1382-
1383- When head node, compute nodes, and login nodes all use the same security group (sg-12345678),
1384- only one set of ingress/egress rules should be created for that security group, not three separate sets.
1385- """
1386- mock_aws_api (mocker )
1387- mock_bucket (mocker )
1388- mock_bucket_object_utils (mocker )
1389-
1390- input_yaml = load_yaml_dict (test_datadir / "config-shared-sg.yaml" )
1391- cluster_config = ClusterSchema (cluster_name = "clustername" ).load (input_yaml )
1392-
1393- generated_template , _ = CDKTemplateBuilder ().build_cluster_template (
1394- cluster_config = cluster_config , bucket = dummy_cluster_bucket (), stack_name = "clustername"
1395- )
1396-
1397- # The EFS storage security group must have 2 ingress rules:
1398- # * allow traffic from cluster nodes (port 2049)
1399- # * allow traffic from storage nodes (all traffic)
1400- efs_sg_ingress_rules = [
1401- (name , resource )
1402- for name , resource in generated_template ["Resources" ].items ()
1403- if resource ["Type" ] == "AWS::EC2::SecurityGroupIngress" and name .startswith ("EFS" ) and "SecurityGroup" in name
1404- ]
1405- assert_that (len (efs_sg_ingress_rules )).is_equal_to (2 )
1406-
1407- # The FSx Lustre storage security group must have 3 rules:
1408- # * allow traffic from cluster nodes (port 2049)
1409- # * allow traffic from cluster nodes (ports 1018-1023)
1410- # * allow traffic from storage nodes (all traffic)
1411- fsx_sg_ingress_rules = [
1412- (name , resource )
1413- for name , resource in generated_template ["Resources" ].items ()
1414- if resource ["Type" ] == "AWS::EC2::SecurityGroupIngress" and name .startswith ("FSX" ) and "SecurityGroup" in name
1415- ]
1416- assert_that (len (fsx_sg_ingress_rules )).is_equal_to (3 )
1417-
1418- # Verify each storage type has the expected unique source security groups
1419- for _storage_type , rules in [("EFS" , efs_sg_ingress_rules ), ("FSX" , fsx_sg_ingress_rules )]:
1420- source_sgs = {
1421- str (rule ["Properties" ].get ("SourceSecurityGroupId" ))
1422- for _ , rule in rules
1423- if rule ["Properties" ].get ("SourceSecurityGroupId" )
1424- }
1425- # Should have 2 unique source SGs (shared SG + storage SG)
1426- assert_that (len (source_sgs )).is_equal_to (2 )
1427-
1428-
1429- def test_storage_security_group_port_restrictions (mocker , test_datadir ):
1379+ @pytest .mark .parametrize (
1380+ "head_sg, q1_sg, q2_sg, login1_sg, login2_sg, expected_efs_ingress_rules, expected_fsx_ingress_rules" ,
1381+ [
1382+ # Case 1: All nodes use the default (managed) security group
1383+ # EFS: rule(2049) * (head, compute, login1, login2) + rule(all) * (storage) = 5 rules
1384+ # FSx: rule(988,1018-1023) * (head, compute, login1, login2) + rule(all) * (storage) = 9 rules
1385+ (None , None , None , None , None , 5 , 9 ),
1386+ # Case 2: All nodes use the same custom security group (deduplication)
1387+ # EFS: rule(2049) * (customSG) + rule(all) * (storage) = 2 rules
1388+ # FSx: rule(988,1018-1023) * (customSG) + rule(all) * (storage) = 3 rules
1389+ (
1390+ "sg-1234567891234567a" ,
1391+ "sg-1234567891234567a" ,
1392+ "sg-1234567891234567a" ,
1393+ "sg-1234567891234567a" ,
1394+ "sg-1234567891234567a" ,
1395+ 2 ,
1396+ 3 ,
1397+ ),
1398+ # Case 3: All nodes use different custom security groups
1399+ # EFS: rule(2049) * (head_sg, q1_sg, q2_sg, login1_sg, login2_sg) + rule(all) * (storage) = 6 rules
1400+ # FSx: rule(988,1018-1023) * (head_sg, q1_sg, q2_sg, login1_sg, login2_sg) + rule(all) * (storage) = 11 rules
1401+ (
1402+ "sg-1234567891234567a" ,
1403+ "sg-1234567891234567b" ,
1404+ "sg-1234567891234567c" ,
1405+ "sg-1234567891234567d" ,
1406+ "sg-1234567891234567e" ,
1407+ 6 ,
1408+ 11 ,
1409+ ),
1410+ ],
1411+ ids = [
1412+ "all_default_sg" ,
1413+ "all_same_custom_sg" ,
1414+ "all_different_custom_sg" ,
1415+ ],
1416+ )
1417+ def test_storage_security_group_port_restrictions (
1418+ mocker ,
1419+ test_datadir ,
1420+ pcluster_config_reader ,
1421+ head_sg ,
1422+ q1_sg ,
1423+ q2_sg ,
1424+ login1_sg ,
1425+ login2_sg ,
1426+ expected_efs_ingress_rules ,
1427+ expected_fsx_ingress_rules ,
1428+ ):
14301429 """
14311430 Test that storage security group rules use restricted ports for head/compute/login nodes.
14321431
14331432 Security group rules should follow these principles:
14341433 1. Storage-to-Storage: Allow all traffic (protocol -1)
14351434 2. Head/Compute/Login nodes to EFS: Allow only TCP port 2049
14361435 3. Head/Compute/Login nodes to FSx Lustre: Allow only TCP ports 988 and 1018-1023
1436+ 4. Rules are deduplicated when nodes share the same security group
14371437 """
14381438 mock_aws_api (mocker )
14391439 mock_bucket (mocker )
14401440 mock_bucket_object_utils (mocker )
14411441
1442- input_yaml = load_yaml_dict (test_datadir / "config-shared-sg.yaml" )
1443- cluster_config = ClusterSchema (cluster_name = "clustername" ).load (input_yaml )
1442+ rendered_config_path = pcluster_config_reader (
1443+ "config.yaml" ,
1444+ head_sg = head_sg ,
1445+ q1_sg = q1_sg ,
1446+ q2_sg = q2_sg ,
1447+ login1_sg = login1_sg ,
1448+ login2_sg = login2_sg ,
1449+ )
1450+
1451+ rendered_config = load_yaml_dict (rendered_config_path )
1452+ cluster_config = ClusterSchema (cluster_name = "clustername" ).load (rendered_config )
14441453
14451454 generated_template , _ = CDKTemplateBuilder ().build_cluster_template (
14461455 cluster_config = cluster_config , bucket = dummy_cluster_bucket (), stack_name = "clustername"
14471456 )
14481457
1458+ # Build expected source security groups based on config
1459+ # Custom SG is used directly as string, managed SG is referenced via {"Ref": "SgName"}
1460+ expected_source = set ()
1461+ expected_source .add (head_sg ) if head_sg else expected_source .add (("Ref" , "HeadNodeSecurityGroup" ))
1462+ expected_source .add (q1_sg ) if q1_sg else expected_source .add (("Ref" , "ComputeSecurityGroup" ))
1463+ expected_source .add (q2_sg ) if q2_sg else expected_source .add (("Ref" , "ComputeSecurityGroup" ))
1464+ expected_source .add (login1_sg ) if login1_sg else expected_source .add (("Ref" , "login1LoginNodesSecurityGroup" ))
1465+ expected_source .add (login2_sg ) if login2_sg else expected_source .add (("Ref" , "login2LoginNodesSecurityGroup" ))
1466+
1467+ def _normalize_source_sg (source_sg ):
1468+ """Convert source SG to a comparable format."""
1469+ if isinstance (source_sg , str ):
1470+ return source_sg
1471+ elif isinstance (source_sg , dict ) and "Ref" in source_sg :
1472+ return ("Ref" , source_sg ["Ref" ])
1473+ return source_sg
1474+
14491475 # Test EFS storage - should only allow port 2049
1450- efs_ingress_rules = [
1451- (name , resource )
1452- for name , resource in generated_template ["Resources" ].items ()
1453- if resource ["Type" ] == "AWS::EC2::SecurityGroupIngress" and name .startswith ("EFS" ) and "SecurityGroup" in name
1454- ]
1476+ efs_ingress_rules = get_resources (
1477+ generated_template ,
1478+ type = "AWS::EC2::SecurityGroupIngress" ,
1479+ name_regex = r"^EFS.*SecurityGroup" ,
1480+ )
1481+
1482+ # Verify rule count (deduplication check)
1483+ assert_that (len (efs_ingress_rules )).is_equal_to (expected_efs_ingress_rules )
14551484
1456- for name , rule in efs_ingress_rules :
1485+ efs_source_sgs = set ()
1486+ for name , rule in efs_ingress_rules .items ():
14571487 props = rule ["Properties" ]
14581488 if "Storage" in name :
14591489 # Storage-to-Storage: all traffic allowed
@@ -1465,31 +1495,47 @@ def test_storage_security_group_port_restrictions(mocker, test_datadir):
14651495 assert_that (props ["IpProtocol" ]).is_equal_to ("tcp" )
14661496 assert_that (props ["FromPort" ]).is_equal_to (2049 )
14671497 assert_that (props ["ToPort" ]).is_equal_to (2049 )
1498+ # Collect source security group
1499+ source_sg = _normalize_source_sg (props .get ("SourceSecurityGroupId" ))
1500+ efs_source_sgs .add (source_sg )
1501+
1502+ # Verify source SGs match expected
1503+ assert_that (efs_source_sgs ).is_equal_to (expected_source )
14681504
14691505 # Test FSx Lustre storage - should only allow ports 988 and 1018-1023
1470- fsx_ingress_rules = [
1471- (name , resource )
1472- for name , resource in generated_template ["Resources" ].items ()
1473- if resource ["Type" ] == "AWS::EC2::SecurityGroupIngress" and name .startswith ("FSX" ) and "SecurityGroup" in name
1474- ]
1506+ fsx_ingress_rules = get_resources (
1507+ generated_template ,
1508+ type = "AWS::EC2::SecurityGroupIngress" ,
1509+ name_regex = r"^FSX.*SecurityGroup" ,
1510+ )
1511+
1512+ # Verify rule count (deduplication check)
1513+ assert_that (len (fsx_ingress_rules )).is_equal_to (expected_fsx_ingress_rules )
14751514
14761515 # Collect non-storage rules to verify FSx ports
1477- fsx_node_rules = [( name , rule ) for name , rule in fsx_ingress_rules if "Storage" not in name ]
1478- fsx_storage_rules = [( name , rule ) for name , rule in fsx_ingress_rules if "Storage" in name ]
1516+ fsx_node_rules = { name : rule for name , rule in fsx_ingress_rules . items () if "Storage" not in name }
1517+ fsx_storage_rules = { name : rule for name , rule in fsx_ingress_rules . items () if "Storage" in name }
14791518
14801519 # Verify Storage-to-Storage rule allows all traffic
1481- for _name , rule in fsx_storage_rules :
1520+ for _name , rule in fsx_storage_rules . items () :
14821521 props = rule ["Properties" ]
14831522 assert_that (props ["IpProtocol" ]).is_equal_to ("-1" )
14841523 assert_that (props ["FromPort" ]).is_equal_to (0 )
14851524 assert_that (props ["ToPort" ]).is_equal_to (65535 )
14861525
14871526 # Verify Head/Compute/Login rules use TCP and FSx Lustre ports (988, 1018-1023)
14881527 fsx_ports_found = set ()
1489- for _name , rule in fsx_node_rules :
1528+ fsx_source_sgs = set ()
1529+ for _name , rule in fsx_node_rules .items ():
14901530 props = rule ["Properties" ]
14911531 assert_that (props ["IpProtocol" ]).is_equal_to ("tcp" )
14921532 fsx_ports_found .add ((props ["FromPort" ], props ["ToPort" ]))
1533+ # Collect source security group
1534+ source_sg = _normalize_source_sg (props .get ("SourceSecurityGroupId" ))
1535+ fsx_source_sgs .add (source_sg )
14931536
14941537 # Should have rules for port 988 and port range 1018-1023
14951538 assert_that (fsx_ports_found ).contains ((988 , 988 ), (1018 , 1023 ))
1539+
1540+ # Verify source SGs match expected
1541+ assert_that (fsx_source_sgs ).is_equal_to (expected_source )
0 commit comments