Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 49 additions & 42 deletions nemo_rl/distributed/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,9 @@ def backward(
seq_size = int(vocab_parallel_logits.shape[1])
num_chunks = (seq_size + chunk_size - 1) // chunk_size

all_grad_input = []
grad_input: torch.Tensor = torch.empty_like(
vocab_parallel_logits, dtype=torch.float32
)

for chunk_idx in range(num_chunks):
chunk_start = chunk_idx * chunk_size
Expand All @@ -243,13 +245,18 @@ def backward(
num_classes=partition_vocab_size,
)

grad_input = is_chosen.float().sub_(softmax_output)

grad_input.mul_(grad_output[:, chunk_start:chunk_end].unsqueeze(dim=-1))
# Inplace index into the preallocated grad_input tensor
grad_input_chunk = grad_input[:, chunk_start:chunk_end, :]

all_grad_input.append(grad_input)
grad_input_chunk.copy_(
is_chosen.float().sub_(softmax_output)
) # inplace copy
grad_input_chunk.mul_(
grad_output[:, chunk_start:chunk_end].unsqueeze(dim=-1)
)

grad_input = torch.cat(all_grad_input, dim=1)
# Explicitly free before next iteration allocates
del softmax_output, is_chosen, logits

# if you add an argument to the forward method, then you must add a corresponding None here
return grad_input, None, None, None, None, None, None
Expand Down Expand Up @@ -326,7 +333,10 @@ def backward(

B, S, V_local = vocab_parallel_logits.shape
num_chunks = (int(S) + chunk_size - 1) // chunk_size
all_grad_input: list[torch.Tensor] = []

grad_input: torch.Tensor = torch.empty_like(
vocab_parallel_logits, dtype=torch.float32
)

for chunk_idx in range(num_chunks):
s0 = chunk_idx * chunk_size
Expand All @@ -344,41 +354,29 @@ def backward(
go_chunk = grad_output[:, s0:s1, :] # [B, Sc, K]
go_sum = go_chunk.sum(dim=-1, keepdim=True) # [B, Sc, 1]

grad_input = softmax_output.neg()
grad_input = grad_input.mul_(go_sum)
# Inplace index into the preallocated grad_input tensor
grad_input_chunk = grad_input[:, s0:s1, :]

grad_input_chunk.copy_(softmax_output.neg().mul_(go_sum)) # inplace copy

# Positive scatter term: add gradients to selected indices
# Mask grad_output for indices not on this shard
go_masked = go_chunk * in_range.to(dtype=go_chunk.dtype)
# Flatten for scatter_add
flat_grad = grad_input.view(-1)
# compute flattened indices positions
Bc, Sc = go_masked.shape[0], go_masked.shape[1]
# row offset per [B, Sc]
row = (
torch.arange(Bc, device=grad_input.device)
.view(-1, 1)
.expand(-1, Sc)
.reshape(-1)
grad_input_chunk.scatter_add_(2, li, go_masked)

# Explicitly free before next iteration allocates
del (
softmax_output,
log_probs,
logits,
gi,
in_range,
li,
go_chunk,
go_sum,
go_masked,
)
col = torch.arange(Sc, device=grad_input.device).expand(Bc, -1).reshape(-1)
flat_idx_base = (row * Sc + col) * V_local # [Bc*Sc]
# selected flat indices
flat_li = li.reshape(-1, li.shape[-1]) # [Bc*Sc, K]
flat_base_expanded = flat_idx_base.unsqueeze(-1).expand_as(flat_li)
flat_chosen = (flat_base_expanded + flat_li).reshape(-1)
flat_go = go_masked.reshape(-1)
flat_grad.scatter_add_(0, flat_chosen, flat_go)

all_grad_input.append(grad_input)

grad_input_total = (
torch.cat(all_grad_input, dim=1)
if len(all_grad_input) > 1
else all_grad_input[0]
)

return grad_input_total, None, None, None, None, None, None
return grad_input, None, None, None, None, None, None


def dtensor_from_parallel_logits_to_logprobs(
Expand Down Expand Up @@ -1033,7 +1031,10 @@ def backward(

B, S, V_local = vocab_parallel_logits.shape
num_chunks = (int(S) + chunk_size - 1) // chunk_size
grads: list[torch.Tensor] = []

grad_input: torch.Tensor = torch.empty_like(
vocab_parallel_logits, dtype=torch.float32
)

for chunk_idx in range(num_chunks):
s0 = chunk_idx * chunk_size
Expand All @@ -1047,10 +1048,16 @@ def backward(
H_local, op=torch.distributed.ReduceOp.SUM, group=tp_group
)

# Inplace index into the preallocated grad_input tensor
grad_input_chunk = grad_input[:, s0:s1, :]

# dH/dz = softmax * (log_probs - H_all)
grad_chunk = softmax_output * (log_probs - H_local.unsqueeze(-1))
grad_chunk.mul_(grad_output[:, s0:s1].unsqueeze(-1))
grads.append(grad_chunk)
grad_input_chunk.copy_(
softmax_output.mul_(log_probs - H_local.unsqueeze(-1))
) # inplace copy
grad_input_chunk.mul_(grad_output[:, s0:s1].unsqueeze(-1))

# Explicitly free before next iteration allocates
del softmax_output, log_probs, logits, H_local

grad_input = torch.cat(grads, dim=1) if len(grads) > 1 else grads[0]
return grad_input, None, None, None
190 changes: 182 additions & 8 deletions tests/unit/distributed/test_model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch

from nemo_rl.distributed.model_utils import (
ChunkedDistributedEntropy,
ChunkedDistributedGatherLogprob,
ChunkedDistributedLogprob,
DistributedLogprob,
Expand Down Expand Up @@ -705,14 +706,25 @@ def test_distributed_logprob_forward_and_backward(self):
full_logits.grad = None

# Compute distributed gradients
distributed_log_probs = DistributedLogprob.apply(
vocab_parallel_logits,
target,
vocab_start_index,
vocab_end_index,
self.tp_group,
False, # inference_only=False to enable backward
)
if chunk_size is not None:
distributed_log_probs = ChunkedDistributedLogprob.apply(
vocab_parallel_logits,
target,
vocab_start_index,
vocab_end_index,
chunk_size,
self.tp_group,
False, # inference_only=False to enable backward
)
else:
distributed_log_probs = DistributedLogprob.apply(
vocab_parallel_logits,
target,
vocab_start_index,
vocab_end_index,
self.tp_group,
False, # inference_only=False to enable backward
)

distributed_loss = torch.sum(distributed_log_probs)
distributed_loss.backward()
Expand Down Expand Up @@ -956,3 +968,165 @@ def test_distributed_logprob_all_tests(

finally:
cluster.shutdown()


@ray.remote(num_gpus=1)
class ChunkedDistributedEntropyTestActor:
def __init__(self, tp_size, chunk_size, inference_only):
self.tp_size = tp_size
self.chunk_size = chunk_size
self.inference_only = inference_only
self.env_vars = dict(os.environ)

def test_chunked_distributed_entropy(self):
torch.distributed.init_process_group(backend="nccl")

rank = int(os.environ["RANK"])
tp_group = torch.distributed.new_group(ranks=list(range(self.tp_size)))

batch_size = 2
seq_len = 16
vocab_size = 256
vocab_part_size = vocab_size // self.tp_size
vocab_start_index = rank * vocab_part_size
vocab_end_index = (rank + 1) * vocab_part_size

torch.manual_seed(1337)
full_logits = torch.randn(batch_size, seq_len, vocab_size, device="cuda")

# === Baseline: single-GPU entropy ===
baseline_logits = (
full_logits.clone().detach().requires_grad_(not self.inference_only)
)
baseline_log_probs = torch.nn.functional.log_softmax(baseline_logits, dim=-1)
baseline_probs = baseline_log_probs.exp()
# H = sum_v p_v * log(p_v) (negative entropy; matches ChunkedDistributedEntropy)
baseline_entropy = (baseline_probs * baseline_log_probs).sum(dim=-1) # [B, S]

if not self.inference_only:
baseline_entropy.sum().backward()
baseline_grad = baseline_logits.grad[
:, :, vocab_start_index:vocab_end_index
].clone()

# === Distributed: ChunkedDistributedEntropy ===
local_logits = full_logits[:, :, vocab_start_index:vocab_end_index]
local_logits = (
local_logits.clone().detach().requires_grad_(not self.inference_only)
)

distributed_entropy = ChunkedDistributedEntropy.apply(
local_logits,
self.chunk_size,
tp_group,
self.inference_only,
)

# Forward check
torch.testing.assert_close(
distributed_entropy, baseline_entropy, rtol=1e-4, atol=1e-4
)
forward_diff = torch.max(
torch.abs(distributed_entropy - baseline_entropy)
).item()

# Backward check
if not self.inference_only:
distributed_entropy.sum().backward()
grad_local = local_logits.grad
torch.testing.assert_close(grad_local, baseline_grad, rtol=1e-4, atol=1e-4)
grad_diff = torch.max(torch.abs(grad_local - baseline_grad)).item()
else:
grad_diff = None

return {
"forward_max_diff": forward_diff,
"grad_max_diff": grad_diff,
}


CHUNKED_DISTRIBUTED_ENTROPY_TEST_ACTOR_FQN = f"{ChunkedDistributedEntropyTestActor.__module__}.ChunkedDistributedEntropyTestActor"


@pytest.fixture
def register_chunked_distributed_entropy_test_actor():
original_registry_value = ACTOR_ENVIRONMENT_REGISTRY.get(
CHUNKED_DISTRIBUTED_ENTROPY_TEST_ACTOR_FQN
)
ACTOR_ENVIRONMENT_REGISTRY[CHUNKED_DISTRIBUTED_ENTROPY_TEST_ACTOR_FQN] = (
PY_EXECUTABLES.SYSTEM
)

yield CHUNKED_DISTRIBUTED_ENTROPY_TEST_ACTOR_FQN

if CHUNKED_DISTRIBUTED_ENTROPY_TEST_ACTOR_FQN in ACTOR_ENVIRONMENT_REGISTRY:
if original_registry_value is None:
del ACTOR_ENVIRONMENT_REGISTRY[CHUNKED_DISTRIBUTED_ENTROPY_TEST_ACTOR_FQN]
else:
ACTOR_ENVIRONMENT_REGISTRY[CHUNKED_DISTRIBUTED_ENTROPY_TEST_ACTOR_FQN] = (
original_registry_value
)


@pytest.mark.parametrize(
"tp_size, chunk_size, inference_only",
[
(1, 5, False),
(2, 4, False),
(1, 3, True),
],
)
def test_chunked_distributed_entropy(
register_chunked_distributed_entropy_test_actor,
tp_size,
chunk_size,
inference_only,
):
world_size = tp_size

if not torch.cuda.is_available() or torch.cuda.device_count() < world_size:
pytest.skip(
f"Not enough GPUs available. Need {world_size}, "
f"got {torch.cuda.device_count()}"
)

cluster = RayVirtualCluster(bundle_ct_per_node_list=[world_size], use_gpus=True)

try:
actor_fqn = register_chunked_distributed_entropy_test_actor

sharding = NamedSharding(
layout=np.arange(world_size).reshape(tp_size), names=["tp"]
)
builder = RayWorkerBuilder(actor_fqn, tp_size, chunk_size, inference_only)

worker_group = RayWorkerGroup(
cluster=cluster,
remote_worker_builder=builder,
workers_per_node=None,
sharding_annotations=sharding,
)

futures = worker_group.run_all_workers_single_data(
"test_chunked_distributed_entropy"
)
results = ray.get(futures)

for i, result in enumerate(results):
print(f"Worker {i} forward max diff: {result['forward_max_diff']:.2e}")
assert result["forward_max_diff"] < 1e-4, (
f"Worker {i} forward diff too large: {result['forward_max_diff']}"
)
if not inference_only:
print(f"Worker {i} grad max diff: {result['grad_max_diff']:.2e}")
assert (
result["grad_max_diff"] is not None
and result["grad_max_diff"] < 1e-4
), f"Worker {i} grad diff too large: {result['grad_max_diff']}"
else:
assert result["grad_max_diff"] is None

worker_group.shutdown(force=True)

finally:
cluster.shutdown()
Loading