|
23 | 23 | from torch._dynamo.utils import counters
|
24 | 24 | from torch.fx._symbolic_trace import is_fx_tracing
|
25 | 25 | from torchrec.distributed import DistributedModelParallel
|
26 |
| -from torchrec.distributed.embedding_types import EmbeddingComputeKernel |
| 26 | +from torchrec.distributed.embedding_types import ( |
| 27 | + EmbeddingComputeKernel, |
| 28 | + EmbeddingTableConfig, |
| 29 | +) |
27 | 30 | from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
|
28 | 31 | from torchrec.distributed.fp_embeddingbag import (
|
29 | 32 | FeatureProcessedEmbeddingBagCollectionSharder,
|
|
32 | 35 | from torchrec.distributed.model_parallel import DMPCollection
|
33 | 36 | from torchrec.distributed.sharding_plan import (
|
34 | 37 | construct_module_sharding_plan,
|
| 38 | + row_wise, |
35 | 39 | table_wise,
|
36 | 40 | )
|
| 41 | +from torchrec.distributed.test_utils.multi_process import ( |
| 42 | + MultiProcessContext, |
| 43 | + MultiProcessTestBase, |
| 44 | +) |
37 | 45 | from torchrec.distributed.test_utils.test_model import (
|
38 | 46 | ModelInput,
|
39 | 47 | TestEBCSharder,
|
@@ -342,6 +350,161 @@ def test_equal_to_non_pipelined_with_input_transformer(self) -> None:
|
342 | 350 | torch.testing.assert_close(pred_gpu.cpu(), pred)
|
343 | 351 |
|
344 | 352 |
|
| 353 | +def fp_ebc( |
| 354 | + rank: int, |
| 355 | + world_size: int, |
| 356 | + tables: List[EmbeddingTableConfig], |
| 357 | + weighted_tables: List[EmbeddingTableConfig], |
| 358 | + data: List[Tuple[ModelInput, List[ModelInput]]], |
| 359 | + backend: str = "nccl", |
| 360 | + local_size: Optional[int] = None, |
| 361 | +) -> None: |
| 362 | + with MultiProcessContext(rank, world_size, backend, local_size) as ctx: |
| 363 | + assert ctx.pg is not None |
| 364 | + sharder = cast( |
| 365 | + ModuleSharder[nn.Module], |
| 366 | + FeatureProcessedEmbeddingBagCollectionSharder(), |
| 367 | + ) |
| 368 | + |
| 369 | + class DummyWrapper(nn.Module): |
| 370 | + def __init__(self, sparse_arch): |
| 371 | + super().__init__() |
| 372 | + self.m = sparse_arch |
| 373 | + |
| 374 | + def forward(self, model_input) -> Tuple[torch.Tensor, torch.Tensor]: |
| 375 | + return self.m(model_input.idlist_features) |
| 376 | + |
| 377 | + max_feature_lengths = [10, 10, 12, 12] |
| 378 | + sparse_arch = DummyWrapper( |
| 379 | + create_module_and_freeze( |
| 380 | + tables=tables, # pyre-ignore[6] |
| 381 | + device=ctx.device, |
| 382 | + use_fp_collection=False, |
| 383 | + max_feature_lengths=max_feature_lengths, |
| 384 | + ) |
| 385 | + ) |
| 386 | + |
| 387 | + # compute_kernel = EmbeddingComputeKernel.FUSED.value |
| 388 | + module_sharding_plan = construct_module_sharding_plan( |
| 389 | + sparse_arch.m._fp_ebc, |
| 390 | + per_param_sharding={ |
| 391 | + "table_0": row_wise(), |
| 392 | + "table_1": row_wise(), |
| 393 | + "table_2": row_wise(), |
| 394 | + "table_3": row_wise(), |
| 395 | + }, |
| 396 | + world_size=2, |
| 397 | + device_type=ctx.device.type, |
| 398 | + sharder=sharder, |
| 399 | + ) |
| 400 | + sharded_sparse_arch_pipeline = DistributedModelParallel( |
| 401 | + module=copy.deepcopy(sparse_arch), |
| 402 | + plan=ShardingPlan({"m._fp_ebc": module_sharding_plan}), |
| 403 | + env=ShardingEnv.from_process_group(ctx.pg), # pyre-ignore[6] |
| 404 | + sharders=[sharder], |
| 405 | + device=ctx.device, |
| 406 | + ) |
| 407 | + sharded_sparse_arch_no_pipeline = DistributedModelParallel( |
| 408 | + module=copy.deepcopy(sparse_arch), |
| 409 | + plan=ShardingPlan({"m._fp_ebc": module_sharding_plan}), |
| 410 | + env=ShardingEnv.from_process_group(ctx.pg), # pyre-ignore[6] |
| 411 | + sharders=[sharder], |
| 412 | + device=ctx.device, |
| 413 | + ) |
| 414 | + |
| 415 | + batches = [] |
| 416 | + for d in data: |
| 417 | + batches.append(d[1][ctx.rank].to(ctx.device)) |
| 418 | + dataloader = iter(batches) |
| 419 | + |
| 420 | + optimizer_no_pipeline = optim.SGD( |
| 421 | + sharded_sparse_arch_no_pipeline.parameters(), lr=0.1 |
| 422 | + ) |
| 423 | + optimizer_pipeline = optim.SGD( |
| 424 | + sharded_sparse_arch_pipeline.parameters(), lr=0.1 |
| 425 | + ) |
| 426 | + |
| 427 | + pipeline = TrainPipelineSparseDist( |
| 428 | + sharded_sparse_arch_pipeline, |
| 429 | + optimizer_pipeline, |
| 430 | + ctx.device, |
| 431 | + ) |
| 432 | + |
| 433 | + for batch in batches[:-2]: |
| 434 | + batch = batch.to(ctx.device) |
| 435 | + optimizer_no_pipeline.zero_grad() |
| 436 | + loss, pred = sharded_sparse_arch_no_pipeline(batch) |
| 437 | + loss.backward() |
| 438 | + optimizer_no_pipeline.step() |
| 439 | + |
| 440 | + pred_pipeline = pipeline.progress(dataloader) |
| 441 | + torch.testing.assert_close(pred_pipeline.cpu(), pred.cpu()) |
| 442 | + |
| 443 | + |
| 444 | +class TrainPipelineGPUTest(MultiProcessTestBase): |
| 445 | + def setUp(self, backend: str = "nccl") -> None: |
| 446 | + super().setUp() |
| 447 | + |
| 448 | + self.pipeline_class = TrainPipelineSparseDist |
| 449 | + num_features = 4 |
| 450 | + num_weighted_features = 4 |
| 451 | + self.tables = [ |
| 452 | + EmbeddingBagConfig( |
| 453 | + num_embeddings=(i + 1) * 100, |
| 454 | + embedding_dim=(i + 1) * 4, |
| 455 | + name="table_" + str(i), |
| 456 | + feature_names=["feature_" + str(i)], |
| 457 | + ) |
| 458 | + for i in range(num_features) |
| 459 | + ] |
| 460 | + self.weighted_tables = [ |
| 461 | + EmbeddingBagConfig( |
| 462 | + num_embeddings=(i + 1) * 100, |
| 463 | + embedding_dim=(i + 1) * 4, |
| 464 | + name="weighted_table_" + str(i), |
| 465 | + feature_names=["weighted_feature_" + str(i)], |
| 466 | + ) |
| 467 | + for i in range(num_weighted_features) |
| 468 | + ] |
| 469 | + |
| 470 | + self.backend = backend |
| 471 | + if torch.cuda.is_available(): |
| 472 | + self.device = torch.device("cuda") |
| 473 | + else: |
| 474 | + self.device = torch.device("cpu") |
| 475 | + |
| 476 | + if self.backend == "nccl" and self.device == torch.device("cpu"): |
| 477 | + self.skipTest("NCCL not supported on CPUs.") |
| 478 | + |
| 479 | + def _generate_data( |
| 480 | + self, |
| 481 | + num_batches: int = 5, |
| 482 | + batch_size: int = 1, |
| 483 | + max_feature_lengths: Optional[List[int]] = None, |
| 484 | + ) -> List[Tuple[ModelInput, List[ModelInput]]]: |
| 485 | + return [ |
| 486 | + ModelInput.generate( |
| 487 | + tables=self.tables, |
| 488 | + weighted_tables=self.weighted_tables, |
| 489 | + batch_size=batch_size, |
| 490 | + world_size=2, |
| 491 | + num_float_features=10, |
| 492 | + max_feature_lengths=max_feature_lengths, |
| 493 | + ) |
| 494 | + for i in range(num_batches) |
| 495 | + ] |
| 496 | + |
| 497 | + def test_fp_ebc_rw(self) -> None: |
| 498 | + data = self._generate_data(max_feature_lengths=[10, 10, 12, 12]) |
| 499 | + self._run_multi_process_test( |
| 500 | + callable=fp_ebc, |
| 501 | + world_size=2, |
| 502 | + tables=self.tables, |
| 503 | + weighted_tables=self.weighted_tables, |
| 504 | + data=data, |
| 505 | + ) |
| 506 | + |
| 507 | + |
345 | 508 | class TrainPipelineSparseDistTest(TrainPipelineSparseDistTestBase):
|
346 | 509 | # pyre-fixme[56]: Pyre was not able to infer the type of argument
|
347 | 510 | @unittest.skipIf(
|
|
0 commit comments