|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | | -import inspect |
15 | 14 | from contextlib import AbstractContextManager, nullcontext |
16 | 15 | from datetime import timedelta |
17 | 16 | from typing import Any, Literal, Optional, Union |
@@ -158,24 +157,20 @@ def all_reduce( |
158 | 157 | def barrier(self, *args: Any, **kwargs: Any) -> None: |
159 | 158 | if not _distributed_is_initialized(): |
160 | 159 | return |
161 | | - backend = torch.distributed.get_backend() |
162 | | - if backend == "nccl": |
| 160 | + if torch.distributed.get_backend() == "nccl": |
163 | 161 | torch.distributed.barrier(device_ids=self._determine_ddp_device_ids()) |
164 | | - return |
165 | | - # For CPU backends (e.g., gloo), recent PyTorch may attempt to resolve an accelerator and crash on CPU-only runs. |
166 | | - try: |
167 | | - torch.distributed.barrier() |
168 | | - except RuntimeError as e: |
169 | | - # Handle: "Please register PrivateUse1HooksInterface by `RegisterPrivateUse1HooksInterface` first." |
170 | | - if "PrivateUse1HooksInterface" in str(e): |
171 | | - # Use explicit CPU device if supported in this PyTorch version |
172 | | - if "device" in inspect.signature(torch.distributed.barrier).parameters: |
173 | | - torch.distributed.barrier(device=torch.device("cpu")) |
| 162 | + else: |
| 163 | + # Handle PyTorch bug where barrier() fails on CPU with "PrivateUse1HooksInterface" error |
| 164 | + try: |
| 165 | + torch.distributed.barrier() |
| 166 | + except RuntimeError as e: |
| 167 | + if "PrivateUse1HooksInterface" in str(e): |
| 168 | + # Fallback: Use all_reduce as barrier - all processes must participate |
| 169 | + # This achieves the same synchronization effect as barrier() |
| 170 | + dummy_tensor = torch.tensor(0.0, device=self.root_device) |
| 171 | + torch.distributed.all_reduce(dummy_tensor) |
174 | 172 | else: |
175 | | - # Older versions shouldn't trigger this path; re-raise to avoid masking other issues |
176 | 173 | raise |
177 | | - else: |
178 | | - raise |
179 | 174 |
|
180 | 175 | @override |
181 | 176 | def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: |
|
0 commit comments