diff --git a/torchrec/distributed/benchmark/benchmark_train_pipeline.py b/torchrec/distributed/benchmark/benchmark_train_pipeline.py index 471057424..8248e4225 100644 --- a/torchrec/distributed/benchmark/benchmark_train_pipeline.py +++ b/torchrec/distributed/benchmark/benchmark_train_pipeline.py @@ -37,13 +37,9 @@ from torchrec.distributed.benchmark.benchmark_utils import ( BaseModelConfig, create_model_config, - generate_data, - generate_planner, generate_sharded_model_and_optimizer, ) -from torchrec.distributed.comm import get_local_size -from torchrec.distributed.embedding_types import EmbeddingComputeKernel -from torchrec.distributed.planner import Topology +from torchrec.distributed.test_utils.input_config import ModelInputConfig from torchrec.distributed.test_utils.model_input import ModelInput from torchrec.distributed.test_utils.multi_process import ( @@ -51,6 +47,7 @@ run_multi_process_func, ) from torchrec.distributed.test_utils.pipeline_config import PipelineConfig +from torchrec.distributed.test_utils.sharding_config import PlannerConfig from torchrec.distributed.test_utils.table_config import EmbeddingTablesConfig from torchrec.distributed.test_utils.test_model import TestOverArchLarge from torchrec.distributed.train_pipeline import TrainPipeline @@ -99,14 +96,11 @@ class RunOptions(BenchFuncConfig): world_size: int = 2 num_batches: int = 10 sharding_type: ShardingType = ShardingType.TABLE_WISE - compute_kernel: EmbeddingComputeKernel = EmbeddingComputeKernel.FUSED input_type: str = "kjt" name: str = "" profile_dir: str = "" num_benchmarks: int = 5 num_profiles: int = 2 - planner_type: str = "embedding" - pooling_factors: Optional[List[float]] = None num_poolings: Optional[List[float]] = None dense_optimizer: str = "SGD" dense_lr: float = 0.1 @@ -124,7 +118,7 @@ class ModelSelectionConfig: model_name: str = "test_sparse_nn" # Common config for all model types - batch_size: int = 8192 + batch_size: int = 1024 * 32 batch_sizes: Optional[List[int]] = None num_float_features: int = 10 feature_pooling_avg: int = 10 @@ -161,6 +155,8 @@ def runner( run_option: RunOptions, model_config: BaseModelConfig, pipeline_config: PipelineConfig, + input_config: ModelInputConfig, + planner_config: PlannerConfig, ) -> BenchmarkResult: # Ensure GPUs are available and we have enough of them assert ( @@ -180,39 +176,14 @@ def runner( dense_device=ctx.device, ) - # Create a topology for sharding - topology = Topology( - local_world_size=get_local_size(world_size), - world_size=world_size, - compute_device=ctx.device.type, - ) - - batch_sizes = model_config.batch_sizes - - if batch_sizes is None: - batch_sizes = [model_config.batch_size] * run_option.num_batches - else: - assert ( - len(batch_sizes) == run_option.num_batches - ), "The length of batch_sizes must match the number of batches." - # Create a planner for sharding based on the specified type - planner = generate_planner( - planner_type=run_option.planner_type, - topology=topology, - tables=tables, - weighted_tables=weighted_tables, - sharding_type=run_option.sharding_type, - compute_kernel=run_option.compute_kernel, - batch_sizes=batch_sizes, - pooling_factors=run_option.pooling_factors, - num_poolings=run_option.num_poolings, + planner = planner_config.generate_planner( + tables=tables + weighted_tables, ) - bench_inputs = generate_data( + + bench_inputs = input_config.generate_batches( tables=tables, weighted_tables=weighted_tables, - model_config=model_config, - batch_sizes=batch_sizes, ) # Prepare fused_params for sparse optimizer @@ -230,8 +201,6 @@ def runner( sharded_model, optimizer = generate_sharded_model_and_optimizer( model=unsharded_model, - sharding_type=run_option.sharding_type.value, - kernel_type=run_option.compute_kernel.value, # pyre-ignore pg=ctx.pg, device=ctx.device, @@ -285,8 +254,9 @@ def run_pipeline( table_config: EmbeddingTablesConfig, pipeline_config: PipelineConfig, model_config: BaseModelConfig, + input_config: ModelInputConfig, + planner_config: PlannerConfig, ) -> BenchmarkResult: - tables, weighted_tables, *_ = table_config.generate_tables() benchmark_res_per_rank = run_multi_process_func( @@ -297,6 +267,8 @@ def run_pipeline( run_option=run_option, model_config=model_config, pipeline_config=pipeline_config, + input_config=input_config, + planner_config=planner_config, ) # Combine results from all ranks into a single BenchmarkResult @@ -329,6 +301,8 @@ def main( table_config: EmbeddingTablesConfig, model_selection: ModelSelectionConfig, pipeline_config: PipelineConfig, + input_config: ModelInputConfig, + planner_config: PlannerConfig, model_config: Optional[BaseModelConfig] = None, ) -> None: tables, weighted_tables, *_ = table_config.generate_tables() @@ -367,6 +341,8 @@ def main( run_option=run_option, model_config=model_config, pipeline_config=pipeline_config, + input_config=input_config, + planner_config=planner_config, ) diff --git a/torchrec/distributed/benchmark/benchmark_utils.py b/torchrec/distributed/benchmark/benchmark_utils.py index c3af00305..dee9a9263 100644 --- a/torchrec/distributed/benchmark/benchmark_utils.py +++ b/torchrec/distributed/benchmark/benchmark_utils.py @@ -19,7 +19,7 @@ import copy from abc import ABC, abstractmethod from dataclasses import dataclass, fields -from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Tuple, Type, Union import torch import torch.distributed as dist @@ -27,19 +27,15 @@ from torch import nn, optim from torch.optim import Optimizer from torchrec.distributed import DistributedModelParallel -from torchrec.distributed.embedding_types import EmbeddingComputeKernel -from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology -from torchrec.distributed.planner.constants import NUM_POOLINGS, POOLING_FACTOR +from torchrec.distributed.planner import EmbeddingShardingPlanner from torchrec.distributed.planner.planners import HeteroEmbeddingShardingPlanner -from torchrec.distributed.planner.types import ParameterConstraints -from torchrec.distributed.test_utils.model_input import ModelInput +from torchrec.distributed.sharding_plan import get_default_sharders from torchrec.distributed.test_utils.test_model import ( - TestEBCSharder, TestSparseNN, TestTowerCollectionSparseNN, TestTowerSparseNN, ) -from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingType +from torchrec.distributed.types import ShardingEnv from torchrec.models.deepfm import SimpleDeepFMNNWrapper from torchrec.models.dlrm import DLRMWrapper from torchrec.modules.embedding_configs import EmbeddingBagConfig @@ -240,137 +236,8 @@ def create_model_config(model_name: str, **kwargs) -> BaseModelConfig: return model_class(**filtered_kwargs) -def generate_data( - tables: List[EmbeddingBagConfig], - weighted_tables: List[EmbeddingBagConfig], - model_config: BaseModelConfig, - batch_sizes: List[int], -) -> List[ModelInput]: - """ - Generate model input data for benchmarking. - - Args: - tables: List of unweighted embedding tables - weighted_tables: List of weighted embedding tables - model_config: Configuration for model generation - num_batches: Number of batches to generate - - Returns: - A list of ModelInput objects representing the generated batches - """ - device = torch.device(model_config.dev_str) if model_config.dev_str else None - - return [ - ModelInput.generate( - batch_size=batch_size, - tables=tables, - weighted_tables=weighted_tables, - num_float_features=model_config.num_float_features, - pooling_avg=model_config.feature_pooling_avg, - use_offsets=model_config.use_offsets, - device=device, - indices_dtype=( - torch.int64 if model_config.long_kjt_indices else torch.int32 - ), - offsets_dtype=( - torch.int64 if model_config.long_kjt_offsets else torch.int32 - ), - lengths_dtype=( - torch.int64 if model_config.long_kjt_lengths else torch.int32 - ), - pin_memory=model_config.pin_memory, - ) - for batch_size in batch_sizes - ] - - -def generate_planner( - planner_type: str, - topology: Topology, - tables: Optional[List[EmbeddingBagConfig]], - weighted_tables: Optional[List[EmbeddingBagConfig]], - sharding_type: ShardingType, - compute_kernel: EmbeddingComputeKernel, - batch_sizes: List[int], - pooling_factors: Optional[List[float]] = None, - num_poolings: Optional[List[float]] = None, -) -> Union[EmbeddingShardingPlanner, HeteroEmbeddingShardingPlanner]: - """ - Generate an embedding sharding planner based on the specified configuration. - - Args: - planner_type: Type of planner to use ("embedding" or "hetero") - topology: Network topology for distributed training - tables: List of unweighted embedding tables - weighted_tables: List of weighted embedding tables - sharding_type: Strategy for sharding embedding tables - compute_kernel: Compute kernel to use for embedding tables - batch_sizes: Sizes of each batch - pooling_factors: Pooling factors for each feature of the table - num_poolings: Number of poolings for each feature of the table - - Returns: - An instance of EmbeddingShardingPlanner or HeteroEmbeddingShardingPlanner - - Raises: - RuntimeError: If an unknown planner type is specified - """ - # Create parameter constraints for tables - constraints = {} - num_batches = len(batch_sizes) - - if pooling_factors is None: - pooling_factors = [POOLING_FACTOR] * num_batches - - if num_poolings is None: - num_poolings = [NUM_POOLINGS] * num_batches - - assert ( - len(pooling_factors) == num_batches and len(num_poolings) == num_batches - ), "The length of pooling_factors and num_poolings must match the number of batches." - - if tables is not None: - for table in tables: - constraints[table.name] = ParameterConstraints( - sharding_types=[sharding_type.value], - compute_kernels=[compute_kernel.value], - device_group="cuda", - pooling_factors=pooling_factors, - num_poolings=num_poolings, - batch_sizes=batch_sizes, - ) - - if weighted_tables is not None: - for table in weighted_tables: - constraints[table.name] = ParameterConstraints( - sharding_types=[sharding_type.value], - compute_kernels=[compute_kernel.value], - device_group="cuda", - pooling_factors=pooling_factors, - num_poolings=num_poolings, - batch_sizes=batch_sizes, - is_weighted=True, - ) - - if planner_type == "embedding": - return EmbeddingShardingPlanner( - topology=topology, - constraints=constraints if constraints else None, - ) - elif planner_type == "hetero": - topology_groups = {"cuda": topology} - return HeteroEmbeddingShardingPlanner( - topology_groups=topology_groups, - constraints=constraints if constraints else None, - ) - else: - raise RuntimeError(f"Unknown planner type: {planner_type}") - - def generate_sharded_model_and_optimizer( model: nn.Module, - sharding_type: str, - kernel_type: str, pg: dist.ProcessGroup, device: torch.device, fused_params: Dict[str, Any], @@ -404,12 +271,7 @@ def generate_sharded_model_and_optimizer( Returns: Tuple of sharded model and optimizer """ - sharder = TestEBCSharder( - sharding_type=sharding_type, - kernel_type=kernel_type, - fused_params=fused_params, - ) - sharders = [cast(ModuleSharder[nn.Module], sharder)] + sharders = get_default_sharders() # Use planner if provided plan = None diff --git a/torchrec/distributed/benchmark/yaml/sparse_data_dist_base.yml b/torchrec/distributed/benchmark/yaml/sparse_data_dist_base.yml index e59a91522..ac2a90a1d 100644 --- a/torchrec/distributed/benchmark/yaml/sparse_data_dist_base.yml +++ b/torchrec/distributed/benchmark/yaml/sparse_data_dist_base.yml @@ -2,7 +2,8 @@ # runs on 2 ranks, showing traces with reasonable workloads RunOptions: world_size: 2 - num_batches: 10 + num_batches: 5 + num_benchmarks: 2 sharding_type: table_wise profile_dir: "." name: "sparse_data_dist_base" @@ -12,18 +13,23 @@ PipelineConfig: EmbeddingTablesConfig: num_unweighted_features: 100 num_weighted_features: 100 - embedding_feature_dim: 128 + embedding_feature_dim: 256 additional_tables: - - - name: additional_tables_0_0 - embedding_dim: 128 + - - name: FP16_table + embedding_dim: 512 num_embeddings: 100_000 feature_names: ["additional_0_0"] - - name: additional_tables_0_1 - embedding_dim: 128 - num_embeddings: 100_000 + data_type: FP16 + - name: large_table + embedding_dim: 2048 + num_embeddings: 1_000_000 feature_names: ["additional_0_1"] - [] - - - name: additional_tables_2_1 + - - name: skipped_table embedding_dim: 128 num_embeddings: 100_000 feature_names: ["additional_2_1"] +PlannerConfig: + additional_constraints: + large_table: + sharding_types: [column_wise] diff --git a/torchrec/distributed/benchmark/yaml/sparse_data_dist_ssd.yml b/torchrec/distributed/benchmark/yaml/sparse_data_dist_ssd.yml new file mode 100644 index 000000000..fa1b39484 --- /dev/null +++ b/torchrec/distributed/benchmark/yaml/sparse_data_dist_ssd.yml @@ -0,0 +1,36 @@ +# this is a very basic sparse data dist config +# runs on 2 ranks, showing traces with reasonable workloads +RunOptions: + world_size: 2 + num_batches: 5 + num_benchmarks: 2 + sharding_type: table_wise + profile_dir: "." + name: "sparse_data_dist_base" + # export_stacks: True # enable this to export stack traces +PipelineConfig: + pipeline: "sparse" +EmbeddingTablesConfig: + num_unweighted_features: 100 + num_weighted_features: 100 + embedding_feature_dim: 256 + additional_tables: + - - name: FP16_table + embedding_dim: 512 + num_embeddings: 100_000 + feature_names: ["additional_0_0"] + data_type: FP16 + - name: large_table + embedding_dim: 256 + num_embeddings: 1_000_000 + feature_names: ["additional_0_1"] + - [] + - - name: skipped_table + embedding_dim: 128 + num_embeddings: 100_000 + feature_names: ["additional_2_1"] +PlannerConfig: + additional_constraints: + large_table: + compute_kernels: [key_value] + sharding_types: [row_wise] diff --git a/torchrec/distributed/test_utils/input_config.py b/torchrec/distributed/test_utils/input_config.py new file mode 100644 index 000000000..5beb5e21a --- /dev/null +++ b/torchrec/distributed/test_utils/input_config.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from dataclasses import dataclass, fields +from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union + +import torch +from torchrec.modules.embedding_configs import EmbeddingBagConfig + +from .model_input import ModelInput + + +@dataclass +class ModelInputConfig: + # fixed size model input + + num_batches: int + batch_size: int + num_float_features: int + feature_pooling_avg: int + device: Optional[str] = None + use_offsets: bool = False + long_kjt_indices: bool = True + long_kjt_offsets: bool = True + long_kjt_lengths: bool = True + pin_memory: bool = True + + def generate_batches( + self, + tables: List[EmbeddingBagConfig], + weighted_tables: List[EmbeddingBagConfig], + ) -> List[ModelInput]: + """ + Generate model input data for benchmarking. + + Args: + tables: List of unweighted embedding tables + weighted_tables: List of weighted embedding tables + model_config: Configuration for model generation + num_batches: Number of batches to generate + + Returns: + A list of ModelInput objects representing the generated batches + """ + device = torch.device(self.device) if self.device is not None else None + + return [ + ModelInput.generate( + batch_size=self.batch_size, + tables=tables, + weighted_tables=weighted_tables, + num_float_features=self.num_float_features, + pooling_avg=self.feature_pooling_avg, + use_offsets=self.use_offsets, + device=device, + indices_dtype=(torch.int64 if self.long_kjt_indices else torch.int32), + offsets_dtype=(torch.int64 if self.long_kjt_offsets else torch.int32), + lengths_dtype=(torch.int64 if self.long_kjt_lengths else torch.int32), + pin_memory=self.pin_memory, + ) + for batch_size in range(self.num_batches) + ] diff --git a/torchrec/distributed/test_utils/sharding_config.py b/torchrec/distributed/test_utils/sharding_config.py new file mode 100644 index 000000000..c7ac11df0 --- /dev/null +++ b/torchrec/distributed/test_utils/sharding_config.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple, Union + +from torchrec.distributed.comm import get_local_size + +from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology +from torchrec.distributed.planner.constants import POOLING_FACTOR +from torchrec.distributed.planner.planners import HeteroEmbeddingShardingPlanner +from torchrec.distributed.planner.types import ParameterConstraints +from torchrec.distributed.types import ShardingType +from torchrec.modules.embedding_configs import EmbeddingBagConfig, EmbeddingConfig + + +@dataclass +class PlannerConfig: + planner_type: str = "embedding" + world_size: int = 2 + device_group: str = "cuda" + pooling_factors: List[float] = field(default_factory=lambda: [POOLING_FACTOR]) + num_poolings: Optional[List[float]] = None + batch_sizes: Optional[List[int]] = None + compute_kernel: EmbeddingComputeKernel = EmbeddingComputeKernel.FUSED + sharding_type: ShardingType = ShardingType.TABLE_WISE + additional_constraints: Dict[str, Any] = field(default_factory=dict) + + def generate_topology(self, device_type: str) -> Topology: + """ + Generate a topology for distributed training. + + Returns: + A Topology object representing the network topology for distributed training + """ + local_world_size = get_local_size(self.world_size) + return Topology( + world_size=self.world_size, + local_world_size=local_world_size, + compute_device=device_type, + ) + + def table_to_constraint( + self, + table: Union[EmbeddingConfig, EmbeddingBagConfig], + kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[str, ParameterConstraints]: + default_kwargs = dict( + sharding_types=[self.sharding_type.value], + compute_kernels=[self.compute_kernel.value], + device_group=self.device_group, + pooling_factors=self.pooling_factors, + num_poolings=self.num_poolings, + batch_sizes=self.batch_sizes, + ) + if kwargs is None: + kwargs = default_kwargs + else: + kwargs = default_kwargs | kwargs + + constraint = ParameterConstraints(**kwargs) # pyre-ignore [6] + return table.name, constraint + + def generate_planner( + self, + tables: List[EmbeddingBagConfig], + ) -> Union[EmbeddingShardingPlanner, HeteroEmbeddingShardingPlanner]: + """ + Generate an embedding sharding planner based on the specified configuration. + + Args: + planner_type: Type of planner to use ("embedding" or "hetero") + topology: Network topology for distributed training + tables: List of unweighted embedding tables + weighted_tables: List of weighted embedding tables + sharding_type: Strategy for sharding embedding tables + compute_kernel: Compute kernel to use for embedding tables + batch_sizes: Sizes of each batch + pooling_factors: Pooling factors for each feature of the table + num_poolings: Number of poolings for each feature of the table + + Returns: + An instance of EmbeddingShardingPlanner or HeteroEmbeddingShardingPlanner + + Raises: + RuntimeError: If an unknown planner type is specified + """ + # Create parameter constraints for tables + constraints = {} + + topology = self.generate_topology(self.device_group) + + for table in tables: + name, cons = self.table_to_constraint( + table, self.additional_constraints.get(table.name, None) + ) + constraints[name] = cons + + if self.planner_type == "embedding": + return EmbeddingShardingPlanner( + topology=topology, + constraints=constraints if constraints else None, + ) + elif self.planner_type == "hetero": + topology_groups = {self.device_group: topology} + return HeteroEmbeddingShardingPlanner( + topology_groups=topology_groups, + constraints=constraints if constraints else None, + ) + else: + raise RuntimeError(f"Unknown planner type: {self.planner_type}") diff --git a/torchrec/distributed/test_utils/table_config.py b/torchrec/distributed/test_utils/table_config.py index 2954ed085..764e74370 100644 --- a/torchrec/distributed/test_utils/table_config.py +++ b/torchrec/distributed/test_utils/table_config.py @@ -11,6 +11,7 @@ from typing import Any, Dict, List from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.types import DataType @dataclass @@ -35,8 +36,16 @@ class EmbeddingTablesConfig: num_unweighted_features: int = 100 num_weighted_features: int = 100 embedding_feature_dim: int = 128 + table_data_type: DataType = DataType.FP32 additional_tables: List[List[Dict[str, Any]]] = field(default_factory=list) + def convert_to_ebconf(self, kwargs: Dict[str, Any]) -> EmbeddingBagConfig: + if "data_type" in kwargs: + kwargs["data_type"] = DataType[kwargs["data_type"]] + else: + kwargs["data_type"] = self.table_data_type + return EmbeddingBagConfig(**kwargs) + def generate_tables( self, ) -> List[List[EmbeddingBagConfig]]: @@ -62,19 +71,21 @@ def generate_tables( """ unweighted_tables = [ EmbeddingBagConfig( - num_embeddings=max(i + 1, 100) * 1000, + num_embeddings=max(i + 1, 100) * 2000, embedding_dim=self.embedding_feature_dim, name="table_" + str(i), feature_names=["feature_" + str(i)], + data_type=self.table_data_type, ) for i in range(self.num_unweighted_features) ] weighted_tables = [ EmbeddingBagConfig( - num_embeddings=max(i + 1, 100) * 1000, + num_embeddings=max(i + 1, 100) * 2000, embedding_dim=self.embedding_feature_dim, name="weighted_table_" + str(i), feature_names=["weighted_feature_" + str(i)], + data_type=self.table_data_type, ) for i in range(self.num_weighted_features) ] @@ -87,7 +98,7 @@ def generate_tables( else: tables = [] for adt in adts: - tables.append(EmbeddingBagConfig(**adt)) + tables.append(self.convert_to_ebconf(adt)) if len(tables_list) == 0: tables_list.append(unweighted_tables)