Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 17 additions & 41 deletions torchrec/distributed/benchmark/benchmark_train_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,17 @@
from torchrec.distributed.benchmark.benchmark_utils import (
BaseModelConfig,
create_model_config,
generate_data,
generate_planner,
generate_sharded_model_and_optimizer,
)
from torchrec.distributed.comm import get_local_size
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.planner import Topology
from torchrec.distributed.test_utils.input_config import ModelInputConfig
from torchrec.distributed.test_utils.model_input import ModelInput

from torchrec.distributed.test_utils.multi_process import (
MultiProcessContext,
run_multi_process_func,
)
from torchrec.distributed.test_utils.pipeline_config import PipelineConfig
from torchrec.distributed.test_utils.sharding_config import PlannerConfig
from torchrec.distributed.test_utils.table_config import EmbeddingTablesConfig
from torchrec.distributed.test_utils.test_model import TestOverArchLarge
from torchrec.distributed.train_pipeline import TrainPipeline
Expand Down Expand Up @@ -99,14 +96,11 @@ class RunOptions(BenchFuncConfig):
world_size: int = 2
num_batches: int = 10
sharding_type: ShardingType = ShardingType.TABLE_WISE
compute_kernel: EmbeddingComputeKernel = EmbeddingComputeKernel.FUSED
input_type: str = "kjt"
name: str = ""
profile_dir: str = ""
num_benchmarks: int = 5
num_profiles: int = 2
planner_type: str = "embedding"
pooling_factors: Optional[List[float]] = None
num_poolings: Optional[List[float]] = None
dense_optimizer: str = "SGD"
dense_lr: float = 0.1
Expand All @@ -124,7 +118,7 @@ class ModelSelectionConfig:
model_name: str = "test_sparse_nn"

# Common config for all model types
batch_size: int = 8192
batch_size: int = 1024 * 32
batch_sizes: Optional[List[int]] = None
num_float_features: int = 10
feature_pooling_avg: int = 10
Expand Down Expand Up @@ -161,6 +155,8 @@ def runner(
run_option: RunOptions,
model_config: BaseModelConfig,
pipeline_config: PipelineConfig,
input_config: ModelInputConfig,
planner_config: PlannerConfig,
) -> BenchmarkResult:
# Ensure GPUs are available and we have enough of them
assert (
Expand All @@ -180,39 +176,14 @@ def runner(
dense_device=ctx.device,
)

# Create a topology for sharding
topology = Topology(
local_world_size=get_local_size(world_size),
world_size=world_size,
compute_device=ctx.device.type,
)

batch_sizes = model_config.batch_sizes

if batch_sizes is None:
batch_sizes = [model_config.batch_size] * run_option.num_batches
else:
assert (
len(batch_sizes) == run_option.num_batches
), "The length of batch_sizes must match the number of batches."

# Create a planner for sharding based on the specified type
planner = generate_planner(
planner_type=run_option.planner_type,
topology=topology,
tables=tables,
weighted_tables=weighted_tables,
sharding_type=run_option.sharding_type,
compute_kernel=run_option.compute_kernel,
batch_sizes=batch_sizes,
pooling_factors=run_option.pooling_factors,
num_poolings=run_option.num_poolings,
planner = planner_config.generate_planner(
tables=tables + weighted_tables,
)
bench_inputs = generate_data(

bench_inputs = input_config.generate_batches(
tables=tables,
weighted_tables=weighted_tables,
model_config=model_config,
batch_sizes=batch_sizes,
)

# Prepare fused_params for sparse optimizer
Expand All @@ -230,8 +201,6 @@ def runner(

sharded_model, optimizer = generate_sharded_model_and_optimizer(
model=unsharded_model,
sharding_type=run_option.sharding_type.value,
kernel_type=run_option.compute_kernel.value,
# pyre-ignore
pg=ctx.pg,
device=ctx.device,
Expand Down Expand Up @@ -285,8 +254,9 @@ def run_pipeline(
table_config: EmbeddingTablesConfig,
pipeline_config: PipelineConfig,
model_config: BaseModelConfig,
input_config: ModelInputConfig,
planner_config: PlannerConfig,
) -> BenchmarkResult:

tables, weighted_tables, *_ = table_config.generate_tables()

benchmark_res_per_rank = run_multi_process_func(
Expand All @@ -297,6 +267,8 @@ def run_pipeline(
run_option=run_option,
model_config=model_config,
pipeline_config=pipeline_config,
input_config=input_config,
planner_config=planner_config,
)

# Combine results from all ranks into a single BenchmarkResult
Expand Down Expand Up @@ -329,6 +301,8 @@ def main(
table_config: EmbeddingTablesConfig,
model_selection: ModelSelectionConfig,
pipeline_config: PipelineConfig,
input_config: ModelInputConfig,
planner_config: PlannerConfig,
model_config: Optional[BaseModelConfig] = None,
) -> None:
tables, weighted_tables, *_ = table_config.generate_tables()
Expand Down Expand Up @@ -367,6 +341,8 @@ def main(
run_option=run_option,
model_config=model_config,
pipeline_config=pipeline_config,
input_config=input_config,
planner_config=planner_config,
)


Expand Down
148 changes: 5 additions & 143 deletions torchrec/distributed/benchmark/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,23 @@
import copy
from abc import ABC, abstractmethod
from dataclasses import dataclass, fields
from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Dict, List, Optional, Tuple, Type, Union

import torch
import torch.distributed as dist

from torch import nn, optim
from torch.optim import Optimizer
from torchrec.distributed import DistributedModelParallel
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.planner.constants import NUM_POOLINGS, POOLING_FACTOR
from torchrec.distributed.planner import EmbeddingShardingPlanner
from torchrec.distributed.planner.planners import HeteroEmbeddingShardingPlanner
from torchrec.distributed.planner.types import ParameterConstraints
from torchrec.distributed.test_utils.model_input import ModelInput
from torchrec.distributed.sharding_plan import get_default_sharders
from torchrec.distributed.test_utils.test_model import (
TestEBCSharder,
TestSparseNN,
TestTowerCollectionSparseNN,
TestTowerSparseNN,
)
from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingType
from torchrec.distributed.types import ShardingEnv
from torchrec.models.deepfm import SimpleDeepFMNNWrapper
from torchrec.models.dlrm import DLRMWrapper
from torchrec.modules.embedding_configs import EmbeddingBagConfig
Expand Down Expand Up @@ -240,137 +236,8 @@ def create_model_config(model_name: str, **kwargs) -> BaseModelConfig:
return model_class(**filtered_kwargs)


def generate_data(
tables: List[EmbeddingBagConfig],
weighted_tables: List[EmbeddingBagConfig],
model_config: BaseModelConfig,
batch_sizes: List[int],
) -> List[ModelInput]:
"""
Generate model input data for benchmarking.

Args:
tables: List of unweighted embedding tables
weighted_tables: List of weighted embedding tables
model_config: Configuration for model generation
num_batches: Number of batches to generate

Returns:
A list of ModelInput objects representing the generated batches
"""
device = torch.device(model_config.dev_str) if model_config.dev_str else None

return [
ModelInput.generate(
batch_size=batch_size,
tables=tables,
weighted_tables=weighted_tables,
num_float_features=model_config.num_float_features,
pooling_avg=model_config.feature_pooling_avg,
use_offsets=model_config.use_offsets,
device=device,
indices_dtype=(
torch.int64 if model_config.long_kjt_indices else torch.int32
),
offsets_dtype=(
torch.int64 if model_config.long_kjt_offsets else torch.int32
),
lengths_dtype=(
torch.int64 if model_config.long_kjt_lengths else torch.int32
),
pin_memory=model_config.pin_memory,
)
for batch_size in batch_sizes
]


def generate_planner(
planner_type: str,
topology: Topology,
tables: Optional[List[EmbeddingBagConfig]],
weighted_tables: Optional[List[EmbeddingBagConfig]],
sharding_type: ShardingType,
compute_kernel: EmbeddingComputeKernel,
batch_sizes: List[int],
pooling_factors: Optional[List[float]] = None,
num_poolings: Optional[List[float]] = None,
) -> Union[EmbeddingShardingPlanner, HeteroEmbeddingShardingPlanner]:
"""
Generate an embedding sharding planner based on the specified configuration.

Args:
planner_type: Type of planner to use ("embedding" or "hetero")
topology: Network topology for distributed training
tables: List of unweighted embedding tables
weighted_tables: List of weighted embedding tables
sharding_type: Strategy for sharding embedding tables
compute_kernel: Compute kernel to use for embedding tables
batch_sizes: Sizes of each batch
pooling_factors: Pooling factors for each feature of the table
num_poolings: Number of poolings for each feature of the table

Returns:
An instance of EmbeddingShardingPlanner or HeteroEmbeddingShardingPlanner

Raises:
RuntimeError: If an unknown planner type is specified
"""
# Create parameter constraints for tables
constraints = {}
num_batches = len(batch_sizes)

if pooling_factors is None:
pooling_factors = [POOLING_FACTOR] * num_batches

if num_poolings is None:
num_poolings = [NUM_POOLINGS] * num_batches

assert (
len(pooling_factors) == num_batches and len(num_poolings) == num_batches
), "The length of pooling_factors and num_poolings must match the number of batches."

if tables is not None:
for table in tables:
constraints[table.name] = ParameterConstraints(
sharding_types=[sharding_type.value],
compute_kernels=[compute_kernel.value],
device_group="cuda",
pooling_factors=pooling_factors,
num_poolings=num_poolings,
batch_sizes=batch_sizes,
)

if weighted_tables is not None:
for table in weighted_tables:
constraints[table.name] = ParameterConstraints(
sharding_types=[sharding_type.value],
compute_kernels=[compute_kernel.value],
device_group="cuda",
pooling_factors=pooling_factors,
num_poolings=num_poolings,
batch_sizes=batch_sizes,
is_weighted=True,
)

if planner_type == "embedding":
return EmbeddingShardingPlanner(
topology=topology,
constraints=constraints if constraints else None,
)
elif planner_type == "hetero":
topology_groups = {"cuda": topology}
return HeteroEmbeddingShardingPlanner(
topology_groups=topology_groups,
constraints=constraints if constraints else None,
)
else:
raise RuntimeError(f"Unknown planner type: {planner_type}")


def generate_sharded_model_and_optimizer(
model: nn.Module,
sharding_type: str,
kernel_type: str,
pg: dist.ProcessGroup,
device: torch.device,
fused_params: Dict[str, Any],
Expand Down Expand Up @@ -404,12 +271,7 @@ def generate_sharded_model_and_optimizer(
Returns:
Tuple of sharded model and optimizer
"""
sharder = TestEBCSharder(
sharding_type=sharding_type,
kernel_type=kernel_type,
fused_params=fused_params,
)
sharders = [cast(ModuleSharder[nn.Module], sharder)]
sharders = get_default_sharders()

# Use planner if provided
plan = None
Expand Down
22 changes: 14 additions & 8 deletions torchrec/distributed/benchmark/yaml/sparse_data_dist_base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
# runs on 2 ranks, showing traces with reasonable workloads
RunOptions:
world_size: 2
num_batches: 10
num_batches: 5
num_benchmarks: 2
sharding_type: table_wise
profile_dir: "."
name: "sparse_data_dist_base"
Expand All @@ -12,18 +13,23 @@ PipelineConfig:
EmbeddingTablesConfig:
num_unweighted_features: 100
num_weighted_features: 100
embedding_feature_dim: 128
embedding_feature_dim: 256
additional_tables:
- - name: additional_tables_0_0
embedding_dim: 128
- - name: FP16_table
embedding_dim: 512
num_embeddings: 100_000
feature_names: ["additional_0_0"]
- name: additional_tables_0_1
embedding_dim: 128
num_embeddings: 100_000
data_type: FP16
- name: large_table
embedding_dim: 2048
num_embeddings: 1_000_000
feature_names: ["additional_0_1"]
- []
- - name: additional_tables_2_1
- - name: skipped_table
embedding_dim: 128
num_embeddings: 100_000
feature_names: ["additional_2_1"]
PlannerConfig:
additional_constraints:
large_table:
sharding_types: [column_wise]
Loading
Loading