diff --git a/torchrec/distributed/fp_embeddingbag.py b/torchrec/distributed/fp_embeddingbag.py index 4b069437f..3d7fd4140 100644 --- a/torchrec/distributed/fp_embeddingbag.py +++ b/torchrec/distributed/fp_embeddingbag.py @@ -8,7 +8,18 @@ # pyre-strict from functools import partial -from typing import Any, Dict, Iterator, List, Optional, Type, Union +from typing import ( + Any, + Dict, + Iterator, + List, + Mapping, + Optional, + Tuple, + Type, + TypeVar, + Union, +) import torch from torch import nn @@ -31,7 +42,11 @@ ShardingEnv, ShardingType, ) -from torchrec.distributed.utils import append_prefix, init_parameters +from torchrec.distributed.utils import ( + append_prefix, + init_parameters, + modify_input_for_feature_processor, +) from torchrec.modules.feature_processor_ import FeatureProcessorsCollection from torchrec.modules.fp_embedding_modules import ( apply_feature_processors_to_kjt, @@ -39,6 +54,8 @@ ) from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor +_T = TypeVar("_T") + def param_dp_sync(kt: KeyedTensor, no_op_tensor: torch.Tensor) -> KeyedTensor: kt._values.add_(no_op_tensor) @@ -74,6 +91,16 @@ def __init__( ) ) + self._row_wise_sharded: bool = False + for param_sharding in table_name_to_parameter_sharding.values(): + if param_sharding.sharding_type in [ + ShardingType.ROW_WISE.value, + ShardingType.TABLE_ROW_WISE.value, + ShardingType.GRID_SHARD.value, + ]: + self._row_wise_sharded = True + break + self._lookups: List[nn.Module] = self._embedding_bag_collection._lookups self._is_collection: bool = False @@ -96,6 +123,11 @@ def __init__( def input_dist( self, ctx: EmbeddingBagCollectionContext, features: KeyedJaggedTensor ) -> Awaitable[Awaitable[KJTList]]: + if not self.is_pipelined and self._row_wise_sharded: + # transform input to support row based sharding when not pipelined + modify_input_for_feature_processor( + features, self._feature_processors, self._is_collection + ) return self._embedding_bag_collection.input_dist(ctx, features) def apply_feature_processors_to_kjt_list(self, dist_input: KJTList) -> KJTList: @@ -105,10 +137,7 @@ def apply_feature_processors_to_kjt_list(self, dist_input: KJTList) -> KJTList: kjt_list.append(self._feature_processors(features)) else: kjt_list.append( - apply_feature_processors_to_kjt( - features, - self._feature_processors, - ) + apply_feature_processors_to_kjt(features, self._feature_processors) ) return KJTList(kjt_list) @@ -117,7 +146,6 @@ def compute( ctx: EmbeddingBagCollectionContext, dist_input: KJTList, ) -> List[torch.Tensor]: - fp_features = self.apply_feature_processors_to_kjt_list(dist_input) return self._embedding_bag_collection.compute(ctx, fp_features) @@ -166,6 +194,18 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]: def _initialize_torch_state(self, skip_registering: bool = False) -> None: # noqa self._embedding_bag_collection._initialize_torch_state(skip_registering) + def preprocess_input( + self, args: List[_T], kwargs: Mapping[str, _T] + ) -> Tuple[List[_T], Mapping[str, _T]]: + for x in args + list(kwargs.values()): + if isinstance(x, KeyedJaggedTensor): + modify_input_for_feature_processor( + features=x, + feature_processors=self._feature_processors, + is_collection=self._is_collection, + ) + return args, kwargs + class FeatureProcessedEmbeddingBagCollectionSharder( BaseEmbeddingSharder[FeatureProcessedEmbeddingBagCollection] @@ -191,7 +231,6 @@ def shard( device: Optional[torch.device] = None, module_fqn: Optional[str] = None, ) -> ShardedFeatureProcessedEmbeddingBagCollection: - if device is None: device = torch.device("cuda") @@ -228,12 +267,14 @@ def sharding_types(self, compute_device_type: str) -> List[str]: if compute_device_type in {"mtia"}: return [ShardingType.TABLE_WISE.value, ShardingType.COLUMN_WISE.value] - # No row wise because position weighted FP and RW don't play well together. types = [ ShardingType.DATA_PARALLEL.value, ShardingType.TABLE_WISE.value, ShardingType.COLUMN_WISE.value, ShardingType.TABLE_COLUMN_WISE.value, + ShardingType.TABLE_ROW_WISE.value, + ShardingType.ROW_WISE.value, + ShardingType.GRID_SHARD.value, ] return types diff --git a/torchrec/distributed/tests/test_fp_embeddingbag.py b/torchrec/distributed/tests/test_fp_embeddingbag.py index 130776919..08f5dfdbb 100644 --- a/torchrec/distributed/tests/test_fp_embeddingbag.py +++ b/torchrec/distributed/tests/test_fp_embeddingbag.py @@ -231,7 +231,6 @@ class ShardedEmbeddingBagCollectionParallelTest(MultiProcessTestBase): def test_sharding_ebc( self, set_gradient_division: bool, use_dmp: bool, use_fp_collection: bool ) -> None: - import hypothesis # don't need to test entire matrix diff --git a/torchrec/distributed/tests/test_fp_embeddingbag_utils.py b/torchrec/distributed/tests/test_fp_embeddingbag_utils.py index 8efacdbb8..f7027b198 100644 --- a/torchrec/distributed/tests/test_fp_embeddingbag_utils.py +++ b/torchrec/distributed/tests/test_fp_embeddingbag_utils.py @@ -86,7 +86,12 @@ def forward(self, kjt: KeyedJaggedTensor) -> Tuple[torch.Tensor, torch.Tensor]: pred = torch.cat( [ fp_ebc_out[key] - for key in ["feature_0", "feature_1", "feature_2", "feature_3"] + for key in [ + "feature_0", + "feature_1", + "feature_2", + "feature_3", + ] ], dim=1, ) diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py index a0ea00132..4728e831e 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py @@ -22,7 +22,10 @@ from torch._dynamo.testing import reduce_to_scalar_loss from torch._dynamo.utils import counters from torchrec.distributed import DistributedModelParallel -from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.embedding_types import ( + EmbeddingComputeKernel, + EmbeddingTableConfig, +) from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder from torchrec.distributed.fp_embeddingbag import ( FeatureProcessedEmbeddingBagCollectionSharder, @@ -31,8 +34,13 @@ from torchrec.distributed.model_parallel import DMPCollection from torchrec.distributed.sharding_plan import ( construct_module_sharding_plan, + row_wise, table_wise, ) +from torchrec.distributed.test_utils.multi_process import ( + MultiProcessContext, + MultiProcessTestBase, +) from torchrec.distributed.test_utils.test_model import ( ModelInput, TestEBCSharder, @@ -331,6 +339,161 @@ def test_equal_to_non_pipelined_with_input_transformer(self) -> None: torch.testing.assert_close(pred_gpu.cpu(), pred) +def fp_ebc( + rank: int, + world_size: int, + tables: List[EmbeddingTableConfig], + weighted_tables: List[EmbeddingTableConfig], + data: List[Tuple[ModelInput, List[ModelInput]]], + backend: str = "nccl", + local_size: Optional[int] = None, +) -> None: + with MultiProcessContext(rank, world_size, backend, local_size) as ctx: + assert ctx.pg is not None + sharder = cast( + ModuleSharder[nn.Module], + FeatureProcessedEmbeddingBagCollectionSharder(), + ) + + class DummyWrapper(nn.Module): + def __init__(self, sparse_arch): + super().__init__() + self.m = sparse_arch + + def forward(self, model_input) -> Tuple[torch.Tensor, torch.Tensor]: + return self.m(model_input.idlist_features) + + max_feature_lengths = [10, 10, 12, 12] + sparse_arch = DummyWrapper( + create_module_and_freeze( + tables=tables, # pyre-ignore[6] + device=ctx.device, + use_fp_collection=False, + max_feature_lengths=max_feature_lengths, + ) + ) + + # compute_kernel = EmbeddingComputeKernel.FUSED.value + module_sharding_plan = construct_module_sharding_plan( + sparse_arch.m._fp_ebc, + per_param_sharding={ + "table_0": row_wise(), + "table_1": row_wise(), + "table_2": row_wise(), + "table_3": row_wise(), + }, + world_size=2, + device_type=ctx.device.type, + sharder=sharder, + ) + sharded_sparse_arch_pipeline = DistributedModelParallel( + module=copy.deepcopy(sparse_arch), + plan=ShardingPlan({"m._fp_ebc": module_sharding_plan}), + env=ShardingEnv.from_process_group(ctx.pg), # pyre-ignore[6] + sharders=[sharder], + device=ctx.device, + ) + sharded_sparse_arch_no_pipeline = DistributedModelParallel( + module=copy.deepcopy(sparse_arch), + plan=ShardingPlan({"m._fp_ebc": module_sharding_plan}), + env=ShardingEnv.from_process_group(ctx.pg), # pyre-ignore[6] + sharders=[sharder], + device=ctx.device, + ) + + batches = [] + for d in data: + batches.append(d[1][ctx.rank].to(ctx.device)) + dataloader = iter(batches) + + optimizer_no_pipeline = optim.SGD( + sharded_sparse_arch_no_pipeline.parameters(), lr=0.1 + ) + optimizer_pipeline = optim.SGD( + sharded_sparse_arch_pipeline.parameters(), lr=0.1 + ) + + pipeline = TrainPipelineSparseDist( + sharded_sparse_arch_pipeline, + optimizer_pipeline, + ctx.device, + ) + + for batch in batches[:-2]: + batch = batch.to(ctx.device) + optimizer_no_pipeline.zero_grad() + loss, pred = sharded_sparse_arch_no_pipeline(batch) + loss.backward() + optimizer_no_pipeline.step() + + pred_pipeline = pipeline.progress(dataloader) + torch.testing.assert_close(pred_pipeline.cpu(), pred.cpu()) + + +class TrainPipelineGPUTest(MultiProcessTestBase): + def setUp(self, backend: str = "nccl") -> None: + super().setUp() + + self.pipeline_class = TrainPipelineSparseDist + num_features = 4 + num_weighted_features = 4 + self.tables = [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 100, + embedding_dim=(i + 1) * 4, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(num_features) + ] + self.weighted_tables = [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 100, + embedding_dim=(i + 1) * 4, + name="weighted_table_" + str(i), + feature_names=["weighted_feature_" + str(i)], + ) + for i in range(num_weighted_features) + ] + + self.backend = backend + if torch.cuda.is_available(): + self.device = torch.device("cuda") + else: + self.device = torch.device("cpu") + + if self.backend == "nccl" and self.device == torch.device("cpu"): + self.skipTest("NCCL not supported on CPUs.") + + def _generate_data( + self, + num_batches: int = 5, + batch_size: int = 1, + max_feature_lengths: Optional[List[int]] = None, + ) -> List[Tuple[ModelInput, List[ModelInput]]]: + return [ + ModelInput.generate( + tables=self.tables, + weighted_tables=self.weighted_tables, + batch_size=batch_size, + world_size=2, + num_float_features=10, + max_feature_lengths=max_feature_lengths, + ) + for i in range(num_batches) + ] + + def test_fp_ebc_rw(self) -> None: + data = self._generate_data(max_feature_lengths=[10, 10, 12, 12]) + self._run_multi_process_test( + callable=fp_ebc, + world_size=2, + tables=self.tables, + weighted_tables=self.weighted_tables, + data=data, + ) + + class TrainPipelineSparseDistTest(TrainPipelineSparseDistTestBase): # pyre-fixme[56]: Pyre was not able to infer the type of argument @unittest.skipIf( diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py index 56e6ac636..85148a480 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py @@ -40,7 +40,7 @@ def setUp(self) -> None: self.pg = init_distributed_single_host(backend=backend, rank=0, world_size=1) num_features = 4 - num_weighted_features = 2 + num_weighted_features = 4 self.tables = [ EmbeddingBagConfig( num_embeddings=(i + 1) * 100, diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index 08e5c2aab..a40356e30 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -147,6 +147,7 @@ def _start_data_dist( # and this info was done in the _rewrite_model by tracing the # entire model to get the arg_info_list args, kwargs = forward.args.build_args_kwargs(batch) + args, kwargs = module.preprocess_input(args, kwargs) # Start input distribution. module_ctx = module.create_context() @@ -382,6 +383,8 @@ def _rewrite_model( # noqa C901 logger.info(f"Module '{node.target}' will be pipelined") child = sharded_modules[node.target] original_forwards.append(child.forward) + # Set pipelining flag on the child module + child.is_pipelined = True # pyre-ignore[8] Incompatible attribute type child.forward = pipelined_forward( node.target, diff --git a/torchrec/distributed/types.py b/torchrec/distributed/types.py index 46521ca6c..82b528130 100644 --- a/torchrec/distributed/types.py +++ b/torchrec/distributed/types.py @@ -19,7 +19,10 @@ Generic, Iterator, List, + Mapping, Optional, + ParamSpec, + Sequence, Tuple, Type, TypeVar, @@ -78,6 +81,8 @@ class GenericMeta(type): ) from torchrec.streamable import Multistreamable +_T = TypeVar("_T") + def _tabulate( table: List[List[Union[str, int]]], headers: Optional[List[str]] = None @@ -1015,6 +1020,8 @@ def __init__( if qcomm_codecs_registry is None: qcomm_codecs_registry = {} self._qcomm_codecs_registry = qcomm_codecs_registry + # In pipelining, this flag is flipped in rewrite_model when the forward is replaced with the pipelined forward + self.is_pipelined = False @abc.abstractmethod def create_context(self) -> ShrdCtx: @@ -1117,6 +1124,19 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]: for key, _ in self.named_parameters(prefix): yield key + def preprocess_input( + self, + args: List[_T], + kwargs: Mapping[str, _T], + ) -> Tuple[List[_T], Mapping[str, _T]]: + """ + This function can be used to preprocess the input arguments prior to module forward call. + + For example, it is used in ShardedFeatureProcessorEmbeddingBagCollection to transform the input data + prior to the forward call. + """ + return args, kwargs + @property @abc.abstractmethod def unsharded_module_type(self) -> Type[nn.Module]: diff --git a/torchrec/distributed/utils.py b/torchrec/distributed/utils.py index b12660e97..e69a88371 100644 --- a/torchrec/distributed/utils.py +++ b/torchrec/distributed/utils.py @@ -26,8 +26,10 @@ from torch import nn from torch.autograd.profiler import record_function from torchrec import optim as trec_optim -from torchrec.distributed.embedding_types import EmbeddingComputeKernel - +from torchrec.distributed.embedding_types import ( + EmbeddingComputeKernel, + KeyedJaggedTensor, +) from torchrec.distributed.types import ( DataType, EmbeddingEvent, @@ -38,6 +40,7 @@ ShardMetadata, ) from torchrec.modules.embedding_configs import data_type_to_sparse_type +from torchrec.modules.feature_processor_ import FeatureProcessorsCollection from torchrec.types import CopyMixIn logger: logging.Logger = logging.getLogger(__name__) @@ -758,3 +761,47 @@ def _recalculate_torch_state_helper( _recalculate_torch_state_helper(child) _recalculate_torch_state_helper(module) + emb_kernel.weights_precision = converted_sparse_dtype # pyre-ignore [16] + + +def modify_input_for_feature_processor( + features: KeyedJaggedTensor, + feature_processors: Union[nn.ModuleDict, FeatureProcessorsCollection], + is_collection: bool, +) -> None: + """ + This function applies the feature processor pre input dist. This way we + can support row wise based sharding mechanisms. + + This is an inplace modifcation of the input KJT. + """ + with torch.no_grad(): + if features.weights_or_none() is None: + # force creation of weights, this way the feature jagged tensor weights are tied to the original KJT + features._weights = torch.zeros_like(features.values(), dtype=torch.float32) + + if is_collection: + if hasattr(feature_processors, "pre_process_pipeline_input"): + feature_processors.pre_process_pipeline_input(features) # pyre-ignore[29] + else: + logging.info( + f"[Feature Processor Pipeline] Skipping pre_process_pipeline_input for feature processor {feature_processors=}" + ) + else: + # per feature process + for feature in features.keys(): + if feature in feature_processors: # pyre-ignore[58] + feature_processor = feature_processors[feature] # pyre-ignore[29] + if hasattr(feature_processor, "pre_process_pipeline_input"): + feature_processor.pre_process_pipeline_input(features[feature]) + else: + logging.info( + f"[Feature Processor Pipeline] Skipping pre_process_pipeline_input for feature processor {feature_processor=}" + ) + else: + features[feature].weights().copy_( + torch.ones( + features[feature].values().shape[0], + device=features[feature].values().device, + ) + ) diff --git a/torchrec/modules/feature_processor_.py b/torchrec/modules/feature_processor_.py index 707f5bd2b..f064ad5e3 100644 --- a/torchrec/modules/feature_processor_.py +++ b/torchrec/modules/feature_processor_.py @@ -14,7 +14,7 @@ import torch -from torch import nn +from torch import distributed as dist, nn from torch.nn.modules.module import _IncompatibleKeys from torchrec.pt2.checks import is_non_strict_exporting @@ -72,6 +72,7 @@ def __init__( torch.empty([max_feature_length], device=device), requires_grad=True, ) + self.pipelined = False self.reset_parameters() @@ -85,15 +86,18 @@ def forward( ) -> JaggedTensor: """ Args: - features (JaggedTensor]): feature representation + features (JaggedTensor): feature representation Returns: JaggedTensor: same as input features with `weights` field being populated. """ - - seq = torch.ops.fbgemm.offsets_range( - features.offsets().long(), torch.numel(features.values()) - ) + if self.pipelined: + # position is embedded as weights + seq = features.weights().clone().to(torch.int64) + else: + seq = torch.ops.fbgemm.offsets_range( + features.offsets().long(), torch.numel(features.values()) + ) weighted_features = JaggedTensor( values=features.values(), lengths=features.lengths(), @@ -102,6 +106,20 @@ def forward( ) return weighted_features + def pre_process_pipeline_input(self, features: JaggedTensor) -> None: + """ + Args: + features (JaggedTensor]): feature representation + + Returns: + torch.Tensor: position weights + """ + self.pipelined = True + cat_seq = torch.ops.fbgemm.offsets_range( + features.offsets().long(), torch.numel(features.values()) + ) + features.weights().copy_(cat_seq.to(torch.float32)) + class FeatureProcessorsCollection(nn.Module): """ @@ -169,7 +187,7 @@ def __init__( for length in self.max_feature_lengths.values(): if length <= 0: raise - + self.pipelined = False # if pipelined, input dist has performed part of input feature processing self.position_weights: nn.ParameterDict = nn.ParameterDict() # needed since nn.ParameterDict isn't torchscriptable (get_items) self.position_weights_dict: Dict[str, nn.Parameter] = {} @@ -191,7 +209,6 @@ def reset_parameters(self) -> None: self.position_weights_dict[key] = self.position_weights[key] def forward(self, features: KeyedJaggedTensor) -> KeyedJaggedTensor: - # TODO unflattener doesnt work well with aten.to at submodule boundaries if is_non_strict_exporting(): offsets = features.offsets() if offsets.dtype == torch.int64: @@ -203,9 +220,12 @@ def forward(self, features: KeyedJaggedTensor) -> KeyedJaggedTensor: features.offsets().long(), torch.numel(features.values()) ) else: - cat_seq = torch.ops.fbgemm.offsets_range( - features.offsets().long(), torch.numel(features.values()) - ) + if self.pipelined: + cat_seq = features.weights().clone().to(torch.int64) + else: + cat_seq = torch.ops.fbgemm.offsets_range( + features.offsets().long(), torch.numel(features.values()) + ) return KeyedJaggedTensor( keys=features.keys(), @@ -245,3 +265,10 @@ def load_state_dict( for k, param in self.position_weights.items(): self.position_weights_dict[k] = param return result + + def pre_process_pipeline_input(self, features: KeyedJaggedTensor) -> None: + self.pipelined = True + cat_seq = torch.ops.fbgemm.offsets_range( + features.offsets().long(), torch.numel(features.values()) + ) + features.weights().copy_(cat_seq.to(torch.float32))