Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions src/mistral_inference/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ def interactive(
) -> None:
if is_torchrun():
torch.distributed.init_process_group()
torch.cuda.set_device(torch.distributed.get_rank())
# Use the LOCAL_RANK for correct pre-process GPU selection (fallback to global rank if not present)
local_rank = int(os.environ.get("LOCAL_RANK", torch.distributed.get_rank()))
torch.cuda.set_device(local_rank)
should_print = torch.distributed.get_rank() == 0

num_pipeline_ranks = torch.distributed.get_world_size()
Expand Down Expand Up @@ -164,10 +166,26 @@ def interactive(
images = []

if is_torchrun():
# Ensure tensor device matches the distributed backend:
# NCCL requires CUDA tensors for collectives, so move the metadata to GPU.
backend = dist.get_backend()
if backend == "nccl":
# Use the same LOCAL_RANK set earlier
local_rank = int(os.environ.get("LOCAL_RANK", torch.distributed.get_rank()))
if not torch.cuda.is_available():
raise RuntimeError("NCCL backend selected but CUDA is not available on this process")
device = torch.device(f"cuda:{local_rank}")
length_tensor = length_tensor.to(device)
dist.broadcast(length_tensor, src=0)

# Convert broadcasted tensor back to a CPU Python int for downstream use.
if length_tensor.device.type == "cuda":
length_value = length_tensor.cpu().item()
else:
length_value = length_tensor.item()

if not should_print:
tokens = int(length_tensor.item()) * [0]
tokens = int(length_value) * [0]

generate_fn = generate if isinstance(model, Transformer) else generate_mamba
generated_tokens, _ = generate_fn( # type: ignore[operator]
Expand Down