Skip to content

Commit df5e4dc

Browse files
committed
Preallocate gradient tensor for other Chunked* classes
Signed-off-by: mloh <mloh@nvidia.com>
1 parent b16a486 commit df5e4dc

File tree

2 files changed

+218
-44
lines changed

2 files changed

+218
-44
lines changed

nemo_rl/distributed/model_utils.py

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,10 @@ def backward(
333333

334334
B, S, V_local = vocab_parallel_logits.shape
335335
num_chunks = (int(S) + chunk_size - 1) // chunk_size
336-
all_grad_input: list[torch.Tensor] = []
336+
337+
grad_input: torch.Tensor = torch.empty_like(
338+
vocab_parallel_logits, dtype=torch.float32
339+
)
337340

338341
for chunk_idx in range(num_chunks):
339342
s0 = chunk_idx * chunk_size
@@ -351,41 +354,29 @@ def backward(
351354
go_chunk = grad_output[:, s0:s1, :] # [B, Sc, K]
352355
go_sum = go_chunk.sum(dim=-1, keepdim=True) # [B, Sc, 1]
353356

354-
grad_input = softmax_output.neg()
355-
grad_input = grad_input.mul_(go_sum)
357+
# Inplace index into the preallocated grad_input tensor
358+
grad_input_chunk = grad_input[:, s0:s1, :]
359+
360+
grad_input_chunk.copy_(softmax_output.neg().mul_(go_sum)) # inplace copy
356361

357362
# Positive scatter term: add gradients to selected indices
358-
# Mask grad_output for indices not on this shard
359363
go_masked = go_chunk * in_range.to(dtype=go_chunk.dtype)
360-
# Flatten for scatter_add
361-
flat_grad = grad_input.view(-1)
362-
# compute flattened indices positions
363-
Bc, Sc = go_masked.shape[0], go_masked.shape[1]
364-
# row offset per [B, Sc]
365-
row = (
366-
torch.arange(Bc, device=grad_input.device)
367-
.view(-1, 1)
368-
.expand(-1, Sc)
369-
.reshape(-1)
364+
grad_input_chunk.scatter_add_(2, li, go_masked)
365+
366+
# Explicitly free before next iteration allocates
367+
del (
368+
softmax_output,
369+
log_probs,
370+
logits,
371+
gi,
372+
in_range,
373+
li,
374+
go_chunk,
375+
go_sum,
376+
go_masked,
370377
)
371-
col = torch.arange(Sc, device=grad_input.device).expand(Bc, -1).reshape(-1)
372-
flat_idx_base = (row * Sc + col) * V_local # [Bc*Sc]
373-
# selected flat indices
374-
flat_li = li.reshape(-1, li.shape[-1]) # [Bc*Sc, K]
375-
flat_base_expanded = flat_idx_base.unsqueeze(-1).expand_as(flat_li)
376-
flat_chosen = (flat_base_expanded + flat_li).reshape(-1)
377-
flat_go = go_masked.reshape(-1)
378-
flat_grad.scatter_add_(0, flat_chosen, flat_go)
379-
380-
all_grad_input.append(grad_input)
381-
382-
grad_input_total = (
383-
torch.cat(all_grad_input, dim=1)
384-
if len(all_grad_input) > 1
385-
else all_grad_input[0]
386-
)
387378

388-
return grad_input_total, None, None, None, None, None, None
379+
return grad_input, None, None, None, None, None, None
389380

390381

391382
def dtensor_from_parallel_logits_to_logprobs(
@@ -1040,7 +1031,10 @@ def backward(
10401031

10411032
B, S, V_local = vocab_parallel_logits.shape
10421033
num_chunks = (int(S) + chunk_size - 1) // chunk_size
1043-
grads: list[torch.Tensor] = []
1034+
1035+
grad_input: torch.Tensor = torch.empty_like(
1036+
vocab_parallel_logits, dtype=torch.float32
1037+
)
10441038

10451039
for chunk_idx in range(num_chunks):
10461040
s0 = chunk_idx * chunk_size
@@ -1054,10 +1048,16 @@ def backward(
10541048
H_local, op=torch.distributed.ReduceOp.SUM, group=tp_group
10551049
)
10561050

1051+
# Inplace index into the preallocated grad_input tensor
1052+
grad_input_chunk = grad_input[:, s0:s1, :]
1053+
10571054
# dH/dz = softmax * (log_probs - H_all)
1058-
grad_chunk = softmax_output * (log_probs - H_local.unsqueeze(-1))
1059-
grad_chunk.mul_(grad_output[:, s0:s1].unsqueeze(-1))
1060-
grads.append(grad_chunk)
1055+
grad_input_chunk.copy_(
1056+
softmax_output.mul_(log_probs - H_local.unsqueeze(-1))
1057+
) # inplace copy
1058+
grad_input_chunk.mul_(grad_output[:, s0:s1].unsqueeze(-1))
1059+
1060+
# Explicitly free before next iteration allocates
1061+
del softmax_output, log_probs, logits, H_local
10611062

1062-
grad_input = torch.cat(grads, dim=1) if len(grads) > 1 else grads[0]
10631063
return grad_input, None, None, None

tests/unit/distributed/test_model_utils.py

Lines changed: 182 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch
1919

2020
from nemo_rl.distributed.model_utils import (
21+
ChunkedDistributedEntropy,
2122
ChunkedDistributedGatherLogprob,
2223
ChunkedDistributedLogprob,
2324
DistributedLogprob,
@@ -705,14 +706,25 @@ def test_distributed_logprob_forward_and_backward(self):
705706
full_logits.grad = None
706707

707708
# Compute distributed gradients
708-
distributed_log_probs = DistributedLogprob.apply(
709-
vocab_parallel_logits,
710-
target,
711-
vocab_start_index,
712-
vocab_end_index,
713-
self.tp_group,
714-
False, # inference_only=False to enable backward
715-
)
709+
if chunk_size is not None:
710+
distributed_log_probs = ChunkedDistributedLogprob.apply(
711+
vocab_parallel_logits,
712+
target,
713+
vocab_start_index,
714+
vocab_end_index,
715+
chunk_size,
716+
self.tp_group,
717+
False, # inference_only=False to enable backward
718+
)
719+
else:
720+
distributed_log_probs = DistributedLogprob.apply(
721+
vocab_parallel_logits,
722+
target,
723+
vocab_start_index,
724+
vocab_end_index,
725+
self.tp_group,
726+
False, # inference_only=False to enable backward
727+
)
716728

717729
distributed_loss = torch.sum(distributed_log_probs)
718730
distributed_loss.backward()
@@ -956,3 +968,165 @@ def test_distributed_logprob_all_tests(
956968

957969
finally:
958970
cluster.shutdown()
971+
972+
973+
@ray.remote(num_gpus=1)
974+
class ChunkedDistributedEntropyTestActor:
975+
def __init__(self, tp_size, chunk_size, inference_only):
976+
self.tp_size = tp_size
977+
self.chunk_size = chunk_size
978+
self.inference_only = inference_only
979+
self.env_vars = dict(os.environ)
980+
981+
def test_chunked_distributed_entropy(self):
982+
torch.distributed.init_process_group(backend="nccl")
983+
984+
rank = int(os.environ["RANK"])
985+
tp_group = torch.distributed.new_group(ranks=list(range(self.tp_size)))
986+
987+
batch_size = 2
988+
seq_len = 16
989+
vocab_size = 256
990+
vocab_part_size = vocab_size // self.tp_size
991+
vocab_start_index = rank * vocab_part_size
992+
vocab_end_index = (rank + 1) * vocab_part_size
993+
994+
torch.manual_seed(1337)
995+
full_logits = torch.randn(batch_size, seq_len, vocab_size, device="cuda")
996+
997+
# === Baseline: single-GPU entropy ===
998+
baseline_logits = (
999+
full_logits.clone().detach().requires_grad_(not self.inference_only)
1000+
)
1001+
baseline_log_probs = torch.nn.functional.log_softmax(baseline_logits, dim=-1)
1002+
baseline_probs = baseline_log_probs.exp()
1003+
# H = sum_v p_v * log(p_v) (negative entropy; matches ChunkedDistributedEntropy)
1004+
baseline_entropy = (baseline_probs * baseline_log_probs).sum(dim=-1) # [B, S]
1005+
1006+
if not self.inference_only:
1007+
baseline_entropy.sum().backward()
1008+
baseline_grad = baseline_logits.grad[
1009+
:, :, vocab_start_index:vocab_end_index
1010+
].clone()
1011+
1012+
# === Distributed: ChunkedDistributedEntropy ===
1013+
local_logits = full_logits[:, :, vocab_start_index:vocab_end_index]
1014+
local_logits = (
1015+
local_logits.clone().detach().requires_grad_(not self.inference_only)
1016+
)
1017+
1018+
distributed_entropy = ChunkedDistributedEntropy.apply(
1019+
local_logits,
1020+
self.chunk_size,
1021+
tp_group,
1022+
self.inference_only,
1023+
)
1024+
1025+
# Forward check
1026+
torch.testing.assert_close(
1027+
distributed_entropy, baseline_entropy, rtol=1e-4, atol=1e-4
1028+
)
1029+
forward_diff = torch.max(
1030+
torch.abs(distributed_entropy - baseline_entropy)
1031+
).item()
1032+
1033+
# Backward check
1034+
if not self.inference_only:
1035+
distributed_entropy.sum().backward()
1036+
grad_local = local_logits.grad
1037+
torch.testing.assert_close(grad_local, baseline_grad, rtol=1e-4, atol=1e-4)
1038+
grad_diff = torch.max(torch.abs(grad_local - baseline_grad)).item()
1039+
else:
1040+
grad_diff = None
1041+
1042+
return {
1043+
"forward_max_diff": forward_diff,
1044+
"grad_max_diff": grad_diff,
1045+
}
1046+
1047+
1048+
CHUNKED_DISTRIBUTED_ENTROPY_TEST_ACTOR_FQN = f"{ChunkedDistributedEntropyTestActor.__module__}.ChunkedDistributedEntropyTestActor"
1049+
1050+
1051+
@pytest.fixture
1052+
def register_chunked_distributed_entropy_test_actor():
1053+
original_registry_value = ACTOR_ENVIRONMENT_REGISTRY.get(
1054+
CHUNKED_DISTRIBUTED_ENTROPY_TEST_ACTOR_FQN
1055+
)
1056+
ACTOR_ENVIRONMENT_REGISTRY[CHUNKED_DISTRIBUTED_ENTROPY_TEST_ACTOR_FQN] = (
1057+
PY_EXECUTABLES.SYSTEM
1058+
)
1059+
1060+
yield CHUNKED_DISTRIBUTED_ENTROPY_TEST_ACTOR_FQN
1061+
1062+
if CHUNKED_DISTRIBUTED_ENTROPY_TEST_ACTOR_FQN in ACTOR_ENVIRONMENT_REGISTRY:
1063+
if original_registry_value is None:
1064+
del ACTOR_ENVIRONMENT_REGISTRY[CHUNKED_DISTRIBUTED_ENTROPY_TEST_ACTOR_FQN]
1065+
else:
1066+
ACTOR_ENVIRONMENT_REGISTRY[CHUNKED_DISTRIBUTED_ENTROPY_TEST_ACTOR_FQN] = (
1067+
original_registry_value
1068+
)
1069+
1070+
1071+
@pytest.mark.parametrize(
1072+
"tp_size, chunk_size, inference_only",
1073+
[
1074+
(1, 5, False),
1075+
(2, 4, False),
1076+
(1, 3, True),
1077+
],
1078+
)
1079+
def test_chunked_distributed_entropy(
1080+
register_chunked_distributed_entropy_test_actor,
1081+
tp_size,
1082+
chunk_size,
1083+
inference_only,
1084+
):
1085+
world_size = tp_size
1086+
1087+
if not torch.cuda.is_available() or torch.cuda.device_count() < world_size:
1088+
pytest.skip(
1089+
f"Not enough GPUs available. Need {world_size}, "
1090+
f"got {torch.cuda.device_count()}"
1091+
)
1092+
1093+
cluster = RayVirtualCluster(bundle_ct_per_node_list=[world_size], use_gpus=True)
1094+
1095+
try:
1096+
actor_fqn = register_chunked_distributed_entropy_test_actor
1097+
1098+
sharding = NamedSharding(
1099+
layout=np.arange(world_size).reshape(tp_size), names=["tp"]
1100+
)
1101+
builder = RayWorkerBuilder(actor_fqn, tp_size, chunk_size, inference_only)
1102+
1103+
worker_group = RayWorkerGroup(
1104+
cluster=cluster,
1105+
remote_worker_builder=builder,
1106+
workers_per_node=None,
1107+
sharding_annotations=sharding,
1108+
)
1109+
1110+
futures = worker_group.run_all_workers_single_data(
1111+
"test_chunked_distributed_entropy"
1112+
)
1113+
results = ray.get(futures)
1114+
1115+
for i, result in enumerate(results):
1116+
print(f"Worker {i} forward max diff: {result['forward_max_diff']:.2e}")
1117+
assert result["forward_max_diff"] < 1e-4, (
1118+
f"Worker {i} forward diff too large: {result['forward_max_diff']}"
1119+
)
1120+
if not inference_only:
1121+
print(f"Worker {i} grad max diff: {result['grad_max_diff']:.2e}")
1122+
assert (
1123+
result["grad_max_diff"] is not None
1124+
and result["grad_max_diff"] < 1e-4
1125+
), f"Worker {i} grad diff too large: {result['grad_max_diff']}"
1126+
else:
1127+
assert result["grad_max_diff"] is None
1128+
1129+
worker_group.shutdown(force=True)
1130+
1131+
finally:
1132+
cluster.shutdown()

0 commit comments

Comments
 (0)