Skip to content

Commit c70834d

Browse files
committed
fix: stop setting device_id for compatibility
(cherry picked from commit 54ff390c9a04418b3c123d8e7a2037ecdb42d8ea)
1 parent df74927 commit c70834d

File tree

2 files changed

+0
-8
lines changed

2 files changed

+0
-8
lines changed

trainer/src/agentrl/trainer/components/nccl_tensor_comm.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ def __init__(self, worker: AbstractTrainWorker, addr, port, world_size):
4242
world_size=world_size,
4343
rank=0,
4444
group_name=f"nccl_comm_{addr}_{port}",
45-
device_id=torch.device(torch.cuda.current_device()),
4645
)
4746

4847
def send(self, bucket_size):
@@ -85,7 +84,6 @@ def __init__(self, worker: AbstractAsyncRolloutWorker, addr, port, world_size, o
8584
world_size=world_size,
8685
rank=offset + worker.rank,
8786
group_name=f"nccl_comm_{addr}_{port}",
88-
device_id=torch.device("cuda:0"),
8987
)
9088

9189
async def async_receive(self):

trainer/src/agentrl/trainer/workers/fsdp_worker.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,6 @@ class FSDPWorker(AbstractTrainWorker):
8282

8383
def __init__(self, config):
8484
super().__init__()
85-
if "CUDA_VISIBLE_DEVICES" in os.environ:
86-
device = os.environ.pop("CUDA_VISIBLE_DEVICES")
87-
os.environ["LOCAL_RANK"] = device
88-
else:
89-
device = os.environ["LOCAL_RANK"]
90-
torch.cuda.set_device(f"cuda:{device}")
9185
self.config = config
9286

9387
def init_distributed(self, addr, port):

0 commit comments

Comments
 (0)