From 6079ce615ea1d88810718cc7879e72195ce1057e Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Mon, 25 Aug 2025 13:16:59 -0700 Subject: [PATCH] silence destroy_proces_group() warning --- distributed/tensor_parallelism/fsdp_tp_example.py | 3 +++ distributed/tensor_parallelism/sequence_parallel_example.py | 4 ++++ distributed/tensor_parallelism/tensor_parallel_example.py | 3 +++ 3 files changed, 10 insertions(+) diff --git a/distributed/tensor_parallelism/fsdp_tp_example.py b/distributed/tensor_parallelism/fsdp_tp_example.py index a44a58ba0f..fb0d5ba1f5 100644 --- a/distributed/tensor_parallelism/fsdp_tp_example.py +++ b/distributed/tensor_parallelism/fsdp_tp_example.py @@ -173,3 +173,6 @@ rank_log(_rank, logger, f"2D iter {i} complete") rank_log(_rank, logger, "2D training successfully completed!") + +if dist.is_initialized(): + dist.destroy_process_group() diff --git a/distributed/tensor_parallelism/sequence_parallel_example.py b/distributed/tensor_parallelism/sequence_parallel_example.py index 988973af4b..73320f5bcc 100644 --- a/distributed/tensor_parallelism/sequence_parallel_example.py +++ b/distributed/tensor_parallelism/sequence_parallel_example.py @@ -22,6 +22,7 @@ import torch import torch.nn as nn +import torch.distributed as dist from torch.distributed._tensor import Shard from torch.distributed.tensor.parallel import ( @@ -107,3 +108,6 @@ def forward(self, x): rank_log(_rank, logger, f"Sequence Parallel iter {i} completed") rank_log(_rank, logger, "Sequence Parallel training completed!") + +if dist.is_initialized(): + dist.destroy_process_group() diff --git a/distributed/tensor_parallelism/tensor_parallel_example.py b/distributed/tensor_parallelism/tensor_parallel_example.py index c42a952ea8..6a4b4ea531 100755 --- a/distributed/tensor_parallelism/tensor_parallel_example.py +++ b/distributed/tensor_parallelism/tensor_parallel_example.py @@ -122,3 +122,6 @@ def forward(self, x): rank_log(_rank, logger, f"Tensor Parallel iter {i} completed") rank_log(_rank, logger, "Tensor Parallel training completed!") + +if dist.is_initialized(): + dist.destroy_process_group()