|
| 1 | +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""Tests for DistributedLogprob and ChunkedDistributedLogprob using mp.spawn. |
| 16 | +
|
| 17 | +These tests use the distributed_test_runner fixture (torch.multiprocessing.spawn) |
| 18 | +so that code coverage is captured by pytest-cov, unlike the Ray actor-based tests |
| 19 | +in test_model_utils.py where execution happens in separate Ray worker processes. |
| 20 | +""" |
| 21 | + |
| 22 | +import functools |
| 23 | + |
| 24 | +import pytest |
| 25 | +import torch |
| 26 | + |
| 27 | +from nemo_rl.distributed.model_utils import ( |
| 28 | + ChunkedDistributedEntropy, |
| 29 | + ChunkedDistributedGatherLogprob, |
| 30 | + ChunkedDistributedLogprob, |
| 31 | + DistributedLogprob, |
| 32 | + _compute_distributed_log_softmax, |
| 33 | +) |
| 34 | + |
| 35 | + |
| 36 | +def _torch_baseline_logprob(full_logits, target): |
| 37 | + """Single-GPU PyTorch baseline for log probability computation.""" |
| 38 | + log_softmax = torch.nn.functional.log_softmax(full_logits, dim=-1) |
| 39 | + log_probs = torch.gather(log_softmax, -1, target.unsqueeze(-1)).squeeze(-1) |
| 40 | + target_mask = target >= 0 |
| 41 | + log_probs = log_probs * target_mask.float() |
| 42 | + return log_probs |
| 43 | + |
| 44 | + |
| 45 | +def _run_logprob_forward_and_backward(rank, world_size, tp_size, chunk_size): |
| 46 | + """Test DistributedLogprob / ChunkedDistributedLogprob forward and backward passes.""" |
| 47 | + tp_group = torch.distributed.new_group(ranks=list(range(tp_size))) |
| 48 | + |
| 49 | + batch_size = 4 |
| 50 | + seq_len = 8 |
| 51 | + full_vocab_size = 1024 |
| 52 | + vocab_part_size = full_vocab_size // tp_size |
| 53 | + |
| 54 | + vocab_start_index = rank * vocab_part_size |
| 55 | + vocab_end_index = (rank + 1) * vocab_part_size |
| 56 | + |
| 57 | + torch.manual_seed(42) |
| 58 | + full_logits = torch.randn( |
| 59 | + batch_size, seq_len, full_vocab_size, device="cuda", requires_grad=True |
| 60 | + ) |
| 61 | + |
| 62 | + vocab_parallel_logits = ( |
| 63 | + full_logits[:, :, vocab_start_index:vocab_end_index] |
| 64 | + .clone() |
| 65 | + .detach() |
| 66 | + .requires_grad_(True) |
| 67 | + ) |
| 68 | + |
| 69 | + torch.manual_seed(43) |
| 70 | + target = torch.randint(0, full_vocab_size, (batch_size, seq_len), device="cuda") |
| 71 | + |
| 72 | + # === FORWARD PASS === |
| 73 | + baseline_log_probs_forward = _torch_baseline_logprob( |
| 74 | + full_logits.clone().detach(), target |
| 75 | + ) |
| 76 | + |
| 77 | + if chunk_size is not None: |
| 78 | + distributed_log_probs_inference = ChunkedDistributedLogprob.apply( |
| 79 | + vocab_parallel_logits.clone().detach(), |
| 80 | + target, |
| 81 | + vocab_start_index, |
| 82 | + vocab_end_index, |
| 83 | + chunk_size, |
| 84 | + tp_group, |
| 85 | + True, |
| 86 | + ) |
| 87 | + else: |
| 88 | + distributed_log_probs_inference = DistributedLogprob.apply( |
| 89 | + vocab_parallel_logits.clone().detach(), |
| 90 | + target, |
| 91 | + vocab_start_index, |
| 92 | + vocab_end_index, |
| 93 | + tp_group, |
| 94 | + True, |
| 95 | + ) |
| 96 | + |
| 97 | + torch.testing.assert_close( |
| 98 | + distributed_log_probs_inference, |
| 99 | + baseline_log_probs_forward, |
| 100 | + rtol=1e-4, |
| 101 | + atol=1e-4, |
| 102 | + ) |
| 103 | + |
| 104 | + # === BACKWARD PASS === |
| 105 | + baseline_log_probs = _torch_baseline_logprob(full_logits, target) |
| 106 | + baseline_loss = torch.sum(baseline_log_probs) |
| 107 | + baseline_loss.backward() |
| 108 | + baseline_grad = full_logits.grad[:, :, vocab_start_index:vocab_end_index].clone() |
| 109 | + |
| 110 | + full_logits.grad = None |
| 111 | + |
| 112 | + if chunk_size is not None: |
| 113 | + distributed_log_probs = ChunkedDistributedLogprob.apply( |
| 114 | + vocab_parallel_logits, |
| 115 | + target, |
| 116 | + vocab_start_index, |
| 117 | + vocab_end_index, |
| 118 | + chunk_size, |
| 119 | + tp_group, |
| 120 | + False, |
| 121 | + ) |
| 122 | + else: |
| 123 | + distributed_log_probs = DistributedLogprob.apply( |
| 124 | + vocab_parallel_logits, |
| 125 | + target, |
| 126 | + vocab_start_index, |
| 127 | + vocab_end_index, |
| 128 | + tp_group, |
| 129 | + False, |
| 130 | + ) |
| 131 | + |
| 132 | + distributed_loss = torch.sum(distributed_log_probs) |
| 133 | + distributed_loss.backward() |
| 134 | + distributed_grad = vocab_parallel_logits.grad |
| 135 | + |
| 136 | + torch.testing.assert_close( |
| 137 | + distributed_grad, baseline_grad, rtol=1e-4, atol=1e-4 |
| 138 | + ) |
| 139 | + torch.testing.assert_close( |
| 140 | + distributed_log_probs, baseline_log_probs, rtol=1e-4, atol=1e-4 |
| 141 | + ) |
| 142 | + |
| 143 | + |
| 144 | +def _run_log_softmax(rank, world_size, tp_size): |
| 145 | + """Test _compute_distributed_log_softmax against PyTorch baseline.""" |
| 146 | + tp_group = torch.distributed.new_group(ranks=list(range(tp_size))) |
| 147 | + |
| 148 | + batch_size = 3 |
| 149 | + seq_len = 5 |
| 150 | + full_vocab_size = 256 |
| 151 | + vocab_part_size = full_vocab_size // tp_size |
| 152 | + |
| 153 | + vocab_start_index = rank * vocab_part_size |
| 154 | + vocab_end_index = (rank + 1) * vocab_part_size |
| 155 | + |
| 156 | + torch.manual_seed(42) |
| 157 | + full_logits = torch.randn(batch_size, seq_len, full_vocab_size, device="cuda") |
| 158 | + vocab_parallel_logits = full_logits[:, :, vocab_start_index:vocab_end_index].clone() |
| 159 | + |
| 160 | + baseline_log_softmax = torch.nn.functional.log_softmax(full_logits, dim=-1) |
| 161 | + expected = baseline_log_softmax[:, :, vocab_start_index:vocab_end_index] |
| 162 | + |
| 163 | + distributed = _compute_distributed_log_softmax(vocab_parallel_logits, tp_group) |
| 164 | + |
| 165 | + torch.testing.assert_close(distributed, expected, rtol=1e-5, atol=1e-5) |
| 166 | + |
| 167 | + |
| 168 | +def _run_edge_cases(rank, world_size, tp_size): |
| 169 | + """Test numerical stability and boundary conditions for DistributedLogprob.""" |
| 170 | + tp_group = torch.distributed.new_group(ranks=list(range(tp_size))) |
| 171 | + |
| 172 | + batch_size = 2 |
| 173 | + seq_len = 3 |
| 174 | + full_vocab_size = 64 |
| 175 | + vocab_part_size = full_vocab_size // tp_size |
| 176 | + |
| 177 | + vocab_start_index = rank * vocab_part_size |
| 178 | + vocab_end_index = (rank + 1) * vocab_part_size |
| 179 | + |
| 180 | + # Large logits — should not produce NaN or Inf |
| 181 | + torch.manual_seed(42) |
| 182 | + large_logits = torch.randn(batch_size, seq_len, full_vocab_size, device="cuda") * 100 |
| 183 | + vocab_parallel_logits = large_logits[:, :, vocab_start_index:vocab_end_index].clone() |
| 184 | + |
| 185 | + torch.manual_seed(43) |
| 186 | + target = torch.randint(0, full_vocab_size, (batch_size, seq_len), device="cuda") |
| 187 | + |
| 188 | + log_probs = DistributedLogprob.apply( |
| 189 | + vocab_parallel_logits, target, vocab_start_index, vocab_end_index, tp_group, True |
| 190 | + ) |
| 191 | + |
| 192 | + assert not torch.isnan(log_probs).any(), "Log probs contain NaN" |
| 193 | + assert not torch.isinf(log_probs).any(), "Log probs contain Inf" |
| 194 | + |
| 195 | + # All targets pointing to vocab index 0 |
| 196 | + zero_target = torch.zeros(batch_size, seq_len, dtype=torch.long, device="cuda") |
| 197 | + |
| 198 | + log_probs_zero = DistributedLogprob.apply( |
| 199 | + vocab_parallel_logits, zero_target, vocab_start_index, vocab_end_index, tp_group, True |
| 200 | + ) |
| 201 | + |
| 202 | + torch.manual_seed(42) |
| 203 | + baseline_large_logits = torch.randn(batch_size, seq_len, full_vocab_size, device="cuda") * 100 |
| 204 | + baseline_log_probs = _torch_baseline_logprob(baseline_large_logits, zero_target) |
| 205 | + |
| 206 | + torch.testing.assert_close(log_probs_zero, baseline_log_probs, rtol=1e-4, atol=1e-4) |
| 207 | + |
| 208 | + |
| 209 | +# --------------------------------------------------------------------------- |
| 210 | +# Pytest test functions |
| 211 | +# --------------------------------------------------------------------------- |
| 212 | + |
| 213 | + |
| 214 | +@pytest.mark.parametrize( |
| 215 | + "tp_size, chunk_size", |
| 216 | + [ |
| 217 | + (1, None), |
| 218 | + (2, None), |
| 219 | + (1, 4), |
| 220 | + (2, 4), |
| 221 | + ], |
| 222 | +) |
| 223 | +def test_distributed_logprob_forward_and_backward( |
| 224 | + distributed_test_runner, tp_size, chunk_size |
| 225 | +): |
| 226 | + test_fn = functools.partial( |
| 227 | + _run_logprob_forward_and_backward, tp_size=tp_size, chunk_size=chunk_size |
| 228 | + ) |
| 229 | + distributed_test_runner(test_fn, world_size=tp_size) |
| 230 | + |
| 231 | + |
| 232 | +@pytest.mark.parametrize("tp_size", [1, 2]) |
| 233 | +def test_distributed_log_softmax(distributed_test_runner, tp_size): |
| 234 | + test_fn = functools.partial(_run_log_softmax, tp_size=tp_size) |
| 235 | + distributed_test_runner(test_fn, world_size=tp_size) |
| 236 | + |
| 237 | + |
| 238 | +def test_distributed_logprob_edge_cases(distributed_test_runner): |
| 239 | + test_fn = functools.partial(_run_edge_cases, tp_size=2) |
| 240 | + distributed_test_runner(test_fn, world_size=2) |
| 241 | + |
| 242 | + |
| 243 | +# --------------------------------------------------------------------------- |
| 244 | +# ChunkedDistributedGatherLogprob |
| 245 | +# --------------------------------------------------------------------------- |
| 246 | + |
| 247 | + |
| 248 | +def _run_chunked_gather_logprob(rank, world_size, tp_size, chunk_size, inference_only): |
| 249 | + """Test ChunkedDistributedGatherLogprob forward (and optionally backward).""" |
| 250 | + tp_group = torch.distributed.new_group(ranks=list(range(tp_size))) |
| 251 | + |
| 252 | + batch_size = 2 |
| 253 | + seq_len = 16 |
| 254 | + vocab_size = 256 |
| 255 | + gather_k = 3 |
| 256 | + |
| 257 | + torch.manual_seed(1337) |
| 258 | + full_logits = torch.randn(batch_size, seq_len, vocab_size, device="cuda") |
| 259 | + global_indices = torch.randint( |
| 260 | + low=0, high=vocab_size, size=(batch_size, seq_len, gather_k), device="cuda" |
| 261 | + ) |
| 262 | + |
| 263 | + vocab_part_size = vocab_size // tp_size |
| 264 | + vocab_start_index = rank * vocab_part_size |
| 265 | + vocab_end_index = (rank + 1) * vocab_part_size |
| 266 | + |
| 267 | + # Baseline: single-GPU log_softmax + gather |
| 268 | + baseline_logits = full_logits.clone().detach().requires_grad_(not inference_only) |
| 269 | + baseline_log_probs = torch.nn.functional.log_softmax(baseline_logits, dim=-1) |
| 270 | + baseline_selected = torch.gather( |
| 271 | + baseline_log_probs, dim=-1, index=global_indices |
| 272 | + ) |
| 273 | + |
| 274 | + if not inference_only: |
| 275 | + torch.gather( |
| 276 | + baseline_log_probs, dim=-1, index=global_indices |
| 277 | + ).sum().backward() |
| 278 | + baseline_grad = baseline_logits.grad[:, :, vocab_start_index:vocab_end_index] |
| 279 | + |
| 280 | + # Distributed path |
| 281 | + local_logits = full_logits[:, :, vocab_start_index:vocab_end_index] |
| 282 | + local_logits = local_logits.clone().detach().requires_grad_(not inference_only) |
| 283 | + |
| 284 | + gathered = ChunkedDistributedGatherLogprob.apply( |
| 285 | + local_logits, |
| 286 | + global_indices, |
| 287 | + vocab_start_index, |
| 288 | + vocab_end_index, |
| 289 | + chunk_size, |
| 290 | + tp_group, |
| 291 | + inference_only, |
| 292 | + ) |
| 293 | + |
| 294 | + torch.testing.assert_close(gathered, baseline_selected, rtol=1e-4, atol=1e-4) |
| 295 | + |
| 296 | + if not inference_only: |
| 297 | + gathered.sum().backward() |
| 298 | + torch.testing.assert_close( |
| 299 | + local_logits.grad, baseline_grad, rtol=1e-4, atol=1e-4 |
| 300 | + ) |
| 301 | + |
| 302 | + |
| 303 | +@pytest.mark.parametrize( |
| 304 | + "tp_size, chunk_size, inference_only", |
| 305 | + [ |
| 306 | + (1, 5, False), |
| 307 | + (2, 4, False), |
| 308 | + (1, 3, True), |
| 309 | + ], |
| 310 | +) |
| 311 | +def test_chunked_distributed_gather_logprob( |
| 312 | + distributed_test_runner, tp_size, chunk_size, inference_only |
| 313 | +): |
| 314 | + test_fn = functools.partial( |
| 315 | + _run_chunked_gather_logprob, |
| 316 | + tp_size=tp_size, |
| 317 | + chunk_size=chunk_size, |
| 318 | + inference_only=inference_only, |
| 319 | + ) |
| 320 | + distributed_test_runner(test_fn, world_size=tp_size) |
| 321 | + |
| 322 | + |
| 323 | +# --------------------------------------------------------------------------- |
| 324 | +# ChunkedDistributedEntropy |
| 325 | +# --------------------------------------------------------------------------- |
| 326 | + |
| 327 | + |
| 328 | +def _run_chunked_distributed_entropy( |
| 329 | + rank, world_size, tp_size, chunk_size, inference_only |
| 330 | +): |
| 331 | + """Test ChunkedDistributedEntropy forward (and optionally backward).""" |
| 332 | + tp_group = torch.distributed.new_group(ranks=list(range(tp_size))) |
| 333 | + |
| 334 | + batch_size = 2 |
| 335 | + seq_len = 16 |
| 336 | + vocab_size = 256 |
| 337 | + vocab_part_size = vocab_size // tp_size |
| 338 | + vocab_start_index = rank * vocab_part_size |
| 339 | + vocab_end_index = (rank + 1) * vocab_part_size |
| 340 | + |
| 341 | + torch.manual_seed(1337) |
| 342 | + full_logits = torch.randn(batch_size, seq_len, vocab_size, device="cuda") |
| 343 | + |
| 344 | + # Baseline: single-GPU entropy H = sum_v p_v * log(p_v) |
| 345 | + baseline_logits = full_logits.clone().detach().requires_grad_(not inference_only) |
| 346 | + baseline_log_probs = torch.nn.functional.log_softmax(baseline_logits, dim=-1) |
| 347 | + baseline_probs = baseline_log_probs.exp() |
| 348 | + baseline_entropy = (baseline_probs * baseline_log_probs).sum(dim=-1) |
| 349 | + |
| 350 | + if not inference_only: |
| 351 | + baseline_entropy.sum().backward() |
| 352 | + baseline_grad = baseline_logits.grad[ |
| 353 | + :, :, vocab_start_index:vocab_end_index |
| 354 | + ].clone() |
| 355 | + |
| 356 | + # Distributed path |
| 357 | + local_logits = full_logits[:, :, vocab_start_index:vocab_end_index] |
| 358 | + local_logits = local_logits.clone().detach().requires_grad_(not inference_only) |
| 359 | + |
| 360 | + distributed_entropy = ChunkedDistributedEntropy.apply( |
| 361 | + local_logits, |
| 362 | + chunk_size, |
| 363 | + tp_group, |
| 364 | + inference_only, |
| 365 | + ) |
| 366 | + |
| 367 | + torch.testing.assert_close( |
| 368 | + distributed_entropy, baseline_entropy, rtol=1e-4, atol=1e-4 |
| 369 | + ) |
| 370 | + |
| 371 | + if not inference_only: |
| 372 | + distributed_entropy.sum().backward() |
| 373 | + torch.testing.assert_close( |
| 374 | + local_logits.grad, baseline_grad, rtol=1e-4, atol=1e-4 |
| 375 | + ) |
| 376 | + |
| 377 | + |
| 378 | +@pytest.mark.parametrize( |
| 379 | + "tp_size, chunk_size, inference_only", |
| 380 | + [ |
| 381 | + (1, 5, False), |
| 382 | + (2, 4, False), |
| 383 | + (1, 3, True), |
| 384 | + ], |
| 385 | +) |
| 386 | +def test_chunked_distributed_entropy( |
| 387 | + distributed_test_runner, tp_size, chunk_size, inference_only |
| 388 | +): |
| 389 | + test_fn = functools.partial( |
| 390 | + _run_chunked_distributed_entropy, |
| 391 | + tp_size=tp_size, |
| 392 | + chunk_size=chunk_size, |
| 393 | + inference_only=inference_only, |
| 394 | + ) |
| 395 | + distributed_test_runner(test_fn, world_size=tp_size) |
0 commit comments