Skip to content

Commit 82159fd

Browse files
TroyGardenmeta-codesync[bot]
authored andcommitted
move sharding/planner and model input generation into config
Summary: # context * make the planner configurable via config * make the model input configurable via config * SSD offloading enabled {F1982627020} > NOTE: currently there's no pipeline optimization for SSD prefetch, so the SSD lookup is expeced to be very long. [trace](https://drive.google.com/file/d/1CshrtpIip_yd_gYbso_ddFOVVw42V8JR/view?usp=sharing) Differential Revision: D84298150
1 parent 654811e commit 82159fd

File tree

7 files changed

+268
-193
lines changed

7 files changed

+268
-193
lines changed

torchrec/distributed/benchmark/benchmark_train_pipeline.py

Lines changed: 15 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -37,20 +37,17 @@
3737
from torchrec.distributed.benchmark.benchmark_utils import (
3838
BaseModelConfig,
3939
create_model_config,
40-
generate_data,
41-
generate_planner,
4240
generate_sharded_model_and_optimizer,
4341
)
44-
from torchrec.distributed.comm import get_local_size
45-
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
46-
from torchrec.distributed.planner import Topology
42+
from torchrec.distributed.test_utils.input_config import ModelInputConfig
4743
from torchrec.distributed.test_utils.model_input import ModelInput
4844

4945
from torchrec.distributed.test_utils.multi_process import (
5046
MultiProcessContext,
5147
run_multi_process_func,
5248
)
5349
from torchrec.distributed.test_utils.pipeline_config import PipelineConfig
50+
from torchrec.distributed.test_utils.sharding_config import PlannerConfig
5451
from torchrec.distributed.test_utils.table_config import EmbeddingTablesConfig
5552
from torchrec.distributed.test_utils.test_model import TestOverArchLarge
5653
from torchrec.distributed.train_pipeline import TrainPipeline
@@ -99,14 +96,11 @@ class RunOptions(BenchFuncConfig):
9996
world_size: int = 2
10097
num_batches: int = 10
10198
sharding_type: ShardingType = ShardingType.TABLE_WISE
102-
compute_kernel: EmbeddingComputeKernel = EmbeddingComputeKernel.FUSED
10399
input_type: str = "kjt"
104100
name: str = ""
105101
profile_dir: str = ""
106102
num_benchmarks: int = 5
107103
num_profiles: int = 2
108-
planner_type: str = "embedding"
109-
pooling_factors: Optional[List[float]] = None
110104
num_poolings: Optional[List[float]] = None
111105
dense_optimizer: str = "SGD"
112106
dense_lr: float = 0.1
@@ -124,7 +118,7 @@ class ModelSelectionConfig:
124118
model_name: str = "test_sparse_nn"
125119

126120
# Common config for all model types
127-
batch_size: int = 8192
121+
batch_size: int = 1024 * 32
128122
batch_sizes: Optional[List[int]] = None
129123
num_float_features: int = 10
130124
feature_pooling_avg: int = 10
@@ -161,6 +155,8 @@ def runner(
161155
run_option: RunOptions,
162156
model_config: BaseModelConfig,
163157
pipeline_config: PipelineConfig,
158+
input_config: ModelInputConfig,
159+
planner_config: PlannerConfig,
164160
) -> BenchmarkResult:
165161
# Ensure GPUs are available and we have enough of them
166162
assert (
@@ -180,39 +176,14 @@ def runner(
180176
dense_device=ctx.device,
181177
)
182178

183-
# Create a topology for sharding
184-
topology = Topology(
185-
local_world_size=get_local_size(world_size),
186-
world_size=world_size,
187-
compute_device=ctx.device.type,
188-
)
189-
190-
batch_sizes = model_config.batch_sizes
191-
192-
if batch_sizes is None:
193-
batch_sizes = [model_config.batch_size] * run_option.num_batches
194-
else:
195-
assert (
196-
len(batch_sizes) == run_option.num_batches
197-
), "The length of batch_sizes must match the number of batches."
198-
199179
# Create a planner for sharding based on the specified type
200-
planner = generate_planner(
201-
planner_type=run_option.planner_type,
202-
topology=topology,
203-
tables=tables,
204-
weighted_tables=weighted_tables,
205-
sharding_type=run_option.sharding_type,
206-
compute_kernel=run_option.compute_kernel,
207-
batch_sizes=batch_sizes,
208-
pooling_factors=run_option.pooling_factors,
209-
num_poolings=run_option.num_poolings,
180+
planner = planner_config.generate_planner(
181+
tables=tables + weighted_tables,
210182
)
211-
bench_inputs = generate_data(
183+
184+
bench_inputs = input_config.generate_batches(
212185
tables=tables,
213186
weighted_tables=weighted_tables,
214-
model_config=model_config,
215-
batch_sizes=batch_sizes,
216187
)
217188

218189
# Prepare fused_params for sparse optimizer
@@ -230,8 +201,6 @@ def runner(
230201

231202
sharded_model, optimizer = generate_sharded_model_and_optimizer(
232203
model=unsharded_model,
233-
sharding_type=run_option.sharding_type.value,
234-
kernel_type=run_option.compute_kernel.value,
235204
# pyre-ignore
236205
pg=ctx.pg,
237206
device=ctx.device,
@@ -285,8 +254,8 @@ def run_pipeline(
285254
table_config: EmbeddingTablesConfig,
286255
pipeline_config: PipelineConfig,
287256
model_config: BaseModelConfig,
257+
input_config: ModelInputConfig,
288258
) -> BenchmarkResult:
289-
290259
tables, weighted_tables, *_ = table_config.generate_tables()
291260

292261
benchmark_res_per_rank = run_multi_process_func(
@@ -297,6 +266,7 @@ def run_pipeline(
297266
run_option=run_option,
298267
model_config=model_config,
299268
pipeline_config=pipeline_config,
269+
input_config=input_config,
300270
)
301271

302272
# Combine results from all ranks into a single BenchmarkResult
@@ -329,6 +299,8 @@ def main(
329299
table_config: EmbeddingTablesConfig,
330300
model_selection: ModelSelectionConfig,
331301
pipeline_config: PipelineConfig,
302+
input_config: ModelInputConfig,
303+
planner_config: PlannerConfig,
332304
model_config: Optional[BaseModelConfig] = None,
333305
) -> None:
334306
tables, weighted_tables, *_ = table_config.generate_tables()
@@ -367,6 +339,8 @@ def main(
367339
run_option=run_option,
368340
model_config=model_config,
369341
pipeline_config=pipeline_config,
342+
input_config=input_config,
343+
planner_config=planner_config,
370344
)
371345

372346

torchrec/distributed/benchmark/benchmark_utils.py

Lines changed: 4 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -27,19 +27,16 @@
2727
from torch import nn, optim
2828
from torch.optim import Optimizer
2929
from torchrec.distributed import DistributedModelParallel
30-
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
31-
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
32-
from torchrec.distributed.planner.constants import NUM_POOLINGS, POOLING_FACTOR
30+
from torchrec.distributed.planner import EmbeddingShardingPlanner
3331
from torchrec.distributed.planner.planners import HeteroEmbeddingShardingPlanner
34-
from torchrec.distributed.planner.types import ParameterConstraints
35-
from torchrec.distributed.test_utils.model_input import ModelInput
32+
from torchrec.distributed.sharding_plan import get_default_sharders
3633
from torchrec.distributed.test_utils.test_model import (
3734
TestEBCSharder,
3835
TestSparseNN,
3936
TestTowerCollectionSparseNN,
4037
TestTowerSparseNN,
4138
)
42-
from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingType
39+
from torchrec.distributed.types import ModuleSharder, ShardingEnv
4340
from torchrec.models.deepfm import SimpleDeepFMNNWrapper
4441
from torchrec.models.dlrm import DLRMWrapper
4542
from torchrec.modules.embedding_configs import EmbeddingBagConfig
@@ -240,137 +237,8 @@ def create_model_config(model_name: str, **kwargs) -> BaseModelConfig:
240237
return model_class(**filtered_kwargs)
241238

242239

243-
def generate_data(
244-
tables: List[EmbeddingBagConfig],
245-
weighted_tables: List[EmbeddingBagConfig],
246-
model_config: BaseModelConfig,
247-
batch_sizes: List[int],
248-
) -> List[ModelInput]:
249-
"""
250-
Generate model input data for benchmarking.
251-
252-
Args:
253-
tables: List of unweighted embedding tables
254-
weighted_tables: List of weighted embedding tables
255-
model_config: Configuration for model generation
256-
num_batches: Number of batches to generate
257-
258-
Returns:
259-
A list of ModelInput objects representing the generated batches
260-
"""
261-
device = torch.device(model_config.dev_str) if model_config.dev_str else None
262-
263-
return [
264-
ModelInput.generate(
265-
batch_size=batch_size,
266-
tables=tables,
267-
weighted_tables=weighted_tables,
268-
num_float_features=model_config.num_float_features,
269-
pooling_avg=model_config.feature_pooling_avg,
270-
use_offsets=model_config.use_offsets,
271-
device=device,
272-
indices_dtype=(
273-
torch.int64 if model_config.long_kjt_indices else torch.int32
274-
),
275-
offsets_dtype=(
276-
torch.int64 if model_config.long_kjt_offsets else torch.int32
277-
),
278-
lengths_dtype=(
279-
torch.int64 if model_config.long_kjt_lengths else torch.int32
280-
),
281-
pin_memory=model_config.pin_memory,
282-
)
283-
for batch_size in batch_sizes
284-
]
285-
286-
287-
def generate_planner(
288-
planner_type: str,
289-
topology: Topology,
290-
tables: Optional[List[EmbeddingBagConfig]],
291-
weighted_tables: Optional[List[EmbeddingBagConfig]],
292-
sharding_type: ShardingType,
293-
compute_kernel: EmbeddingComputeKernel,
294-
batch_sizes: List[int],
295-
pooling_factors: Optional[List[float]] = None,
296-
num_poolings: Optional[List[float]] = None,
297-
) -> Union[EmbeddingShardingPlanner, HeteroEmbeddingShardingPlanner]:
298-
"""
299-
Generate an embedding sharding planner based on the specified configuration.
300-
301-
Args:
302-
planner_type: Type of planner to use ("embedding" or "hetero")
303-
topology: Network topology for distributed training
304-
tables: List of unweighted embedding tables
305-
weighted_tables: List of weighted embedding tables
306-
sharding_type: Strategy for sharding embedding tables
307-
compute_kernel: Compute kernel to use for embedding tables
308-
batch_sizes: Sizes of each batch
309-
pooling_factors: Pooling factors for each feature of the table
310-
num_poolings: Number of poolings for each feature of the table
311-
312-
Returns:
313-
An instance of EmbeddingShardingPlanner or HeteroEmbeddingShardingPlanner
314-
315-
Raises:
316-
RuntimeError: If an unknown planner type is specified
317-
"""
318-
# Create parameter constraints for tables
319-
constraints = {}
320-
num_batches = len(batch_sizes)
321-
322-
if pooling_factors is None:
323-
pooling_factors = [POOLING_FACTOR] * num_batches
324-
325-
if num_poolings is None:
326-
num_poolings = [NUM_POOLINGS] * num_batches
327-
328-
assert (
329-
len(pooling_factors) == num_batches and len(num_poolings) == num_batches
330-
), "The length of pooling_factors and num_poolings must match the number of batches."
331-
332-
if tables is not None:
333-
for table in tables:
334-
constraints[table.name] = ParameterConstraints(
335-
sharding_types=[sharding_type.value],
336-
compute_kernels=[compute_kernel.value],
337-
device_group="cuda",
338-
pooling_factors=pooling_factors,
339-
num_poolings=num_poolings,
340-
batch_sizes=batch_sizes,
341-
)
342-
343-
if weighted_tables is not None:
344-
for table in weighted_tables:
345-
constraints[table.name] = ParameterConstraints(
346-
sharding_types=[sharding_type.value],
347-
compute_kernels=[compute_kernel.value],
348-
device_group="cuda",
349-
pooling_factors=pooling_factors,
350-
num_poolings=num_poolings,
351-
batch_sizes=batch_sizes,
352-
is_weighted=True,
353-
)
354-
355-
if planner_type == "embedding":
356-
return EmbeddingShardingPlanner(
357-
topology=topology,
358-
constraints=constraints if constraints else None,
359-
)
360-
elif planner_type == "hetero":
361-
topology_groups = {"cuda": topology}
362-
return HeteroEmbeddingShardingPlanner(
363-
topology_groups=topology_groups,
364-
constraints=constraints if constraints else None,
365-
)
366-
else:
367-
raise RuntimeError(f"Unknown planner type: {planner_type}")
368-
369-
370240
def generate_sharded_model_and_optimizer(
371241
model: nn.Module,
372-
sharding_type: str,
373-
kernel_type: str,
374242
pg: dist.ProcessGroup,
375243
device: torch.device,
376244
fused_params: Dict[str, Any],
@@ -404,12 +272,7 @@ def generate_sharded_model_and_optimizer(
404272
Returns:
405273
Tuple of sharded model and optimizer
406274
"""
407-
sharder = TestEBCSharder(
408-
sharding_type=sharding_type,
409-
kernel_type=kernel_type,
410-
fused_params=fused_params,
411-
)
412-
sharders = [cast(ModuleSharder[nn.Module], sharder)]
275+
sharders = get_default_sharders()
413276

414277
# Use planner if provided
415278
plan = None

torchrec/distributed/benchmark/yaml/sparse_data_dist_base.yml

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
# runs on 2 ranks, showing traces with reasonable workloads
33
RunOptions:
44
world_size: 2
5-
num_batches: 10
5+
num_batches: 5
6+
num_benchmarks: 2
67
sharding_type: table_wise
78
profile_dir: "."
89
name: "sparse_data_dist_base"
@@ -12,18 +13,23 @@ PipelineConfig:
1213
EmbeddingTablesConfig:
1314
num_unweighted_features: 100
1415
num_weighted_features: 100
15-
embedding_feature_dim: 128
16+
embedding_feature_dim: 256
1617
additional_tables:
17-
- - name: additional_tables_0_0
18-
embedding_dim: 128
18+
- - name: FP16_table
19+
embedding_dim: 512
1920
num_embeddings: 100_000
2021
feature_names: ["additional_0_0"]
21-
- name: additional_tables_0_1
22-
embedding_dim: 128
23-
num_embeddings: 100_000
22+
data_type: FP16
23+
- name: large_table
24+
embedding_dim: 2048
25+
num_embeddings: 1_000_000
2426
feature_names: ["additional_0_1"]
2527
- []
26-
- - name: additional_tables_2_1
28+
- - name: skipped_table
2729
embedding_dim: 128
2830
num_embeddings: 100_000
2931
feature_names: ["additional_2_1"]
32+
PlannerConfig:
33+
additional_constraints:
34+
large_table:
35+
sharding_types: [column_wise]
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# this is a very basic sparse data dist config
2+
# runs on 2 ranks, showing traces with reasonable workloads
3+
RunOptions:
4+
world_size: 2
5+
num_batches: 5
6+
num_benchmarks: 2
7+
sharding_type: table_wise
8+
profile_dir: "."
9+
name: "sparse_data_dist_base"
10+
# export_stacks: True # enable this to export stack traces
11+
PipelineConfig:
12+
pipeline: "sparse"
13+
EmbeddingTablesConfig:
14+
num_unweighted_features: 100
15+
num_weighted_features: 100
16+
embedding_feature_dim: 256
17+
additional_tables:
18+
- - name: FP16_table
19+
embedding_dim: 512
20+
num_embeddings: 100_000
21+
feature_names: ["additional_0_0"]
22+
data_type: FP16
23+
- name: large_table
24+
embedding_dim: 256
25+
num_embeddings: 1_000_000
26+
feature_names: ["additional_0_1"]
27+
- []
28+
- - name: skipped_table
29+
embedding_dim: 128
30+
num_embeddings: 100_000
31+
feature_names: ["additional_2_1"]
32+
PlannerConfig:
33+
additional_constraints:
34+
large_table:
35+
compute_kernels: [key_value]
36+
sharding_types: [row_wise]

0 commit comments

Comments
 (0)