Skip to content

Commit f6aab15

Browse files
committed
try it
1 parent a702266 commit f6aab15

File tree

1 file changed

+11
-16
lines changed
  • src/lightning/fabric/strategies

1 file changed

+11
-16
lines changed

src/lightning/fabric/strategies/ddp.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import inspect
1514
from contextlib import AbstractContextManager, nullcontext
1615
from datetime import timedelta
1716
from typing import Any, Literal, Optional, Union
@@ -158,24 +157,20 @@ def all_reduce(
158157
def barrier(self, *args: Any, **kwargs: Any) -> None:
159158
if not _distributed_is_initialized():
160159
return
161-
backend = torch.distributed.get_backend()
162-
if backend == "nccl":
160+
if torch.distributed.get_backend() == "nccl":
163161
torch.distributed.barrier(device_ids=self._determine_ddp_device_ids())
164-
return
165-
# For CPU backends (e.g., gloo), recent PyTorch may attempt to resolve an accelerator and crash on CPU-only runs.
166-
try:
167-
torch.distributed.barrier()
168-
except RuntimeError as e:
169-
# Handle: "Please register PrivateUse1HooksInterface by `RegisterPrivateUse1HooksInterface` first."
170-
if "PrivateUse1HooksInterface" in str(e):
171-
# Use explicit CPU device if supported in this PyTorch version
172-
if "device" in inspect.signature(torch.distributed.barrier).parameters:
173-
torch.distributed.barrier(device=torch.device("cpu"))
162+
else:
163+
# Handle PyTorch bug where barrier() fails on CPU with "PrivateUse1HooksInterface" error
164+
try:
165+
torch.distributed.barrier()
166+
except RuntimeError as e:
167+
if "PrivateUse1HooksInterface" in str(e):
168+
# Fallback: Use all_reduce as barrier - all processes must participate
169+
# This achieves the same synchronization effect as barrier()
170+
dummy_tensor = torch.tensor(0.0, device=self.root_device)
171+
torch.distributed.all_reduce(dummy_tensor)
174172
else:
175-
# Older versions shouldn't trigger this path; re-raise to avoid masking other issues
176173
raise
177-
else:
178-
raise
179174

180175
@override
181176
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:

0 commit comments

Comments
 (0)