Skip to content

Commit 1025341

Browse files
TroyGardenmeta-codesync[bot]
authored andcommitted
move sharding/planner and model input generation into config (#3448)
Summary: Pull Request resolved: #3448 # context * make the planner configurable via config * make the model input configurable via config * SSD offloading enabled {F1982627020} > NOTE: currently there's no pipeline optimization for SSD prefetch, so the SSD lookup is expeced to be very long. [trace](https://drive.google.com/file/d/1CshrtpIip_yd_gYbso_ddFOVVw42V8JR/view?usp=sharing) Differential Revision: D84298150
1 parent 654811e commit 1025341

File tree

7 files changed

+269
-195
lines changed

7 files changed

+269
-195
lines changed

torchrec/distributed/benchmark/benchmark_train_pipeline.py

Lines changed: 15 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -37,20 +37,17 @@
3737
from torchrec.distributed.benchmark.benchmark_utils import (
3838
BaseModelConfig,
3939
create_model_config,
40-
generate_data,
41-
generate_planner,
4240
generate_sharded_model_and_optimizer,
4341
)
44-
from torchrec.distributed.comm import get_local_size
45-
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
46-
from torchrec.distributed.planner import Topology
42+
from torchrec.distributed.test_utils.input_config import ModelInputConfig
4743
from torchrec.distributed.test_utils.model_input import ModelInput
4844

4945
from torchrec.distributed.test_utils.multi_process import (
5046
MultiProcessContext,
5147
run_multi_process_func,
5248
)
5349
from torchrec.distributed.test_utils.pipeline_config import PipelineConfig
50+
from torchrec.distributed.test_utils.sharding_config import PlannerConfig
5451
from torchrec.distributed.test_utils.table_config import EmbeddingTablesConfig
5552
from torchrec.distributed.test_utils.test_model import TestOverArchLarge
5653
from torchrec.distributed.train_pipeline import TrainPipeline
@@ -99,14 +96,11 @@ class RunOptions(BenchFuncConfig):
9996
world_size: int = 2
10097
num_batches: int = 10
10198
sharding_type: ShardingType = ShardingType.TABLE_WISE
102-
compute_kernel: EmbeddingComputeKernel = EmbeddingComputeKernel.FUSED
10399
input_type: str = "kjt"
104100
name: str = ""
105101
profile_dir: str = ""
106102
num_benchmarks: int = 5
107103
num_profiles: int = 2
108-
planner_type: str = "embedding"
109-
pooling_factors: Optional[List[float]] = None
110104
num_poolings: Optional[List[float]] = None
111105
dense_optimizer: str = "SGD"
112106
dense_lr: float = 0.1
@@ -124,7 +118,7 @@ class ModelSelectionConfig:
124118
model_name: str = "test_sparse_nn"
125119

126120
# Common config for all model types
127-
batch_size: int = 8192
121+
batch_size: int = 1024 * 32
128122
batch_sizes: Optional[List[int]] = None
129123
num_float_features: int = 10
130124
feature_pooling_avg: int = 10
@@ -161,6 +155,8 @@ def runner(
161155
run_option: RunOptions,
162156
model_config: BaseModelConfig,
163157
pipeline_config: PipelineConfig,
158+
input_config: ModelInputConfig,
159+
planner_config: PlannerConfig,
164160
) -> BenchmarkResult:
165161
# Ensure GPUs are available and we have enough of them
166162
assert (
@@ -180,39 +176,14 @@ def runner(
180176
dense_device=ctx.device,
181177
)
182178

183-
# Create a topology for sharding
184-
topology = Topology(
185-
local_world_size=get_local_size(world_size),
186-
world_size=world_size,
187-
compute_device=ctx.device.type,
188-
)
189-
190-
batch_sizes = model_config.batch_sizes
191-
192-
if batch_sizes is None:
193-
batch_sizes = [model_config.batch_size] * run_option.num_batches
194-
else:
195-
assert (
196-
len(batch_sizes) == run_option.num_batches
197-
), "The length of batch_sizes must match the number of batches."
198-
199179
# Create a planner for sharding based on the specified type
200-
planner = generate_planner(
201-
planner_type=run_option.planner_type,
202-
topology=topology,
203-
tables=tables,
204-
weighted_tables=weighted_tables,
205-
sharding_type=run_option.sharding_type,
206-
compute_kernel=run_option.compute_kernel,
207-
batch_sizes=batch_sizes,
208-
pooling_factors=run_option.pooling_factors,
209-
num_poolings=run_option.num_poolings,
180+
planner = planner_config.generate_planner(
181+
tables=tables + weighted_tables,
210182
)
211-
bench_inputs = generate_data(
183+
184+
bench_inputs = input_config.generate_batches(
212185
tables=tables,
213186
weighted_tables=weighted_tables,
214-
model_config=model_config,
215-
batch_sizes=batch_sizes,
216187
)
217188

218189
# Prepare fused_params for sparse optimizer
@@ -230,8 +201,6 @@ def runner(
230201

231202
sharded_model, optimizer = generate_sharded_model_and_optimizer(
232203
model=unsharded_model,
233-
sharding_type=run_option.sharding_type.value,
234-
kernel_type=run_option.compute_kernel.value,
235204
# pyre-ignore
236205
pg=ctx.pg,
237206
device=ctx.device,
@@ -285,8 +254,8 @@ def run_pipeline(
285254
table_config: EmbeddingTablesConfig,
286255
pipeline_config: PipelineConfig,
287256
model_config: BaseModelConfig,
257+
input_config: ModelInputConfig,
288258
) -> BenchmarkResult:
289-
290259
tables, weighted_tables, *_ = table_config.generate_tables()
291260

292261
benchmark_res_per_rank = run_multi_process_func(
@@ -297,6 +266,7 @@ def run_pipeline(
297266
run_option=run_option,
298267
model_config=model_config,
299268
pipeline_config=pipeline_config,
269+
input_config=input_config,
300270
)
301271

302272
# Combine results from all ranks into a single BenchmarkResult
@@ -329,6 +299,8 @@ def main(
329299
table_config: EmbeddingTablesConfig,
330300
model_selection: ModelSelectionConfig,
331301
pipeline_config: PipelineConfig,
302+
input_config: ModelInputConfig,
303+
planner_config: PlannerConfig,
332304
model_config: Optional[BaseModelConfig] = None,
333305
) -> None:
334306
tables, weighted_tables, *_ = table_config.generate_tables()
@@ -367,6 +339,8 @@ def main(
367339
run_option=run_option,
368340
model_config=model_config,
369341
pipeline_config=pipeline_config,
342+
input_config=input_config,
343+
planner_config=planner_config,
370344
)
371345

372346

torchrec/distributed/benchmark/benchmark_utils.py

Lines changed: 5 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -19,27 +19,23 @@
1919
import copy
2020
from abc import ABC, abstractmethod
2121
from dataclasses import dataclass, fields
22-
from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union
22+
from typing import Any, Dict, List, Optional, Tuple, Type, Union
2323

2424
import torch
2525
import torch.distributed as dist
2626

2727
from torch import nn, optim
2828
from torch.optim import Optimizer
2929
from torchrec.distributed import DistributedModelParallel
30-
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
31-
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
32-
from torchrec.distributed.planner.constants import NUM_POOLINGS, POOLING_FACTOR
30+
from torchrec.distributed.planner import EmbeddingShardingPlanner
3331
from torchrec.distributed.planner.planners import HeteroEmbeddingShardingPlanner
34-
from torchrec.distributed.planner.types import ParameterConstraints
35-
from torchrec.distributed.test_utils.model_input import ModelInput
32+
from torchrec.distributed.sharding_plan import get_default_sharders
3633
from torchrec.distributed.test_utils.test_model import (
37-
TestEBCSharder,
3834
TestSparseNN,
3935
TestTowerCollectionSparseNN,
4036
TestTowerSparseNN,
4137
)
42-
from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingType
38+
from torchrec.distributed.types import ShardingEnv
4339
from torchrec.models.deepfm import SimpleDeepFMNNWrapper
4440
from torchrec.models.dlrm import DLRMWrapper
4541
from torchrec.modules.embedding_configs import EmbeddingBagConfig
@@ -240,137 +236,8 @@ def create_model_config(model_name: str, **kwargs) -> BaseModelConfig:
240236
return model_class(**filtered_kwargs)
241237

242238

243-
def generate_data(
244-
tables: List[EmbeddingBagConfig],
245-
weighted_tables: List[EmbeddingBagConfig],
246-
model_config: BaseModelConfig,
247-
batch_sizes: List[int],
248-
) -> List[ModelInput]:
249-
"""
250-
Generate model input data for benchmarking.
251-
252-
Args:
253-
tables: List of unweighted embedding tables
254-
weighted_tables: List of weighted embedding tables
255-
model_config: Configuration for model generation
256-
num_batches: Number of batches to generate
257-
258-
Returns:
259-
A list of ModelInput objects representing the generated batches
260-
"""
261-
device = torch.device(model_config.dev_str) if model_config.dev_str else None
262-
263-
return [
264-
ModelInput.generate(
265-
batch_size=batch_size,
266-
tables=tables,
267-
weighted_tables=weighted_tables,
268-
num_float_features=model_config.num_float_features,
269-
pooling_avg=model_config.feature_pooling_avg,
270-
use_offsets=model_config.use_offsets,
271-
device=device,
272-
indices_dtype=(
273-
torch.int64 if model_config.long_kjt_indices else torch.int32
274-
),
275-
offsets_dtype=(
276-
torch.int64 if model_config.long_kjt_offsets else torch.int32
277-
),
278-
lengths_dtype=(
279-
torch.int64 if model_config.long_kjt_lengths else torch.int32
280-
),
281-
pin_memory=model_config.pin_memory,
282-
)
283-
for batch_size in batch_sizes
284-
]
285-
286-
287-
def generate_planner(
288-
planner_type: str,
289-
topology: Topology,
290-
tables: Optional[List[EmbeddingBagConfig]],
291-
weighted_tables: Optional[List[EmbeddingBagConfig]],
292-
sharding_type: ShardingType,
293-
compute_kernel: EmbeddingComputeKernel,
294-
batch_sizes: List[int],
295-
pooling_factors: Optional[List[float]] = None,
296-
num_poolings: Optional[List[float]] = None,
297-
) -> Union[EmbeddingShardingPlanner, HeteroEmbeddingShardingPlanner]:
298-
"""
299-
Generate an embedding sharding planner based on the specified configuration.
300-
301-
Args:
302-
planner_type: Type of planner to use ("embedding" or "hetero")
303-
topology: Network topology for distributed training
304-
tables: List of unweighted embedding tables
305-
weighted_tables: List of weighted embedding tables
306-
sharding_type: Strategy for sharding embedding tables
307-
compute_kernel: Compute kernel to use for embedding tables
308-
batch_sizes: Sizes of each batch
309-
pooling_factors: Pooling factors for each feature of the table
310-
num_poolings: Number of poolings for each feature of the table
311-
312-
Returns:
313-
An instance of EmbeddingShardingPlanner or HeteroEmbeddingShardingPlanner
314-
315-
Raises:
316-
RuntimeError: If an unknown planner type is specified
317-
"""
318-
# Create parameter constraints for tables
319-
constraints = {}
320-
num_batches = len(batch_sizes)
321-
322-
if pooling_factors is None:
323-
pooling_factors = [POOLING_FACTOR] * num_batches
324-
325-
if num_poolings is None:
326-
num_poolings = [NUM_POOLINGS] * num_batches
327-
328-
assert (
329-
len(pooling_factors) == num_batches and len(num_poolings) == num_batches
330-
), "The length of pooling_factors and num_poolings must match the number of batches."
331-
332-
if tables is not None:
333-
for table in tables:
334-
constraints[table.name] = ParameterConstraints(
335-
sharding_types=[sharding_type.value],
336-
compute_kernels=[compute_kernel.value],
337-
device_group="cuda",
338-
pooling_factors=pooling_factors,
339-
num_poolings=num_poolings,
340-
batch_sizes=batch_sizes,
341-
)
342-
343-
if weighted_tables is not None:
344-
for table in weighted_tables:
345-
constraints[table.name] = ParameterConstraints(
346-
sharding_types=[sharding_type.value],
347-
compute_kernels=[compute_kernel.value],
348-
device_group="cuda",
349-
pooling_factors=pooling_factors,
350-
num_poolings=num_poolings,
351-
batch_sizes=batch_sizes,
352-
is_weighted=True,
353-
)
354-
355-
if planner_type == "embedding":
356-
return EmbeddingShardingPlanner(
357-
topology=topology,
358-
constraints=constraints if constraints else None,
359-
)
360-
elif planner_type == "hetero":
361-
topology_groups = {"cuda": topology}
362-
return HeteroEmbeddingShardingPlanner(
363-
topology_groups=topology_groups,
364-
constraints=constraints if constraints else None,
365-
)
366-
else:
367-
raise RuntimeError(f"Unknown planner type: {planner_type}")
368-
369-
370239
def generate_sharded_model_and_optimizer(
371240
model: nn.Module,
372-
sharding_type: str,
373-
kernel_type: str,
374241
pg: dist.ProcessGroup,
375242
device: torch.device,
376243
fused_params: Dict[str, Any],
@@ -404,12 +271,7 @@ def generate_sharded_model_and_optimizer(
404271
Returns:
405272
Tuple of sharded model and optimizer
406273
"""
407-
sharder = TestEBCSharder(
408-
sharding_type=sharding_type,
409-
kernel_type=kernel_type,
410-
fused_params=fused_params,
411-
)
412-
sharders = [cast(ModuleSharder[nn.Module], sharder)]
274+
sharders = get_default_sharders()
413275

414276
# Use planner if provided
415277
plan = None

torchrec/distributed/benchmark/yaml/sparse_data_dist_base.yml

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
# runs on 2 ranks, showing traces with reasonable workloads
33
RunOptions:
44
world_size: 2
5-
num_batches: 10
5+
num_batches: 5
6+
num_benchmarks: 2
67
sharding_type: table_wise
78
profile_dir: "."
89
name: "sparse_data_dist_base"
@@ -12,18 +13,23 @@ PipelineConfig:
1213
EmbeddingTablesConfig:
1314
num_unweighted_features: 100
1415
num_weighted_features: 100
15-
embedding_feature_dim: 128
16+
embedding_feature_dim: 256
1617
additional_tables:
17-
- - name: additional_tables_0_0
18-
embedding_dim: 128
18+
- - name: FP16_table
19+
embedding_dim: 512
1920
num_embeddings: 100_000
2021
feature_names: ["additional_0_0"]
21-
- name: additional_tables_0_1
22-
embedding_dim: 128
23-
num_embeddings: 100_000
22+
data_type: FP16
23+
- name: large_table
24+
embedding_dim: 2048
25+
num_embeddings: 1_000_000
2426
feature_names: ["additional_0_1"]
2527
- []
26-
- - name: additional_tables_2_1
28+
- - name: skipped_table
2729
embedding_dim: 128
2830
num_embeddings: 100_000
2931
feature_names: ["additional_2_1"]
32+
PlannerConfig:
33+
additional_constraints:
34+
large_table:
35+
sharding_types: [column_wise]

0 commit comments

Comments
 (0)