diff --git a/torchft/checkpointing/pg_transport.py b/torchft/checkpointing/pg_transport.py index 22fb4d9..e73c0bc 100644 --- a/torchft/checkpointing/pg_transport.py +++ b/torchft/checkpointing/pg_transport.py @@ -156,9 +156,9 @@ def _cast_tensor(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: caveat that the cast tensor may be larger than the original tensor due to the differences in striding. """ - assert type(tensor) is torch.Tensor, ( - f"can only cast standard tensors not {type(tensor)}" - ) + assert ( + type(tensor) is torch.Tensor + ), f"can only cast standard tensors not {type(tensor)}" storage = tensor.untyped_storage() ret = torch.tensor(storage, dtype=dtype, device=tensor.device) assert ret.untyped_storage() is storage, "storage should be the same" @@ -266,9 +266,9 @@ def recv(path: KeyPath, v: _TensorMeta) -> torch.Tensor: if isinstance(inplace, DTensor): inplace = inplace._local_tensor t = _cast_tensor(inplace, torch.uint8) - assert t.nbytes == v.nbytes, ( - "inplace tensor storage must be the same size" - ) + assert ( + t.nbytes == v.nbytes + ), "inplace tensor storage must be the same size" else: t = torch.empty(v.nbytes, dtype=torch.uint8, device=self._device) diff --git a/torchft/collectives.py b/torchft/collectives.py index 45a8501..513a2d9 100644 --- a/torchft/collectives.py +++ b/torchft/collectives.py @@ -115,12 +115,12 @@ def allocate_reduce_scatter_output( device = tensors[0].device dtype = tensors[0].dtype for i in range(1, len(tensors)): - assert tensors[i].device == tensors[i - 1].device, ( - "All inputs must be on the same device" - ) - assert tensors[i].dtype == tensors[i - 1].dtype, ( - "All inputs must be on the same dtype" - ) + assert ( + tensors[i].device == tensors[i - 1].device + ), "All inputs must be on the same device" + assert ( + tensors[i].dtype == tensors[i - 1].dtype + ), "All inputs must be on the same dtype" padded_sizes = get_padded_sizes(tensors, world_size) diff --git a/torchft/futures.py b/torchft/futures.py index b2f183a..ec1936f 100644 --- a/torchft/futures.py +++ b/torchft/futures.py @@ -262,9 +262,7 @@ def _clear_del_queue(self) -> int: assert ( # 1 from item, 1 from getrefcount refcount == 2 - ), ( - f"items in del_queue reference should not have other references, found {refcount=}" - ) + ), f"items in del_queue reference should not have other references, found {refcount=}" del item count += 1 diff --git a/torchft/local_sgd.py b/torchft/local_sgd.py index 9f72ede..8d1a4a2 100644 --- a/torchft/local_sgd.py +++ b/torchft/local_sgd.py @@ -615,9 +615,9 @@ def __init__( """ if isinstance(outer_optimizer, list): - assert len(outer_optimizer) == len(model_fragments), ( - "The number of outer optimizers must match the number of model fragments" - ) + assert len(outer_optimizer) == len( + model_fragments + ), "The number of outer optimizers must match the number of model fragments" if manager._use_async_quorum: raise ValueError( @@ -790,6 +790,6 @@ def _step_post_hook( self._local_step = 0 return - assert False, ( - f"{self._local_step=} should never be greater than {self._sync_every=}" - ) + assert ( + False + ), f"{self._local_step=} should never be greater than {self._sync_every=}" diff --git a/torchft/manager.py b/torchft/manager.py index 8374542..318cb23 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -621,9 +621,9 @@ def wait_quorum(self) -> None: ProcessGroup will be in a healthy state after this returns. """ - assert self._quorum_future is not None, ( - "must call start_quorum before wait_quorum" - ) + assert ( + self._quorum_future is not None + ), "must call start_quorum before wait_quorum" self._quorum_future.result() @torch.profiler.record_function("torchft::manager::_async_quorum") @@ -730,7 +730,14 @@ def _async_quorum( f"resetting fr recording for quorum id {self._quorum_id}" ) self._update_fr_path() - torch._C._distributed_c10d._reset_fr_recording_nccl() # pyre-ignore + # Only reset FR recording if available (requires NCCL Flight Recorder support) + if hasattr(torch._C._distributed_c10d, "_reset_fr_recording_nccl"): + torch._C._distributed_c10d._reset_fr_recording_nccl() # pyre-ignore + else: + self._logger.warn( + "Unable to reset NCCL flight recorder recording so traces will be " + "incorrect." + ) except Exception as e: self._logger.exception(f"got exception in pg configure: {e}") self.report_error(e) @@ -769,9 +776,9 @@ def _async_quorum( self._group_rank, timeout=self._timeout ) recover_src_replica_rank = quorum.recover_src_replica_rank - assert recover_src_replica_rank is not None, ( - "must have a recover rank when healing" - ) + assert ( + recover_src_replica_rank is not None + ), "must have a recover rank when healing" self._logger.info( f"fetching checkpoint from {recover_src_replica_rank=} with {checkpoint_metadata=}" @@ -830,9 +837,9 @@ def _apply_pending_state_dict(self) -> None: else: self._logger.info("applying pending state dict") - assert len(self._load_state_dict_fns) > 0, ( - "user load_state_dict is not initialized." - ) + assert ( + len(self._load_state_dict_fns) > 0 + ), "user load_state_dict is not initialized." pending_user_state_dict = cast( Dict[str, object], pending_state_dict["user"] @@ -949,9 +956,9 @@ def load_state_dict(self, state_dict: Dict[str, int]) -> None: def _manager_state_dict(self) -> Dict[str, object]: with self._state_dict_lock.r_lock(): - assert len(self._user_state_dicts) > 0, ( - "user state_dict is not initialized." - ) + assert ( + len(self._user_state_dicts) > 0 + ), "user state_dict is not initialized." return { "user": {key: value() for key, value in self._user_state_dicts.items()}, "torchft": self.state_dict(), diff --git a/torchft/quantization.py b/torchft/quantization.py index 4c9d72e..2bcdf4e 100644 --- a/torchft/quantization.py +++ b/torchft/quantization.py @@ -464,12 +464,12 @@ def _prepare_quantize_fp8( device = inputs[0].device dtype = inputs[0].dtype for i in range(1, i_num): - assert inputs[i].device == inputs[i - 1].device, ( - "All inputs must be on the same device" - ) - assert inputs[i].dtype == inputs[i - 1].dtype, ( - "All inputs must be on the same dtype" - ) + assert ( + inputs[i].device == inputs[i - 1].device + ), "All inputs must be on the same device" + assert ( + inputs[i].dtype == inputs[i - 1].dtype + ), "All inputs must be on the same dtype" assert dtype in [ torch.float32,