diff --git a/nemo_rl/distributed/model_utils.py b/nemo_rl/distributed/model_utils.py index fb17ee1661..e4ba81f8f9 100644 --- a/nemo_rl/distributed/model_utils.py +++ b/nemo_rl/distributed/model_utils.py @@ -220,7 +220,9 @@ def backward( seq_size = int(vocab_parallel_logits.shape[1]) num_chunks = (seq_size + chunk_size - 1) // chunk_size - all_grad_input = [] + grad_input: torch.Tensor = torch.empty_like( + vocab_parallel_logits, dtype=torch.float32 + ) for chunk_idx in range(num_chunks): chunk_start = chunk_idx * chunk_size @@ -243,13 +245,18 @@ def backward( num_classes=partition_vocab_size, ) - grad_input = is_chosen.float().sub_(softmax_output) - - grad_input.mul_(grad_output[:, chunk_start:chunk_end].unsqueeze(dim=-1)) + # Inplace index into the preallocated grad_input tensor + grad_input_chunk = grad_input[:, chunk_start:chunk_end, :] - all_grad_input.append(grad_input) + grad_input_chunk.copy_( + is_chosen.float().sub_(softmax_output) + ) # inplace copy + grad_input_chunk.mul_( + grad_output[:, chunk_start:chunk_end].unsqueeze(dim=-1) + ) - grad_input = torch.cat(all_grad_input, dim=1) + # Explicitly free before next iteration allocates + del softmax_output, is_chosen, logits # if you add an argument to the forward method, then you must add a corresponding None here return grad_input, None, None, None, None, None, None @@ -326,7 +333,10 @@ def backward( B, S, V_local = vocab_parallel_logits.shape num_chunks = (int(S) + chunk_size - 1) // chunk_size - all_grad_input: list[torch.Tensor] = [] + + grad_input: torch.Tensor = torch.empty_like( + vocab_parallel_logits, dtype=torch.float32 + ) for chunk_idx in range(num_chunks): s0 = chunk_idx * chunk_size @@ -344,41 +354,29 @@ def backward( go_chunk = grad_output[:, s0:s1, :] # [B, Sc, K] go_sum = go_chunk.sum(dim=-1, keepdim=True) # [B, Sc, 1] - grad_input = softmax_output.neg() - grad_input = grad_input.mul_(go_sum) + # Inplace index into the preallocated grad_input tensor + grad_input_chunk = grad_input[:, s0:s1, :] + + grad_input_chunk.copy_(softmax_output.neg().mul_(go_sum)) # inplace copy # Positive scatter term: add gradients to selected indices - # Mask grad_output for indices not on this shard go_masked = go_chunk * in_range.to(dtype=go_chunk.dtype) - # Flatten for scatter_add - flat_grad = grad_input.view(-1) - # compute flattened indices positions - Bc, Sc = go_masked.shape[0], go_masked.shape[1] - # row offset per [B, Sc] - row = ( - torch.arange(Bc, device=grad_input.device) - .view(-1, 1) - .expand(-1, Sc) - .reshape(-1) + grad_input_chunk.scatter_add_(2, li, go_masked) + + # Explicitly free before next iteration allocates + del ( + softmax_output, + log_probs, + logits, + gi, + in_range, + li, + go_chunk, + go_sum, + go_masked, ) - col = torch.arange(Sc, device=grad_input.device).expand(Bc, -1).reshape(-1) - flat_idx_base = (row * Sc + col) * V_local # [Bc*Sc] - # selected flat indices - flat_li = li.reshape(-1, li.shape[-1]) # [Bc*Sc, K] - flat_base_expanded = flat_idx_base.unsqueeze(-1).expand_as(flat_li) - flat_chosen = (flat_base_expanded + flat_li).reshape(-1) - flat_go = go_masked.reshape(-1) - flat_grad.scatter_add_(0, flat_chosen, flat_go) - - all_grad_input.append(grad_input) - - grad_input_total = ( - torch.cat(all_grad_input, dim=1) - if len(all_grad_input) > 1 - else all_grad_input[0] - ) - return grad_input_total, None, None, None, None, None, None + return grad_input, None, None, None, None, None, None def dtensor_from_parallel_logits_to_logprobs( @@ -1033,7 +1031,10 @@ def backward( B, S, V_local = vocab_parallel_logits.shape num_chunks = (int(S) + chunk_size - 1) // chunk_size - grads: list[torch.Tensor] = [] + + grad_input: torch.Tensor = torch.empty_like( + vocab_parallel_logits, dtype=torch.float32 + ) for chunk_idx in range(num_chunks): s0 = chunk_idx * chunk_size @@ -1047,10 +1048,16 @@ def backward( H_local, op=torch.distributed.ReduceOp.SUM, group=tp_group ) + # Inplace index into the preallocated grad_input tensor + grad_input_chunk = grad_input[:, s0:s1, :] + # dH/dz = softmax * (log_probs - H_all) - grad_chunk = softmax_output * (log_probs - H_local.unsqueeze(-1)) - grad_chunk.mul_(grad_output[:, s0:s1].unsqueeze(-1)) - grads.append(grad_chunk) + grad_input_chunk.copy_( + softmax_output.mul_(log_probs - H_local.unsqueeze(-1)) + ) # inplace copy + grad_input_chunk.mul_(grad_output[:, s0:s1].unsqueeze(-1)) + + # Explicitly free before next iteration allocates + del softmax_output, log_probs, logits, H_local - grad_input = torch.cat(grads, dim=1) if len(grads) > 1 else grads[0] return grad_input, None, None, None diff --git a/tests/unit/distributed/test_model_utils.py b/tests/unit/distributed/test_model_utils.py index 8637ad22fe..2579121ec8 100644 --- a/tests/unit/distributed/test_model_utils.py +++ b/tests/unit/distributed/test_model_utils.py @@ -18,6 +18,7 @@ import torch from nemo_rl.distributed.model_utils import ( + ChunkedDistributedEntropy, ChunkedDistributedGatherLogprob, ChunkedDistributedLogprob, DistributedLogprob, @@ -705,14 +706,25 @@ def test_distributed_logprob_forward_and_backward(self): full_logits.grad = None # Compute distributed gradients - distributed_log_probs = DistributedLogprob.apply( - vocab_parallel_logits, - target, - vocab_start_index, - vocab_end_index, - self.tp_group, - False, # inference_only=False to enable backward - ) + if chunk_size is not None: + distributed_log_probs = ChunkedDistributedLogprob.apply( + vocab_parallel_logits, + target, + vocab_start_index, + vocab_end_index, + chunk_size, + self.tp_group, + False, # inference_only=False to enable backward + ) + else: + distributed_log_probs = DistributedLogprob.apply( + vocab_parallel_logits, + target, + vocab_start_index, + vocab_end_index, + self.tp_group, + False, # inference_only=False to enable backward + ) distributed_loss = torch.sum(distributed_log_probs) distributed_loss.backward() @@ -956,3 +968,165 @@ def test_distributed_logprob_all_tests( finally: cluster.shutdown() + + +@ray.remote(num_gpus=1) +class ChunkedDistributedEntropyTestActor: + def __init__(self, tp_size, chunk_size, inference_only): + self.tp_size = tp_size + self.chunk_size = chunk_size + self.inference_only = inference_only + self.env_vars = dict(os.environ) + + def test_chunked_distributed_entropy(self): + torch.distributed.init_process_group(backend="nccl") + + rank = int(os.environ["RANK"]) + tp_group = torch.distributed.new_group(ranks=list(range(self.tp_size))) + + batch_size = 2 + seq_len = 16 + vocab_size = 256 + vocab_part_size = vocab_size // self.tp_size + vocab_start_index = rank * vocab_part_size + vocab_end_index = (rank + 1) * vocab_part_size + + torch.manual_seed(1337) + full_logits = torch.randn(batch_size, seq_len, vocab_size, device="cuda") + + # === Baseline: single-GPU entropy === + baseline_logits = ( + full_logits.clone().detach().requires_grad_(not self.inference_only) + ) + baseline_log_probs = torch.nn.functional.log_softmax(baseline_logits, dim=-1) + baseline_probs = baseline_log_probs.exp() + # H = sum_v p_v * log(p_v) (negative entropy; matches ChunkedDistributedEntropy) + baseline_entropy = (baseline_probs * baseline_log_probs).sum(dim=-1) # [B, S] + + if not self.inference_only: + baseline_entropy.sum().backward() + baseline_grad = baseline_logits.grad[ + :, :, vocab_start_index:vocab_end_index + ].clone() + + # === Distributed: ChunkedDistributedEntropy === + local_logits = full_logits[:, :, vocab_start_index:vocab_end_index] + local_logits = ( + local_logits.clone().detach().requires_grad_(not self.inference_only) + ) + + distributed_entropy = ChunkedDistributedEntropy.apply( + local_logits, + self.chunk_size, + tp_group, + self.inference_only, + ) + + # Forward check + torch.testing.assert_close( + distributed_entropy, baseline_entropy, rtol=1e-4, atol=1e-4 + ) + forward_diff = torch.max( + torch.abs(distributed_entropy - baseline_entropy) + ).item() + + # Backward check + if not self.inference_only: + distributed_entropy.sum().backward() + grad_local = local_logits.grad + torch.testing.assert_close(grad_local, baseline_grad, rtol=1e-4, atol=1e-4) + grad_diff = torch.max(torch.abs(grad_local - baseline_grad)).item() + else: + grad_diff = None + + return { + "forward_max_diff": forward_diff, + "grad_max_diff": grad_diff, + } + + +CHUNKED_DISTRIBUTED_ENTROPY_TEST_ACTOR_FQN = f"{ChunkedDistributedEntropyTestActor.__module__}.ChunkedDistributedEntropyTestActor" + + +@pytest.fixture +def register_chunked_distributed_entropy_test_actor(): + original_registry_value = ACTOR_ENVIRONMENT_REGISTRY.get( + CHUNKED_DISTRIBUTED_ENTROPY_TEST_ACTOR_FQN + ) + ACTOR_ENVIRONMENT_REGISTRY[CHUNKED_DISTRIBUTED_ENTROPY_TEST_ACTOR_FQN] = ( + PY_EXECUTABLES.SYSTEM + ) + + yield CHUNKED_DISTRIBUTED_ENTROPY_TEST_ACTOR_FQN + + if CHUNKED_DISTRIBUTED_ENTROPY_TEST_ACTOR_FQN in ACTOR_ENVIRONMENT_REGISTRY: + if original_registry_value is None: + del ACTOR_ENVIRONMENT_REGISTRY[CHUNKED_DISTRIBUTED_ENTROPY_TEST_ACTOR_FQN] + else: + ACTOR_ENVIRONMENT_REGISTRY[CHUNKED_DISTRIBUTED_ENTROPY_TEST_ACTOR_FQN] = ( + original_registry_value + ) + + +@pytest.mark.parametrize( + "tp_size, chunk_size, inference_only", + [ + (1, 5, False), + (2, 4, False), + (1, 3, True), + ], +) +def test_chunked_distributed_entropy( + register_chunked_distributed_entropy_test_actor, + tp_size, + chunk_size, + inference_only, +): + world_size = tp_size + + if not torch.cuda.is_available() or torch.cuda.device_count() < world_size: + pytest.skip( + f"Not enough GPUs available. Need {world_size}, " + f"got {torch.cuda.device_count()}" + ) + + cluster = RayVirtualCluster(bundle_ct_per_node_list=[world_size], use_gpus=True) + + try: + actor_fqn = register_chunked_distributed_entropy_test_actor + + sharding = NamedSharding( + layout=np.arange(world_size).reshape(tp_size), names=["tp"] + ) + builder = RayWorkerBuilder(actor_fqn, tp_size, chunk_size, inference_only) + + worker_group = RayWorkerGroup( + cluster=cluster, + remote_worker_builder=builder, + workers_per_node=None, + sharding_annotations=sharding, + ) + + futures = worker_group.run_all_workers_single_data( + "test_chunked_distributed_entropy" + ) + results = ray.get(futures) + + for i, result in enumerate(results): + print(f"Worker {i} forward max diff: {result['forward_max_diff']:.2e}") + assert result["forward_max_diff"] < 1e-4, ( + f"Worker {i} forward diff too large: {result['forward_max_diff']}" + ) + if not inference_only: + print(f"Worker {i} grad max diff: {result['grad_max_diff']:.2e}") + assert ( + result["grad_max_diff"] is not None + and result["grad_max_diff"] < 1e-4 + ), f"Worker {i} grad diff too large: {result['grad_max_diff']}" + else: + assert result["grad_max_diff"] is None + + worker_group.shutdown(force=True) + + finally: + cluster.shutdown()