@@ -887,6 +887,241 @@ def test_column_wise(self, data_type: DataType) -> None:
887
887
}
888
888
self .assertDictEqual (expected , module_sharding_plan )
889
889
890
+ # pyre-fixme[56]
891
+ @given (data_type = st .sampled_from ([DataType .FP32 , DataType .FP16 ]))
892
+ @settings (verbosity = Verbosity .verbose , max_examples = 8 , deadline = None )
893
+ def test_column_wise_size_per_rank (self , data_type : DataType ) -> None :
894
+ """Test column_wise sharding with custom size_per_rank parameter."""
895
+
896
+ embedding_bag_config = [
897
+ EmbeddingBagConfig (
898
+ name = "table_0" ,
899
+ feature_names = ["feature_0" ],
900
+ embedding_dim = 100 , # Total columns that will be split as [30, 40, 30]
901
+ num_embeddings = 1024 ,
902
+ data_type = data_type ,
903
+ )
904
+ ]
905
+
906
+ # Test uneven column distribution: rank 0 gets 30 cols, rank 1 gets 40 cols, rank 2 gets 30 cols
907
+ module_sharding_plan = construct_module_sharding_plan (
908
+ EmbeddingBagCollection (tables = embedding_bag_config ),
909
+ per_param_sharding = {
910
+ "table_0" : column_wise (size_per_rank = [30 , 40 , 30 ]),
911
+ },
912
+ local_size = 3 ,
913
+ world_size = 3 ,
914
+ device_type = "cuda" ,
915
+ )
916
+
917
+ expected = {
918
+ "table_0" : ParameterSharding (
919
+ sharding_type = "column_wise" ,
920
+ compute_kernel = "dense" ,
921
+ ranks = [0 , 1 , 2 ],
922
+ sharding_spec = EnumerableShardingSpec (
923
+ shards = [
924
+ ShardMetadata (
925
+ shard_offsets = [0 , 0 ],
926
+ shard_sizes = [1024 , 30 ],
927
+ placement = "rank:0/cuda:0" ,
928
+ ),
929
+ ShardMetadata (
930
+ shard_offsets = [0 , 30 ],
931
+ shard_sizes = [1024 , 40 ],
932
+ placement = "rank:1/cuda:1" ,
933
+ ),
934
+ ShardMetadata (
935
+ shard_offsets = [0 , 70 ],
936
+ shard_sizes = [1024 , 30 ],
937
+ placement = "rank:2/cuda:2" ,
938
+ ),
939
+ ]
940
+ ),
941
+ ),
942
+ }
943
+ self .assertDictEqual (expected , module_sharding_plan )
944
+
945
+ # pyre-fixme[56]
946
+ @given (data_type = st .sampled_from ([DataType .FP32 , DataType .FP16 ]))
947
+ @settings (verbosity = Verbosity .verbose , max_examples = 8 , deadline = None )
948
+ def test_column_wise_device_types (self , data_type : DataType ) -> None :
949
+ """Test column_wise sharding with mixed device types."""
950
+
951
+ embedding_bag_config = [
952
+ EmbeddingBagConfig (
953
+ name = "table_0" ,
954
+ feature_names = ["feature_0" ],
955
+ embedding_dim = 64 ,
956
+ num_embeddings = 1024 ,
957
+ data_type = data_type ,
958
+ )
959
+ ]
960
+
961
+ # Test mixed device types: cpu, cuda, cpu, cuda
962
+ module_sharding_plan = construct_module_sharding_plan (
963
+ EmbeddingBagCollection (tables = embedding_bag_config ),
964
+ per_param_sharding = {
965
+ "table_0" : column_wise (
966
+ ranks = [0 , 1 , 2 , 3 ],
967
+ device_types = ["cpu" , "cuda" , "cpu" , "cuda" ],
968
+ ),
969
+ },
970
+ local_size = 4 ,
971
+ world_size = 4 ,
972
+ device_type = "cuda" ,
973
+ )
974
+
975
+ expected = {
976
+ "table_0" : ParameterSharding (
977
+ sharding_type = "column_wise" ,
978
+ compute_kernel = "dense" ,
979
+ ranks = [0 , 1 , 2 , 3 ],
980
+ sharding_spec = EnumerableShardingSpec (
981
+ shards = [
982
+ ShardMetadata (
983
+ shard_offsets = [0 , 0 ],
984
+ shard_sizes = [1024 , 16 ],
985
+ placement = "rank:0/cpu" ,
986
+ ),
987
+ ShardMetadata (
988
+ shard_offsets = [0 , 16 ],
989
+ shard_sizes = [1024 , 16 ],
990
+ placement = "rank:0/cuda:0" ,
991
+ ),
992
+ ShardMetadata (
993
+ shard_offsets = [0 , 32 ],
994
+ shard_sizes = [1024 , 16 ],
995
+ placement = "rank:0/cpu" ,
996
+ ),
997
+ ShardMetadata (
998
+ shard_offsets = [0 , 48 ],
999
+ shard_sizes = [1024 , 16 ],
1000
+ placement = "rank:1/cuda:1" ,
1001
+ ),
1002
+ ]
1003
+ ),
1004
+ ),
1005
+ }
1006
+ self .assertDictEqual (expected , module_sharding_plan )
1007
+
1008
+ def test_column_wise_size_per_rank_insufficient_columns (self ) -> None :
1009
+ """Test that column_wise raises error when size_per_rank doesn't cover all columns."""
1010
+
1011
+ embedding_bag_config = [
1012
+ EmbeddingBagConfig (
1013
+ name = "table_0" ,
1014
+ feature_names = ["feature_0" ],
1015
+ embedding_dim = 100 ,
1016
+ num_embeddings = 1024 ,
1017
+ data_type = DataType .FP32 ,
1018
+ )
1019
+ ]
1020
+
1021
+ with self .assertRaises (ValueError ) as context :
1022
+ construct_module_sharding_plan (
1023
+ EmbeddingBagCollection (tables = embedding_bag_config ),
1024
+ per_param_sharding = {
1025
+ "table_0" : column_wise (
1026
+ size_per_rank = [30 , 40 ]
1027
+ ), # Only covers 70/100 columns
1028
+ },
1029
+ local_size = 2 ,
1030
+ world_size = 2 ,
1031
+ device_type = "cuda" ,
1032
+ )
1033
+
1034
+ self .assertIn (
1035
+ "Cannot fit tensor of (1024, 100) into sizes_ranks_placements = [30, 40]" ,
1036
+ str (context .exception ),
1037
+ )
1038
+
1039
+ def test_column_wise_size_per_rank_with_device_types (self ) -> None :
1040
+ """Test column_wise sharding with both size_per_rank and device_types parameters."""
1041
+
1042
+ embedding_bag_config = [
1043
+ EmbeddingBagConfig (
1044
+ name = "table_0" ,
1045
+ feature_names = ["feature_0" ],
1046
+ embedding_dim = 80 , # Total columns that will be split as [20, 30, 30]
1047
+ num_embeddings = 512 ,
1048
+ data_type = DataType .FP32 ,
1049
+ )
1050
+ ]
1051
+
1052
+ # Test combining custom column sizes with mixed device types
1053
+ module_sharding_plan = construct_module_sharding_plan (
1054
+ EmbeddingBagCollection (tables = embedding_bag_config ),
1055
+ per_param_sharding = {
1056
+ "table_0" : column_wise (
1057
+ size_per_rank = [20 , 30 , 30 ],
1058
+ device_types = ["cpu" , "cuda" , "cpu" ],
1059
+ ),
1060
+ },
1061
+ local_size = 3 ,
1062
+ world_size = 3 ,
1063
+ device_type = "cuda" ,
1064
+ )
1065
+
1066
+ expected = {
1067
+ "table_0" : ParameterSharding (
1068
+ sharding_type = "column_wise" ,
1069
+ compute_kernel = "dense" ,
1070
+ ranks = [0 , 1 , 2 ],
1071
+ sharding_spec = EnumerableShardingSpec (
1072
+ shards = [
1073
+ ShardMetadata (
1074
+ shard_offsets = [0 , 0 ],
1075
+ shard_sizes = [512 , 20 ],
1076
+ placement = "rank:0/cpu" ,
1077
+ ),
1078
+ ShardMetadata (
1079
+ shard_offsets = [0 , 20 ],
1080
+ shard_sizes = [512 , 30 ],
1081
+ placement = "rank:0/cuda:0" ,
1082
+ ),
1083
+ ShardMetadata (
1084
+ shard_offsets = [0 , 50 ],
1085
+ shard_sizes = [512 , 30 ],
1086
+ placement = "rank:0/cpu" ,
1087
+ ),
1088
+ ]
1089
+ ),
1090
+ ),
1091
+ }
1092
+ self .assertDictEqual (expected , module_sharding_plan )
1093
+
1094
+ def test_column_wise_uneven_division_error (self ) -> None :
1095
+ """Test that column_wise raises error when columns can't be evenly divided across ranks."""
1096
+
1097
+ embedding_bag_config = [
1098
+ EmbeddingBagConfig (
1099
+ name = "table_0" ,
1100
+ feature_names = ["feature_0" ],
1101
+ embedding_dim = 65 , # Cannot be evenly divided by 2
1102
+ num_embeddings = 1024 ,
1103
+ data_type = DataType .FP32 ,
1104
+ )
1105
+ ]
1106
+
1107
+ with self .assertRaises (ValueError ) as context :
1108
+ construct_module_sharding_plan (
1109
+ EmbeddingBagCollection (tables = embedding_bag_config ),
1110
+ per_param_sharding = {
1111
+ "table_0" : column_wise (
1112
+ ranks = [0 , 1 ]
1113
+ ), # 65 columns cannot be evenly divided by 2 ranks
1114
+ },
1115
+ local_size = 2 ,
1116
+ world_size = 2 ,
1117
+ device_type = "cuda" ,
1118
+ )
1119
+
1120
+ self .assertIn (
1121
+ "column dim of 65 cannot be evenly divided across [0, 1]" ,
1122
+ str (context .exception ),
1123
+ )
1124
+
890
1125
891
1126
class ShardingPlanTest (unittest .TestCase ):
892
1127
def test_str (self ) -> None :
0 commit comments