Skip to content

fix: remove hardcoded cuda:0 in OTA loss for multi-GPU support#2146

Open
Mr-Neutr0n wants to merge 1 commit intoWongKinYiu:mainfrom
Mr-Neutr0n:fix/ota-loss-multi-gpu-device
Open

fix: remove hardcoded cuda:0 in OTA loss for multi-GPU support#2146
Mr-Neutr0n wants to merge 1 commit intoWongKinYiu:mainfrom
Mr-Neutr0n:fix/ota-loss-multi-gpu-device

Conversation

@Mr-Neutr0n
Copy link

Bug

The OTA loss computation hardcodes cuda:0 for tensor allocation, causing failures during multi-GPU (DDP) training when the model runs on other GPUs.

In utils/loss.py, several locations use .cuda() or device='cuda:0' when creating tensors, which forces them onto GPU 0 regardless of where the model and input data reside. This causes device mismatch errors during distributed training.

Fix

Replaced hardcoded device references with dynamic device inference from input tensors:

  • .cuda().to(logits.device) in RankSort, aLRPLoss, and APLoss forward methods
  • device='cuda:0'device=targets.device in build_targets and build_targets2 methods

This is consistent with how device handling is already done elsewhere in the same file (e.g., torch.ones(7, device=targets.device)).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant