diff --git a/src/mistral_inference/main.py b/src/mistral_inference/main.py index e27e0b31..741137f1 100644 --- a/src/mistral_inference/main.py +++ b/src/mistral_inference/main.py @@ -109,7 +109,9 @@ def interactive( ) -> None: if is_torchrun(): torch.distributed.init_process_group() - torch.cuda.set_device(torch.distributed.get_rank()) + # Use the LOCAL_RANK for correct pre-process GPU selection (fallback to global rank if not present) + local_rank = int(os.environ.get("LOCAL_RANK", torch.distributed.get_rank())) + torch.cuda.set_device(local_rank) should_print = torch.distributed.get_rank() == 0 num_pipeline_ranks = torch.distributed.get_world_size() @@ -164,10 +166,26 @@ def interactive( images = [] if is_torchrun(): + # Ensure tensor device matches the distributed backend: + # NCCL requires CUDA tensors for collectives, so move the metadata to GPU. + backend = dist.get_backend() + if backend == "nccl": + # Use the same LOCAL_RANK set earlier + local_rank = int(os.environ.get("LOCAL_RANK", torch.distributed.get_rank())) + if not torch.cuda.is_available(): + raise RuntimeError("NCCL backend selected but CUDA is not available on this process") + device = torch.device(f"cuda:{local_rank}") + length_tensor = length_tensor.to(device) dist.broadcast(length_tensor, src=0) + # Convert broadcasted tensor back to a CPU Python int for downstream use. + if length_tensor.device.type == "cuda": + length_value = length_tensor.cpu().item() + else: + length_value = length_tensor.item() + if not should_print: - tokens = int(length_tensor.item()) * [0] + tokens = int(length_value) * [0] generate_fn = generate if isinstance(model, Transformer) else generate_mamba generated_tokens, _ = generate_fn( # type: ignore[operator]