|
8 | 8 | # pyre-strict |
9 | 9 |
|
10 | 10 | import copy |
11 | | -from typing import Any, Dict, List, Optional, Tuple, Type, Union |
| 11 | +from typing import Any, cast, Dict, Iterator, List, Optional, Tuple, Type, Union |
12 | 12 |
|
13 | 13 | import torch |
14 | 14 | from fbgemm_gpu.split_table_batched_embeddings_ops_inference import ( |
|
43 | 43 | is_fused_param_register_tbe, |
44 | 44 | ) |
45 | 45 | from torchrec.distributed.global_settings import get_propogate_device |
| 46 | +from torchrec.distributed.mc_modules import ( |
| 47 | + InferManagedCollisionCollectionSharder, |
| 48 | + ShardedMCCRemapper, |
| 49 | + ShardedQuantManagedCollisionCollection, |
| 50 | +) |
46 | 51 | from torchrec.distributed.quant_state import ShardedQuantEmbeddingModuleState |
47 | 52 | from torchrec.distributed.sharding.cw_sharding import InferCwPooledEmbeddingSharding |
48 | 53 | from torchrec.distributed.sharding.rw_sharding import InferRwPooledEmbeddingSharding |
|
54 | 59 | ShardingEnv, |
55 | 60 | ShardingType, |
56 | 61 | ) |
57 | | -from torchrec.distributed.utils import copy_to_device |
| 62 | +from torchrec.distributed.utils import append_prefix, copy_to_device |
58 | 63 | from torchrec.modules.embedding_configs import ( |
59 | 64 | data_type_to_sparse_type, |
60 | 65 | dtype_to_data_type, |
|
67 | 72 | EmbeddingBagCollection as QuantEmbeddingBagCollection, |
68 | 73 | FeatureProcessedEmbeddingBagCollection as QuantFeatureProcessedEmbeddingBagCollection, |
69 | 74 | MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS, |
| 75 | + QuantManagedCollisionEmbeddingBagCollection, |
70 | 76 | ) |
71 | | -from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor |
| 77 | +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor |
72 | 78 |
|
73 | 79 |
|
74 | 80 | def get_device_from_parameter_sharding( |
@@ -722,3 +728,261 @@ def forward(self, features: KeyedJaggedTensor) -> ListOfKJTList: |
722 | 728 | for i in range(len(self._input_dists)) |
723 | 729 | ] |
724 | 730 | ) |
| 731 | + |
| 732 | + |
| 733 | +class ShardedMCEBCLookup(torch.nn.Module): |
| 734 | + """ |
| 735 | + This module implements distributed compute of a ShardedQuantManagedCollisionEmbeddingBagCollection. |
| 736 | +
|
| 737 | + Args: |
| 738 | + sharding (int): sharding index |
| 739 | + rank (int): rank index |
| 740 | + mcc_remapper (ShardedMCCRemapper): managed collision collection remapper |
| 741 | + ebc_lookup (nn.Module): embedding bag collection lookup |
| 742 | +
|
| 743 | + Example:: |
| 744 | +
|
| 745 | + """ |
| 746 | + |
| 747 | + def __init__( |
| 748 | + self, |
| 749 | + sharding: int, |
| 750 | + rank: int, |
| 751 | + mcc_remapper: ShardedMCCRemapper, |
| 752 | + ebc_lookup: nn.Module, |
| 753 | + ) -> None: |
| 754 | + super().__init__() |
| 755 | + self._sharding = sharding |
| 756 | + self._rank = rank |
| 757 | + self._mcc_remapper = mcc_remapper |
| 758 | + self._ebc_lookup = ebc_lookup |
| 759 | + |
| 760 | + def forward( |
| 761 | + self, |
| 762 | + features: KeyedJaggedTensor, |
| 763 | + ) -> Tuple[ |
| 764 | + Union[KeyedTensor, Dict[str, JaggedTensor]], Optional[KeyedJaggedTensor] |
| 765 | + ]: |
| 766 | + """ |
| 767 | + Applies managed collision collection remapping and performs embedding lookup. |
| 768 | +
|
| 769 | + Args: |
| 770 | + features (KeyedJaggedTensor): input features |
| 771 | +
|
| 772 | + Returns: |
| 773 | + Tuple[Union[KeyedTensor, Dict[str, JaggedTensor]], Optional[KeyedJaggedTensor]]: embedding lookup result |
| 774 | + """ |
| 775 | + remapped_kjt = self._mcc_remapper(features) |
| 776 | + return self._ebc_lookup(remapped_kjt) |
| 777 | + |
| 778 | + |
| 779 | +class ShardedQuantManagedCollisionEmbeddingBagCollection( |
| 780 | + ShardedQuantEmbeddingBagCollection |
| 781 | +): |
| 782 | + def __init__( |
| 783 | + self, |
| 784 | + module: QuantManagedCollisionEmbeddingBagCollection, |
| 785 | + table_name_to_parameter_sharding: Dict[str, ParameterSharding], |
| 786 | + mc_sharder: InferManagedCollisionCollectionSharder, |
| 787 | + # TODO - maybe we need this to manage unsharded/sharded consistency/state consistency |
| 788 | + env: Union[ShardingEnv, Dict[str, ShardingEnv]], |
| 789 | + fused_params: Optional[Dict[str, Any]] = None, |
| 790 | + device: Optional[torch.device] = None, |
| 791 | + ) -> None: |
| 792 | + super().__init__( |
| 793 | + module, table_name_to_parameter_sharding, env, fused_params, device |
| 794 | + ) |
| 795 | + |
| 796 | + self._device = device |
| 797 | + self._env = env |
| 798 | + |
| 799 | + # TODO: This is a hack since _embedding_module doesn't need input |
| 800 | + # dist, so eliminating it so all fused a2a will ignore it. |
| 801 | + # we're using ec input_dist directly, so this cannot be escaped. |
| 802 | + # self._has_uninitialized_input_dist = False |
| 803 | + embedding_shardings = list( |
| 804 | + self._sharding_type_device_group_to_sharding.values() |
| 805 | + ) |
| 806 | + |
| 807 | + self._managed_collision_collection: ShardedQuantManagedCollisionCollection = ( |
| 808 | + mc_sharder.shard( |
| 809 | + module._managed_collision_collection, |
| 810 | + table_name_to_parameter_sharding, |
| 811 | + env=env, |
| 812 | + device=device, |
| 813 | + # pyre-ignore |
| 814 | + embedding_shardings=embedding_shardings, |
| 815 | + ) |
| 816 | + ) |
| 817 | + self._return_remapped_features: bool = module._return_remapped_features |
| 818 | + self._create_mcebc_lookups() |
| 819 | + |
| 820 | + def _create_mcebc_lookups(self) -> None: |
| 821 | + mcebc_lookups: List[nn.ModuleList] = [] |
| 822 | + mcc_remappers: List[List[ShardedMCCRemapper]] = ( |
| 823 | + self._managed_collision_collection.create_mcc_remappers() |
| 824 | + ) |
| 825 | + for sharding in range( |
| 826 | + len(self._managed_collision_collection._embedding_shardings) |
| 827 | + ): |
| 828 | + ebc_sharding_lookups = self._lookups[sharding] |
| 829 | + sharding_mcebc_lookups: List[ShardedMCEBCLookup] = [] |
| 830 | + for j, ec_lookup in enumerate( |
| 831 | + ebc_sharding_lookups._embedding_lookups_per_rank # pyre-ignore |
| 832 | + ): |
| 833 | + sharding_mcebc_lookups.append( |
| 834 | + ShardedMCEBCLookup( |
| 835 | + sharding, |
| 836 | + j, |
| 837 | + mcc_remappers[sharding][j], |
| 838 | + ec_lookup, |
| 839 | + ) |
| 840 | + ) |
| 841 | + mcebc_lookups.append(nn.ModuleList(sharding_mcebc_lookups)) |
| 842 | + self._mcebc_lookup: nn.ModuleList = nn.ModuleList(mcebc_lookups) |
| 843 | + |
| 844 | + def input_dist( |
| 845 | + self, |
| 846 | + ctx: NullShardedModuleContext, |
| 847 | + features: KeyedJaggedTensor, |
| 848 | + ) -> ListOfKJTList: |
| 849 | + # TODO: resolve incompatiblity with different contexts |
| 850 | + if self._has_uninitialized_output_dist: |
| 851 | + self._create_output_dist(features.device()) |
| 852 | + self._has_uninitialized_output_dist = False |
| 853 | + |
| 854 | + return self._managed_collision_collection.input_dist( |
| 855 | + # pyre-fixme [6] |
| 856 | + ctx, |
| 857 | + features, |
| 858 | + is_sequence_embedding=False, |
| 859 | + ) |
| 860 | + |
| 861 | + def compute( |
| 862 | + self, |
| 863 | + ctx: NullShardedModuleContext, |
| 864 | + dist_input: ListOfKJTList, |
| 865 | + ) -> List[List[torch.Tensor]]: |
| 866 | + ret: List[List[torch.Tensor]] = [] |
| 867 | + for i in range(len(self._managed_collision_collection._embedding_shardings)): |
| 868 | + dist_input_i = dist_input[i] |
| 869 | + lookups = self._mcebc_lookup[i] |
| 870 | + sharding_ret: List[torch.Tensor] = [] |
| 871 | + for j, lookup in enumerate(lookups): |
| 872 | + rank_ret = lookup( |
| 873 | + features=dist_input_i[j], |
| 874 | + ) |
| 875 | + sharding_ret.append(rank_ret) |
| 876 | + ret.append(sharding_ret) |
| 877 | + return ret |
| 878 | + |
| 879 | + # pyre-ignore |
| 880 | + def output_dist( |
| 881 | + self, |
| 882 | + ctx: NullShardedModuleContext, |
| 883 | + output: List[List[torch.Tensor]], |
| 884 | + ) -> Tuple[ |
| 885 | + Union[KeyedTensor, Dict[str, JaggedTensor]], Optional[KeyedJaggedTensor] |
| 886 | + ]: |
| 887 | + |
| 888 | + # pyre-ignore [6] |
| 889 | + ebc_out = super().output_dist(ctx, output) |
| 890 | + |
| 891 | + kjt_out: Optional[KeyedJaggedTensor] = None |
| 892 | + |
| 893 | + return ebc_out, kjt_out |
| 894 | + |
| 895 | + def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]: |
| 896 | + for fqn, _ in self.named_parameters(): |
| 897 | + yield append_prefix(prefix, fqn) |
| 898 | + for fqn, _ in self.named_buffers(): |
| 899 | + yield append_prefix(prefix, fqn) |
| 900 | + |
| 901 | + |
| 902 | +class QuantManagedCollisionEmbeddingBagCollectionSharder( |
| 903 | + BaseQuantEmbeddingSharder[QuantManagedCollisionEmbeddingBagCollection] |
| 904 | +): |
| 905 | + """ |
| 906 | + Sharder for QuantManagedCollisionEmbeddingBagCollection. |
| 907 | +
|
| 908 | + This implementation uses non-fused EmbeddingBagCollection and manages both |
| 909 | + embedding bag collection sharding and managed collision collection sharding. |
| 910 | +
|
| 911 | + Args: |
| 912 | + e_sharder (QuantEmbeddingBagCollectionSharder): sharder for embedding bag collection |
| 913 | + mc_sharder (InferManagedCollisionCollectionSharder): sharder for managed collision collection |
| 914 | +
|
| 915 | + Example:: |
| 916 | +
|
| 917 | + """ |
| 918 | + |
| 919 | + def __init__( |
| 920 | + self, |
| 921 | + e_sharder: QuantEmbeddingBagCollectionSharder, |
| 922 | + mc_sharder: InferManagedCollisionCollectionSharder, |
| 923 | + ) -> None: |
| 924 | + super().__init__() |
| 925 | + self._e_sharder: QuantEmbeddingBagCollectionSharder = e_sharder |
| 926 | + self._mc_sharder: InferManagedCollisionCollectionSharder = mc_sharder |
| 927 | + |
| 928 | + def shardable_parameters( |
| 929 | + self, module: QuantManagedCollisionEmbeddingBagCollection |
| 930 | + ) -> Dict[str, torch.nn.Parameter]: |
| 931 | + return self._e_sharder.shardable_parameters(module) |
| 932 | + |
| 933 | + def compute_kernels( |
| 934 | + self, |
| 935 | + sharding_type: str, |
| 936 | + compute_device_type: str, |
| 937 | + ) -> List[str]: |
| 938 | + return [ |
| 939 | + EmbeddingComputeKernel.QUANT.value, |
| 940 | + ] |
| 941 | + |
| 942 | + def sharding_types(self, compute_device_type: str) -> List[str]: |
| 943 | + return list( |
| 944 | + set.intersection( |
| 945 | + set(self._e_sharder.sharding_types(compute_device_type)), |
| 946 | + set(self._mc_sharder.sharding_types(compute_device_type)), |
| 947 | + ) |
| 948 | + ) |
| 949 | + |
| 950 | + @property |
| 951 | + def fused_params(self) -> Optional[Dict[str, Any]]: |
| 952 | + # TODO: to be deprecate after planner get cache_load_factor from ParameterConstraints |
| 953 | + return self._e_sharder.fused_params |
| 954 | + |
| 955 | + def shard( |
| 956 | + self, |
| 957 | + module: QuantManagedCollisionEmbeddingBagCollection, |
| 958 | + params: Dict[str, ParameterSharding], |
| 959 | + env: Union[ShardingEnv, Dict[str, ShardingEnv]], |
| 960 | + device: Optional[torch.device] = None, |
| 961 | + module_fqn: Optional[str] = None, |
| 962 | + ) -> ShardedQuantManagedCollisionEmbeddingBagCollection: |
| 963 | + fused_params = self.fused_params if self.fused_params else {} |
| 964 | + fused_params["output_dtype"] = data_type_to_sparse_type( |
| 965 | + dtype_to_data_type(module.output_dtype()) |
| 966 | + ) |
| 967 | + if FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS not in fused_params: |
| 968 | + fused_params[FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS] = getattr( |
| 969 | + module, |
| 970 | + MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS, |
| 971 | + False, |
| 972 | + ) |
| 973 | + if FUSED_PARAM_REGISTER_TBE_BOOL not in fused_params: |
| 974 | + fused_params[FUSED_PARAM_REGISTER_TBE_BOOL] = getattr( |
| 975 | + module, FUSED_PARAM_REGISTER_TBE_BOOL, False |
| 976 | + ) |
| 977 | + return ShardedQuantManagedCollisionEmbeddingBagCollection( |
| 978 | + module, |
| 979 | + params, |
| 980 | + self._mc_sharder, |
| 981 | + env, |
| 982 | + fused_params, |
| 983 | + device, |
| 984 | + ) |
| 985 | + |
| 986 | + @property |
| 987 | + def module_type(self) -> Type[QuantManagedCollisionEmbeddingBagCollection]: |
| 988 | + return QuantManagedCollisionEmbeddingBagCollection |
0 commit comments