Skip to content

Commit 6556a0f

Browse files
faran928facebook-github-bot
authored andcommitted
Adding MC EBC quant embedding modules for inference in TorchRec
Summary: We may need MC EBC bucket-aware sharding during inference for some models. Supporting Quant Sharded MC EBC version of the module for the same Differential Revision: D82994116
1 parent 217889e commit 6556a0f

File tree

4 files changed

+448
-20
lines changed

4 files changed

+448
-20
lines changed

torchrec/distributed/mc_modules.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
from torchrec.distributed.types import (
6161
Awaitable,
6262
LazyAwaitable,
63+
NullShardedModuleContext,
6364
ParameterSharding,
6465
QuantizedCommCodecs,
6566
ShardedModule,
@@ -1292,8 +1293,9 @@ def _create_input_dists(
12921293
# pyre-ignore
12931294
def input_dist(
12941295
self,
1295-
ctx: ManagedCollisionCollectionContext,
1296+
ctx: Union[ManagedCollisionCollectionContext, NullShardedModuleContext],
12961297
features: KeyedJaggedTensor,
1298+
is_sequence_embedding: bool = True,
12971299
) -> ListOfKJTList:
12981300
if self._has_uninitialized_input_dists:
12991301
self._create_input_dists(
@@ -1345,19 +1347,20 @@ def input_dist(
13451347
for feature_split, input_dist in zip(feature_splits, self._input_dists):
13461348
out = input_dist(feature_split)
13471349
input_dist_result_list.append(out.features)
1348-
ctx.sharding_contexts.append(
1349-
InferSequenceShardingContext(
1350-
features=out.features,
1351-
features_before_input_dist=features,
1352-
unbucketize_permute_tensor=(
1353-
out.unbucketize_permute_tensor
1354-
if isinstance(input_dist, InferRwSparseFeaturesDist)
1355-
else None
1356-
),
1357-
bucket_mapping_tensor=out.bucket_mapping_tensor,
1358-
bucketized_length=out.bucketized_length,
1350+
if is_sequence_embedding:
1351+
ctx.sharding_contexts.append(
1352+
InferSequenceShardingContext(
1353+
features=out.features,
1354+
features_before_input_dist=features,
1355+
unbucketize_permute_tensor=(
1356+
out.unbucketize_permute_tensor
1357+
if isinstance(input_dist, InferRwSparseFeaturesDist)
1358+
else None
1359+
),
1360+
bucket_mapping_tensor=out.bucket_mapping_tensor,
1361+
bucketized_length=out.bucketized_length,
1362+
)
13591363
)
1360-
)
13611364

13621365
return ListOfKJTList(input_dist_result_list)
13631366

torchrec/distributed/quant_embeddingbag.py

Lines changed: 267 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
# pyre-strict
99

1010
import copy
11-
from typing import Any, Dict, List, Optional, Tuple, Type, Union
11+
from typing import Any, cast, Dict, Iterator, List, Optional, Tuple, Type, Union
1212

1313
import torch
1414
from fbgemm_gpu.split_table_batched_embeddings_ops_inference import (
@@ -43,6 +43,11 @@
4343
is_fused_param_register_tbe,
4444
)
4545
from torchrec.distributed.global_settings import get_propogate_device
46+
from torchrec.distributed.mc_modules import (
47+
InferManagedCollisionCollectionSharder,
48+
ShardedMCCRemapper,
49+
ShardedQuantManagedCollisionCollection,
50+
)
4651
from torchrec.distributed.quant_state import ShardedQuantEmbeddingModuleState
4752
from torchrec.distributed.sharding.cw_sharding import InferCwPooledEmbeddingSharding
4853
from torchrec.distributed.sharding.rw_sharding import InferRwPooledEmbeddingSharding
@@ -54,7 +59,7 @@
5459
ShardingEnv,
5560
ShardingType,
5661
)
57-
from torchrec.distributed.utils import copy_to_device
62+
from torchrec.distributed.utils import append_prefix, copy_to_device
5863
from torchrec.modules.embedding_configs import (
5964
data_type_to_sparse_type,
6065
dtype_to_data_type,
@@ -67,8 +72,9 @@
6772
EmbeddingBagCollection as QuantEmbeddingBagCollection,
6873
FeatureProcessedEmbeddingBagCollection as QuantFeatureProcessedEmbeddingBagCollection,
6974
MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS,
75+
QuantManagedCollisionEmbeddingBagCollection,
7076
)
71-
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
77+
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor
7278

7379

7480
def get_device_from_parameter_sharding(
@@ -722,3 +728,261 @@ def forward(self, features: KeyedJaggedTensor) -> ListOfKJTList:
722728
for i in range(len(self._input_dists))
723729
]
724730
)
731+
732+
733+
class ShardedMCEBCLookup(torch.nn.Module):
734+
"""
735+
This module implements distributed compute of a ShardedQuantManagedCollisionEmbeddingBagCollection.
736+
737+
Args:
738+
sharding (int): sharding index
739+
rank (int): rank index
740+
mcc_remapper (ShardedMCCRemapper): managed collision collection remapper
741+
ebc_lookup (nn.Module): embedding bag collection lookup
742+
743+
Example::
744+
745+
"""
746+
747+
def __init__(
748+
self,
749+
sharding: int,
750+
rank: int,
751+
mcc_remapper: ShardedMCCRemapper,
752+
ebc_lookup: nn.Module,
753+
) -> None:
754+
super().__init__()
755+
self._sharding = sharding
756+
self._rank = rank
757+
self._mcc_remapper = mcc_remapper
758+
self._ebc_lookup = ebc_lookup
759+
760+
def forward(
761+
self,
762+
features: KeyedJaggedTensor,
763+
) -> Tuple[
764+
Union[KeyedTensor, Dict[str, JaggedTensor]], Optional[KeyedJaggedTensor]
765+
]:
766+
"""
767+
Applies managed collision collection remapping and performs embedding lookup.
768+
769+
Args:
770+
features (KeyedJaggedTensor): input features
771+
772+
Returns:
773+
Tuple[Union[KeyedTensor, Dict[str, JaggedTensor]], Optional[KeyedJaggedTensor]]: embedding lookup result
774+
"""
775+
remapped_kjt = self._mcc_remapper(features)
776+
return self._ebc_lookup(remapped_kjt)
777+
778+
779+
class ShardedQuantManagedCollisionEmbeddingBagCollection(
780+
ShardedQuantEmbeddingBagCollection
781+
):
782+
def __init__(
783+
self,
784+
module: QuantManagedCollisionEmbeddingBagCollection,
785+
table_name_to_parameter_sharding: Dict[str, ParameterSharding],
786+
mc_sharder: InferManagedCollisionCollectionSharder,
787+
# TODO - maybe we need this to manage unsharded/sharded consistency/state consistency
788+
env: Union[ShardingEnv, Dict[str, ShardingEnv]],
789+
fused_params: Optional[Dict[str, Any]] = None,
790+
device: Optional[torch.device] = None,
791+
) -> None:
792+
super().__init__(
793+
module, table_name_to_parameter_sharding, env, fused_params, device
794+
)
795+
796+
self._device = device
797+
self._env = env
798+
799+
# TODO: This is a hack since _embedding_module doesn't need input
800+
# dist, so eliminating it so all fused a2a will ignore it.
801+
# we're using ec input_dist directly, so this cannot be escaped.
802+
# self._has_uninitialized_input_dist = False
803+
embedding_shardings = list(
804+
self._sharding_type_device_group_to_sharding.values()
805+
)
806+
807+
self._managed_collision_collection: ShardedQuantManagedCollisionCollection = (
808+
mc_sharder.shard(
809+
module._managed_collision_collection,
810+
table_name_to_parameter_sharding,
811+
env=env,
812+
device=device,
813+
# pyre-ignore
814+
embedding_shardings=embedding_shardings,
815+
)
816+
)
817+
self._return_remapped_features: bool = module._return_remapped_features
818+
self._create_mcebc_lookups()
819+
820+
def _create_mcebc_lookups(self) -> None:
821+
mcebc_lookups: List[nn.ModuleList] = []
822+
mcc_remappers: List[List[ShardedMCCRemapper]] = (
823+
self._managed_collision_collection.create_mcc_remappers()
824+
)
825+
for sharding in range(
826+
len(self._managed_collision_collection._embedding_shardings)
827+
):
828+
ebc_sharding_lookups = self._lookups[sharding]
829+
sharding_mcebc_lookups: List[ShardedMCEBCLookup] = []
830+
for j, ec_lookup in enumerate(
831+
ebc_sharding_lookups._embedding_lookups_per_rank # pyre-ignore
832+
):
833+
sharding_mcebc_lookups.append(
834+
ShardedMCEBCLookup(
835+
sharding,
836+
j,
837+
mcc_remappers[sharding][j],
838+
ec_lookup,
839+
)
840+
)
841+
mcebc_lookups.append(nn.ModuleList(sharding_mcebc_lookups))
842+
self._mcebc_lookup: nn.ModuleList = nn.ModuleList(mcebc_lookups)
843+
844+
def input_dist(
845+
self,
846+
ctx: NullShardedModuleContext,
847+
features: KeyedJaggedTensor,
848+
) -> ListOfKJTList:
849+
# TODO: resolve incompatiblity with different contexts
850+
if self._has_uninitialized_output_dist:
851+
self._create_output_dist(features.device())
852+
self._has_uninitialized_output_dist = False
853+
854+
return self._managed_collision_collection.input_dist(
855+
# pyre-fixme [6]
856+
ctx,
857+
features,
858+
is_sequence_embedding=False,
859+
)
860+
861+
def compute(
862+
self,
863+
ctx: NullShardedModuleContext,
864+
dist_input: ListOfKJTList,
865+
) -> List[List[torch.Tensor]]:
866+
ret: List[List[torch.Tensor]] = []
867+
for i in range(len(self._managed_collision_collection._embedding_shardings)):
868+
dist_input_i = dist_input[i]
869+
lookups = self._mcebc_lookup[i]
870+
sharding_ret: List[torch.Tensor] = []
871+
for j, lookup in enumerate(lookups):
872+
rank_ret = lookup(
873+
features=dist_input_i[j],
874+
)
875+
sharding_ret.append(rank_ret)
876+
ret.append(sharding_ret)
877+
return ret
878+
879+
# pyre-ignore
880+
def output_dist(
881+
self,
882+
ctx: NullShardedModuleContext,
883+
output: List[List[torch.Tensor]],
884+
) -> Tuple[
885+
Union[KeyedTensor, Dict[str, JaggedTensor]], Optional[KeyedJaggedTensor]
886+
]:
887+
888+
# pyre-ignore [6]
889+
ebc_out = super().output_dist(ctx, output)
890+
891+
kjt_out: Optional[KeyedJaggedTensor] = None
892+
893+
return ebc_out, kjt_out
894+
895+
def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
896+
for fqn, _ in self.named_parameters():
897+
yield append_prefix(prefix, fqn)
898+
for fqn, _ in self.named_buffers():
899+
yield append_prefix(prefix, fqn)
900+
901+
902+
class QuantManagedCollisionEmbeddingBagCollectionSharder(
903+
BaseQuantEmbeddingSharder[QuantManagedCollisionEmbeddingBagCollection]
904+
):
905+
"""
906+
Sharder for QuantManagedCollisionEmbeddingBagCollection.
907+
908+
This implementation uses non-fused EmbeddingBagCollection and manages both
909+
embedding bag collection sharding and managed collision collection sharding.
910+
911+
Args:
912+
e_sharder (QuantEmbeddingBagCollectionSharder): sharder for embedding bag collection
913+
mc_sharder (InferManagedCollisionCollectionSharder): sharder for managed collision collection
914+
915+
Example::
916+
917+
"""
918+
919+
def __init__(
920+
self,
921+
e_sharder: QuantEmbeddingBagCollectionSharder,
922+
mc_sharder: InferManagedCollisionCollectionSharder,
923+
) -> None:
924+
super().__init__()
925+
self._e_sharder: QuantEmbeddingBagCollectionSharder = e_sharder
926+
self._mc_sharder: InferManagedCollisionCollectionSharder = mc_sharder
927+
928+
def shardable_parameters(
929+
self, module: QuantManagedCollisionEmbeddingBagCollection
930+
) -> Dict[str, torch.nn.Parameter]:
931+
return self._e_sharder.shardable_parameters(module)
932+
933+
def compute_kernels(
934+
self,
935+
sharding_type: str,
936+
compute_device_type: str,
937+
) -> List[str]:
938+
return [
939+
EmbeddingComputeKernel.QUANT.value,
940+
]
941+
942+
def sharding_types(self, compute_device_type: str) -> List[str]:
943+
return list(
944+
set.intersection(
945+
set(self._e_sharder.sharding_types(compute_device_type)),
946+
set(self._mc_sharder.sharding_types(compute_device_type)),
947+
)
948+
)
949+
950+
@property
951+
def fused_params(self) -> Optional[Dict[str, Any]]:
952+
# TODO: to be deprecate after planner get cache_load_factor from ParameterConstraints
953+
return self._e_sharder.fused_params
954+
955+
def shard(
956+
self,
957+
module: QuantManagedCollisionEmbeddingBagCollection,
958+
params: Dict[str, ParameterSharding],
959+
env: Union[ShardingEnv, Dict[str, ShardingEnv]],
960+
device: Optional[torch.device] = None,
961+
module_fqn: Optional[str] = None,
962+
) -> ShardedQuantManagedCollisionEmbeddingBagCollection:
963+
fused_params = self.fused_params if self.fused_params else {}
964+
fused_params["output_dtype"] = data_type_to_sparse_type(
965+
dtype_to_data_type(module.output_dtype())
966+
)
967+
if FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS not in fused_params:
968+
fused_params[FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS] = getattr(
969+
module,
970+
MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS,
971+
False,
972+
)
973+
if FUSED_PARAM_REGISTER_TBE_BOOL not in fused_params:
974+
fused_params[FUSED_PARAM_REGISTER_TBE_BOOL] = getattr(
975+
module, FUSED_PARAM_REGISTER_TBE_BOOL, False
976+
)
977+
return ShardedQuantManagedCollisionEmbeddingBagCollection(
978+
module,
979+
params,
980+
self._mc_sharder,
981+
env,
982+
fused_params,
983+
device,
984+
)
985+
986+
@property
987+
def module_type(self) -> Type[QuantManagedCollisionEmbeddingBagCollection]:
988+
return QuantManagedCollisionEmbeddingBagCollection

torchrec/distributed/quant_state.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -470,11 +470,14 @@ def sharded_tbes_weights_spec(
470470
is_sfpebc: bool = (
471471
"ShardedQuantFeatureProcessedEmbeddingBagCollection" in type_name
472472
)
473+
is_sqmcebc: bool = (
474+
"ShardedQuantManagedCollisionEmbeddingBagCollection" in type_name
475+
)
473476

474-
if is_sqebc or is_sqec or is_sqmcec or is_sfpebc:
477+
if is_sqebc or is_sqec or is_sqmcec or is_sqebc or is_sqmcebc:
475478
assert (
476-
is_sqec + is_sqebc + is_sqmcec + is_sfpebc == 1
477-
), "Cannot have any two of ShardedQuantEmbeddingBagCollection, ShardedQuantEmbeddingCollection, ShardedQuantManagedCollisionEmbeddingCollection and ShardedQuantFeatureProcessedEmbeddingBagCollection are true"
479+
is_sqec + is_sqebc + is_sqmcec + is_sfpebc + is_sqmcebc == 1
480+
), "Cannot have any two of ShardedQuantEmbeddingBagCollection, ShardedQuantEmbeddingCollection, ShardedQuantManagedCollisionEmbeddingCollection, ShardedQuantFeatureProcessedEmbeddingBagCollection and ShardedQuantManagedCollisionEmbeddingBagCollection are true"
478481
tbes_configs: Dict[
479482
IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig
480483
] = module.tbes_configs()

0 commit comments

Comments
 (0)