Skip to content

Commit 4e4daa7

Browse files
committed
fix: check for cpu tensor contiguousity as well
Signed-off-by: Hao Lin <linhaomails@gmail.com>
1 parent 99353b9 commit 4e4daa7

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

rlinf/scheduler/collective/collective_group.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -484,9 +484,9 @@ def _atomic_send_tensor(
484484
assert object_type == CollectiveGroup.TENSOR, (
485485
"The object must be a torch.Tensor when using send_tensor"
486486
)
487-
if tensor_data.has_accel_tensor and not tensor.is_contiguous():
487+
if not tensor.is_contiguous():
488488
raise ValueError(
489-
"All CUDA tensors must be contiguous when using P2P communication. Otherwise the recv side might recv wrong tensor data. Consider using .contiguous() to make the tensors contiguous."
489+
"All tensors must be contiguous when using P2P communication. Otherwise the recv side might recv wrong tensor data. Consider using .contiguous() to make the tensors contiguous."
490490
)
491491

492492
self._init_process_group(options=options)
@@ -962,7 +962,7 @@ def _get_object_info(self, object: torch.Tensor | Any) -> tuple[int, TensorData]
962962
cpu_tensor_mask, cpu_tensors, accel_tensors = self._partition_tensors(
963963
[object]
964964
)
965-
self._check_tensor_contiguous(accel_tensors)
965+
self._check_tensor_contiguous(accel_tensors + cpu_tensors)
966966
object_type = CollectiveGroup.TENSOR
967967
tensor_data = TensorData(
968968
cpu_tensor_mask=cpu_tensor_mask,
@@ -976,7 +976,7 @@ def _get_object_info(self, object: torch.Tensor | Any) -> tuple[int, TensorData]
976976
cpu_tensor_mask, cpu_tensors, accel_tensors = self._partition_tensors(
977977
list(object)
978978
)
979-
self._check_tensor_contiguous(accel_tensors)
979+
self._check_tensor_contiguous(accel_tensors + cpu_tensors)
980980
object_type = CollectiveGroup.TENSOR_LIST
981981
tensor_data = TensorData(
982982
cpu_tensor_mask=cpu_tensor_mask,
@@ -991,7 +991,7 @@ def _get_object_info(self, object: torch.Tensor | Any) -> tuple[int, TensorData]
991991
cpu_tensor_mask, cpu_tensors, accel_tensors = self._partition_tensors(
992992
values
993993
)
994-
self._check_tensor_contiguous(accel_tensors)
994+
self._check_tensor_contiguous(accel_tensors + cpu_tensors)
995995
object_type = CollectiveGroup.TENSOR_DICT
996996
tensor_data = TensorData(
997997
cpu_tensor_mask=cpu_tensor_mask,
@@ -1009,7 +1009,7 @@ def _get_object_info(self, object: torch.Tensor | Any) -> tuple[int, TensorData]
10091009
cpu_tensors,
10101010
accel_tensors,
10111011
) = self._partition_tensors(tensors_list)
1012-
self._check_tensor_contiguous(accel_tensors)
1012+
self._check_tensor_contiguous(accel_tensors + cpu_tensors)
10131013
object_type = CollectiveGroup.DATACLASS_WITH_TENSORS
10141014
tensor_data = TensorData(
10151015
cpu_tensor_mask=cpu_tensor_mask,
@@ -1026,7 +1026,7 @@ def _check_tensor_contiguous(self, tensors: Iterable[torch.Tensor]):
10261026
"""Check if the tensors are contiguous."""
10271027
if not all(t.is_contiguous() for t in tensors):
10281028
raise ValueError(
1029-
"All CUDA/Accelerator tensors must be contiguous when using P2P communication. Otherwise the recv side might recv wrong tensor data. Consider using .contiguous() to make the tensors contiguous."
1029+
"All tensors must be contiguous when using P2P communication. Otherwise the recv side might recv wrong tensor data. Consider using .contiguous() to make the tensors contiguous."
10301030
)
10311031

10321032
def _check_same_device_with_peer(self):

0 commit comments

Comments
 (0)