|
19 | 19 | import copy
|
20 | 20 | from abc import ABC, abstractmethod
|
21 | 21 | from dataclasses import dataclass, fields
|
22 |
| -from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union |
| 22 | +from typing import Any, Dict, List, Optional, Tuple, Type, Union |
23 | 23 |
|
24 | 24 | import torch
|
25 | 25 | import torch.distributed as dist
|
26 | 26 |
|
27 | 27 | from torch import nn, optim
|
28 | 28 | from torch.optim import Optimizer
|
29 | 29 | from torchrec.distributed import DistributedModelParallel
|
30 |
| -from torchrec.distributed.embedding_types import EmbeddingComputeKernel |
31 |
| -from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology |
32 |
| -from torchrec.distributed.planner.constants import NUM_POOLINGS, POOLING_FACTOR |
| 30 | +from torchrec.distributed.planner import EmbeddingShardingPlanner |
33 | 31 | from torchrec.distributed.planner.planners import HeteroEmbeddingShardingPlanner
|
34 |
| -from torchrec.distributed.planner.types import ParameterConstraints |
35 |
| -from torchrec.distributed.test_utils.model_input import ModelInput |
| 32 | +from torchrec.distributed.sharding_plan import get_default_sharders |
36 | 33 | from torchrec.distributed.test_utils.test_model import (
|
37 |
| - TestEBCSharder, |
38 | 34 | TestSparseNN,
|
39 | 35 | TestTowerCollectionSparseNN,
|
40 | 36 | TestTowerSparseNN,
|
41 | 37 | )
|
42 |
| -from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingType |
| 38 | +from torchrec.distributed.types import ShardingEnv |
43 | 39 | from torchrec.models.deepfm import SimpleDeepFMNNWrapper
|
44 | 40 | from torchrec.models.dlrm import DLRMWrapper
|
45 | 41 | from torchrec.modules.embedding_configs import EmbeddingBagConfig
|
@@ -240,137 +236,8 @@ def create_model_config(model_name: str, **kwargs) -> BaseModelConfig:
|
240 | 236 | return model_class(**filtered_kwargs)
|
241 | 237 |
|
242 | 238 |
|
243 |
| -def generate_data( |
244 |
| - tables: List[EmbeddingBagConfig], |
245 |
| - weighted_tables: List[EmbeddingBagConfig], |
246 |
| - model_config: BaseModelConfig, |
247 |
| - batch_sizes: List[int], |
248 |
| -) -> List[ModelInput]: |
249 |
| - """ |
250 |
| - Generate model input data for benchmarking. |
251 |
| -
|
252 |
| - Args: |
253 |
| - tables: List of unweighted embedding tables |
254 |
| - weighted_tables: List of weighted embedding tables |
255 |
| - model_config: Configuration for model generation |
256 |
| - num_batches: Number of batches to generate |
257 |
| -
|
258 |
| - Returns: |
259 |
| - A list of ModelInput objects representing the generated batches |
260 |
| - """ |
261 |
| - device = torch.device(model_config.dev_str) if model_config.dev_str else None |
262 |
| - |
263 |
| - return [ |
264 |
| - ModelInput.generate( |
265 |
| - batch_size=batch_size, |
266 |
| - tables=tables, |
267 |
| - weighted_tables=weighted_tables, |
268 |
| - num_float_features=model_config.num_float_features, |
269 |
| - pooling_avg=model_config.feature_pooling_avg, |
270 |
| - use_offsets=model_config.use_offsets, |
271 |
| - device=device, |
272 |
| - indices_dtype=( |
273 |
| - torch.int64 if model_config.long_kjt_indices else torch.int32 |
274 |
| - ), |
275 |
| - offsets_dtype=( |
276 |
| - torch.int64 if model_config.long_kjt_offsets else torch.int32 |
277 |
| - ), |
278 |
| - lengths_dtype=( |
279 |
| - torch.int64 if model_config.long_kjt_lengths else torch.int32 |
280 |
| - ), |
281 |
| - pin_memory=model_config.pin_memory, |
282 |
| - ) |
283 |
| - for batch_size in batch_sizes |
284 |
| - ] |
285 |
| - |
286 |
| - |
287 |
| -def generate_planner( |
288 |
| - planner_type: str, |
289 |
| - topology: Topology, |
290 |
| - tables: Optional[List[EmbeddingBagConfig]], |
291 |
| - weighted_tables: Optional[List[EmbeddingBagConfig]], |
292 |
| - sharding_type: ShardingType, |
293 |
| - compute_kernel: EmbeddingComputeKernel, |
294 |
| - batch_sizes: List[int], |
295 |
| - pooling_factors: Optional[List[float]] = None, |
296 |
| - num_poolings: Optional[List[float]] = None, |
297 |
| -) -> Union[EmbeddingShardingPlanner, HeteroEmbeddingShardingPlanner]: |
298 |
| - """ |
299 |
| - Generate an embedding sharding planner based on the specified configuration. |
300 |
| -
|
301 |
| - Args: |
302 |
| - planner_type: Type of planner to use ("embedding" or "hetero") |
303 |
| - topology: Network topology for distributed training |
304 |
| - tables: List of unweighted embedding tables |
305 |
| - weighted_tables: List of weighted embedding tables |
306 |
| - sharding_type: Strategy for sharding embedding tables |
307 |
| - compute_kernel: Compute kernel to use for embedding tables |
308 |
| - batch_sizes: Sizes of each batch |
309 |
| - pooling_factors: Pooling factors for each feature of the table |
310 |
| - num_poolings: Number of poolings for each feature of the table |
311 |
| -
|
312 |
| - Returns: |
313 |
| - An instance of EmbeddingShardingPlanner or HeteroEmbeddingShardingPlanner |
314 |
| -
|
315 |
| - Raises: |
316 |
| - RuntimeError: If an unknown planner type is specified |
317 |
| - """ |
318 |
| - # Create parameter constraints for tables |
319 |
| - constraints = {} |
320 |
| - num_batches = len(batch_sizes) |
321 |
| - |
322 |
| - if pooling_factors is None: |
323 |
| - pooling_factors = [POOLING_FACTOR] * num_batches |
324 |
| - |
325 |
| - if num_poolings is None: |
326 |
| - num_poolings = [NUM_POOLINGS] * num_batches |
327 |
| - |
328 |
| - assert ( |
329 |
| - len(pooling_factors) == num_batches and len(num_poolings) == num_batches |
330 |
| - ), "The length of pooling_factors and num_poolings must match the number of batches." |
331 |
| - |
332 |
| - if tables is not None: |
333 |
| - for table in tables: |
334 |
| - constraints[table.name] = ParameterConstraints( |
335 |
| - sharding_types=[sharding_type.value], |
336 |
| - compute_kernels=[compute_kernel.value], |
337 |
| - device_group="cuda", |
338 |
| - pooling_factors=pooling_factors, |
339 |
| - num_poolings=num_poolings, |
340 |
| - batch_sizes=batch_sizes, |
341 |
| - ) |
342 |
| - |
343 |
| - if weighted_tables is not None: |
344 |
| - for table in weighted_tables: |
345 |
| - constraints[table.name] = ParameterConstraints( |
346 |
| - sharding_types=[sharding_type.value], |
347 |
| - compute_kernels=[compute_kernel.value], |
348 |
| - device_group="cuda", |
349 |
| - pooling_factors=pooling_factors, |
350 |
| - num_poolings=num_poolings, |
351 |
| - batch_sizes=batch_sizes, |
352 |
| - is_weighted=True, |
353 |
| - ) |
354 |
| - |
355 |
| - if planner_type == "embedding": |
356 |
| - return EmbeddingShardingPlanner( |
357 |
| - topology=topology, |
358 |
| - constraints=constraints if constraints else None, |
359 |
| - ) |
360 |
| - elif planner_type == "hetero": |
361 |
| - topology_groups = {"cuda": topology} |
362 |
| - return HeteroEmbeddingShardingPlanner( |
363 |
| - topology_groups=topology_groups, |
364 |
| - constraints=constraints if constraints else None, |
365 |
| - ) |
366 |
| - else: |
367 |
| - raise RuntimeError(f"Unknown planner type: {planner_type}") |
368 |
| - |
369 |
| - |
370 | 239 | def generate_sharded_model_and_optimizer(
|
371 | 240 | model: nn.Module,
|
372 |
| - sharding_type: str, |
373 |
| - kernel_type: str, |
374 | 241 | pg: dist.ProcessGroup,
|
375 | 242 | device: torch.device,
|
376 | 243 | fused_params: Dict[str, Any],
|
@@ -404,12 +271,7 @@ def generate_sharded_model_and_optimizer(
|
404 | 271 | Returns:
|
405 | 272 | Tuple of sharded model and optimizer
|
406 | 273 | """
|
407 |
| - sharder = TestEBCSharder( |
408 |
| - sharding_type=sharding_type, |
409 |
| - kernel_type=kernel_type, |
410 |
| - fused_params=fused_params, |
411 |
| - ) |
412 |
| - sharders = [cast(ModuleSharder[nn.Module], sharder)] |
| 274 | + sharders = get_default_sharders() |
413 | 275 |
|
414 | 276 | # Use planner if provided
|
415 | 277 | plan = None
|
|
0 commit comments