|
23 | 23 | from torch._dynamo.utils import counters
|
24 | 24 | from torch.fx._symbolic_trace import is_fx_tracing
|
25 | 25 | from torchrec.distributed import DistributedModelParallel
|
26 |
| -from torchrec.distributed.embedding_types import ( |
27 |
| - EmbeddingComputeKernel, |
28 |
| - EmbeddingTableConfig, |
29 |
| -) |
| 26 | +from torchrec.distributed.embedding_types import EmbeddingComputeKernel |
30 | 27 | from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
|
31 | 28 | from torchrec.distributed.fp_embeddingbag import (
|
32 | 29 | FeatureProcessedEmbeddingBagCollectionSharder,
|
|
35 | 32 | from torchrec.distributed.model_parallel import DMPCollection
|
36 | 33 | from torchrec.distributed.sharding_plan import (
|
37 | 34 | construct_module_sharding_plan,
|
38 |
| - row_wise, |
39 | 35 | table_wise,
|
40 | 36 | )
|
41 |
| -from torchrec.distributed.test_utils.multi_process import ( |
42 |
| - MultiProcessContext, |
43 |
| - MultiProcessTestBase, |
44 |
| -) |
45 | 37 | from torchrec.distributed.test_utils.test_model import (
|
46 | 38 | ModelInput,
|
47 | 39 | TestEBCSharder,
|
@@ -350,161 +342,6 @@ def test_equal_to_non_pipelined_with_input_transformer(self) -> None:
|
350 | 342 | torch.testing.assert_close(pred_gpu.cpu(), pred)
|
351 | 343 |
|
352 | 344 |
|
353 |
| -def fp_ebc( |
354 |
| - rank: int, |
355 |
| - world_size: int, |
356 |
| - tables: List[EmbeddingTableConfig], |
357 |
| - weighted_tables: List[EmbeddingTableConfig], |
358 |
| - data: List[Tuple[ModelInput, List[ModelInput]]], |
359 |
| - backend: str = "nccl", |
360 |
| - local_size: Optional[int] = None, |
361 |
| -) -> None: |
362 |
| - with MultiProcessContext(rank, world_size, backend, local_size) as ctx: |
363 |
| - assert ctx.pg is not None |
364 |
| - sharder = cast( |
365 |
| - ModuleSharder[nn.Module], |
366 |
| - FeatureProcessedEmbeddingBagCollectionSharder(), |
367 |
| - ) |
368 |
| - |
369 |
| - class DummyWrapper(nn.Module): |
370 |
| - def __init__(self, sparse_arch): |
371 |
| - super().__init__() |
372 |
| - self.m = sparse_arch |
373 |
| - |
374 |
| - def forward(self, model_input) -> Tuple[torch.Tensor, torch.Tensor]: |
375 |
| - return self.m(model_input.idlist_features) |
376 |
| - |
377 |
| - max_feature_lengths = [10, 10, 12, 12] |
378 |
| - sparse_arch = DummyWrapper( |
379 |
| - create_module_and_freeze( |
380 |
| - tables=tables, # pyre-ignore[6] |
381 |
| - device=ctx.device, |
382 |
| - use_fp_collection=False, |
383 |
| - max_feature_lengths=max_feature_lengths, |
384 |
| - ) |
385 |
| - ) |
386 |
| - |
387 |
| - # compute_kernel = EmbeddingComputeKernel.FUSED.value |
388 |
| - module_sharding_plan = construct_module_sharding_plan( |
389 |
| - sparse_arch.m._fp_ebc, |
390 |
| - per_param_sharding={ |
391 |
| - "table_0": row_wise(), |
392 |
| - "table_1": row_wise(), |
393 |
| - "table_2": row_wise(), |
394 |
| - "table_3": row_wise(), |
395 |
| - }, |
396 |
| - world_size=2, |
397 |
| - device_type=ctx.device.type, |
398 |
| - sharder=sharder, |
399 |
| - ) |
400 |
| - sharded_sparse_arch_pipeline = DistributedModelParallel( |
401 |
| - module=copy.deepcopy(sparse_arch), |
402 |
| - plan=ShardingPlan({"m._fp_ebc": module_sharding_plan}), |
403 |
| - env=ShardingEnv.from_process_group(ctx.pg), # pyre-ignore[6] |
404 |
| - sharders=[sharder], |
405 |
| - device=ctx.device, |
406 |
| - ) |
407 |
| - sharded_sparse_arch_no_pipeline = DistributedModelParallel( |
408 |
| - module=copy.deepcopy(sparse_arch), |
409 |
| - plan=ShardingPlan({"m._fp_ebc": module_sharding_plan}), |
410 |
| - env=ShardingEnv.from_process_group(ctx.pg), # pyre-ignore[6] |
411 |
| - sharders=[sharder], |
412 |
| - device=ctx.device, |
413 |
| - ) |
414 |
| - |
415 |
| - batches = [] |
416 |
| - for d in data: |
417 |
| - batches.append(d[1][ctx.rank].to(ctx.device)) |
418 |
| - dataloader = iter(batches) |
419 |
| - |
420 |
| - optimizer_no_pipeline = optim.SGD( |
421 |
| - sharded_sparse_arch_no_pipeline.parameters(), lr=0.1 |
422 |
| - ) |
423 |
| - optimizer_pipeline = optim.SGD( |
424 |
| - sharded_sparse_arch_pipeline.parameters(), lr=0.1 |
425 |
| - ) |
426 |
| - |
427 |
| - pipeline = TrainPipelineSparseDist( |
428 |
| - sharded_sparse_arch_pipeline, |
429 |
| - optimizer_pipeline, |
430 |
| - ctx.device, |
431 |
| - ) |
432 |
| - |
433 |
| - for batch in batches[:-2]: |
434 |
| - batch = batch.to(ctx.device) |
435 |
| - optimizer_no_pipeline.zero_grad() |
436 |
| - loss, pred = sharded_sparse_arch_no_pipeline(batch) |
437 |
| - loss.backward() |
438 |
| - optimizer_no_pipeline.step() |
439 |
| - |
440 |
| - pred_pipeline = pipeline.progress(dataloader) |
441 |
| - torch.testing.assert_close(pred_pipeline.cpu(), pred.cpu()) |
442 |
| - |
443 |
| - |
444 |
| -class TrainPipelineGPUTest(MultiProcessTestBase): |
445 |
| - def setUp(self, backend: str = "nccl") -> None: |
446 |
| - super().setUp() |
447 |
| - |
448 |
| - self.pipeline_class = TrainPipelineSparseDist |
449 |
| - num_features = 4 |
450 |
| - num_weighted_features = 4 |
451 |
| - self.tables = [ |
452 |
| - EmbeddingBagConfig( |
453 |
| - num_embeddings=(i + 1) * 100, |
454 |
| - embedding_dim=(i + 1) * 4, |
455 |
| - name="table_" + str(i), |
456 |
| - feature_names=["feature_" + str(i)], |
457 |
| - ) |
458 |
| - for i in range(num_features) |
459 |
| - ] |
460 |
| - self.weighted_tables = [ |
461 |
| - EmbeddingBagConfig( |
462 |
| - num_embeddings=(i + 1) * 100, |
463 |
| - embedding_dim=(i + 1) * 4, |
464 |
| - name="weighted_table_" + str(i), |
465 |
| - feature_names=["weighted_feature_" + str(i)], |
466 |
| - ) |
467 |
| - for i in range(num_weighted_features) |
468 |
| - ] |
469 |
| - |
470 |
| - self.backend = backend |
471 |
| - if torch.cuda.is_available(): |
472 |
| - self.device = torch.device("cuda") |
473 |
| - else: |
474 |
| - self.device = torch.device("cpu") |
475 |
| - |
476 |
| - if self.backend == "nccl" and self.device == torch.device("cpu"): |
477 |
| - self.skipTest("NCCL not supported on CPUs.") |
478 |
| - |
479 |
| - def _generate_data( |
480 |
| - self, |
481 |
| - num_batches: int = 5, |
482 |
| - batch_size: int = 1, |
483 |
| - max_feature_lengths: Optional[List[int]] = None, |
484 |
| - ) -> List[Tuple[ModelInput, List[ModelInput]]]: |
485 |
| - return [ |
486 |
| - ModelInput.generate( |
487 |
| - tables=self.tables, |
488 |
| - weighted_tables=self.weighted_tables, |
489 |
| - batch_size=batch_size, |
490 |
| - world_size=2, |
491 |
| - num_float_features=10, |
492 |
| - max_feature_lengths=max_feature_lengths, |
493 |
| - ) |
494 |
| - for i in range(num_batches) |
495 |
| - ] |
496 |
| - |
497 |
| - def test_fp_ebc_rw(self) -> None: |
498 |
| - data = self._generate_data(max_feature_lengths=[10, 10, 12, 12]) |
499 |
| - self._run_multi_process_test( |
500 |
| - callable=fp_ebc, |
501 |
| - world_size=2, |
502 |
| - tables=self.tables, |
503 |
| - weighted_tables=self.weighted_tables, |
504 |
| - data=data, |
505 |
| - ) |
506 |
| - |
507 |
| - |
508 | 345 | class TrainPipelineSparseDistTest(TrainPipelineSparseDistTestBase):
|
509 | 346 | # pyre-fixme[56]: Pyre was not able to infer the type of argument
|
510 | 347 | @unittest.skipIf(
|
|
0 commit comments