Skip to content

Commit fc8730c

Browse files
committed
Use distributed_test_runner to ensure visibility for code coverage
Signed-off-by: mloh <mloh@nvidia.com>
1 parent cab076f commit fc8730c

File tree

2 files changed

+403
-182
lines changed

2 files changed

+403
-182
lines changed
Lines changed: 395 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,395 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for DistributedLogprob and ChunkedDistributedLogprob using mp.spawn.
16+
17+
These tests use the distributed_test_runner fixture (torch.multiprocessing.spawn)
18+
so that code coverage is captured by pytest-cov, unlike the Ray actor-based tests
19+
in test_model_utils.py where execution happens in separate Ray worker processes.
20+
"""
21+
22+
import functools
23+
24+
import pytest
25+
import torch
26+
27+
from nemo_rl.distributed.model_utils import (
28+
ChunkedDistributedEntropy,
29+
ChunkedDistributedGatherLogprob,
30+
ChunkedDistributedLogprob,
31+
DistributedLogprob,
32+
_compute_distributed_log_softmax,
33+
)
34+
35+
36+
def _torch_baseline_logprob(full_logits, target):
37+
"""Single-GPU PyTorch baseline for log probability computation."""
38+
log_softmax = torch.nn.functional.log_softmax(full_logits, dim=-1)
39+
log_probs = torch.gather(log_softmax, -1, target.unsqueeze(-1)).squeeze(-1)
40+
target_mask = target >= 0
41+
log_probs = log_probs * target_mask.float()
42+
return log_probs
43+
44+
45+
def _run_logprob_forward_and_backward(rank, world_size, tp_size, chunk_size):
46+
"""Test DistributedLogprob / ChunkedDistributedLogprob forward and backward passes."""
47+
tp_group = torch.distributed.new_group(ranks=list(range(tp_size)))
48+
49+
batch_size = 4
50+
seq_len = 8
51+
full_vocab_size = 1024
52+
vocab_part_size = full_vocab_size // tp_size
53+
54+
vocab_start_index = rank * vocab_part_size
55+
vocab_end_index = (rank + 1) * vocab_part_size
56+
57+
torch.manual_seed(42)
58+
full_logits = torch.randn(
59+
batch_size, seq_len, full_vocab_size, device="cuda", requires_grad=True
60+
)
61+
62+
vocab_parallel_logits = (
63+
full_logits[:, :, vocab_start_index:vocab_end_index]
64+
.clone()
65+
.detach()
66+
.requires_grad_(True)
67+
)
68+
69+
torch.manual_seed(43)
70+
target = torch.randint(0, full_vocab_size, (batch_size, seq_len), device="cuda")
71+
72+
# === FORWARD PASS ===
73+
baseline_log_probs_forward = _torch_baseline_logprob(
74+
full_logits.clone().detach(), target
75+
)
76+
77+
if chunk_size is not None:
78+
distributed_log_probs_inference = ChunkedDistributedLogprob.apply(
79+
vocab_parallel_logits.clone().detach(),
80+
target,
81+
vocab_start_index,
82+
vocab_end_index,
83+
chunk_size,
84+
tp_group,
85+
True,
86+
)
87+
else:
88+
distributed_log_probs_inference = DistributedLogprob.apply(
89+
vocab_parallel_logits.clone().detach(),
90+
target,
91+
vocab_start_index,
92+
vocab_end_index,
93+
tp_group,
94+
True,
95+
)
96+
97+
torch.testing.assert_close(
98+
distributed_log_probs_inference,
99+
baseline_log_probs_forward,
100+
rtol=1e-4,
101+
atol=1e-4,
102+
)
103+
104+
# === BACKWARD PASS ===
105+
baseline_log_probs = _torch_baseline_logprob(full_logits, target)
106+
baseline_loss = torch.sum(baseline_log_probs)
107+
baseline_loss.backward()
108+
baseline_grad = full_logits.grad[:, :, vocab_start_index:vocab_end_index].clone()
109+
110+
full_logits.grad = None
111+
112+
if chunk_size is not None:
113+
distributed_log_probs = ChunkedDistributedLogprob.apply(
114+
vocab_parallel_logits,
115+
target,
116+
vocab_start_index,
117+
vocab_end_index,
118+
chunk_size,
119+
tp_group,
120+
False,
121+
)
122+
else:
123+
distributed_log_probs = DistributedLogprob.apply(
124+
vocab_parallel_logits,
125+
target,
126+
vocab_start_index,
127+
vocab_end_index,
128+
tp_group,
129+
False,
130+
)
131+
132+
distributed_loss = torch.sum(distributed_log_probs)
133+
distributed_loss.backward()
134+
distributed_grad = vocab_parallel_logits.grad
135+
136+
torch.testing.assert_close(
137+
distributed_grad, baseline_grad, rtol=1e-4, atol=1e-4
138+
)
139+
torch.testing.assert_close(
140+
distributed_log_probs, baseline_log_probs, rtol=1e-4, atol=1e-4
141+
)
142+
143+
144+
def _run_log_softmax(rank, world_size, tp_size):
145+
"""Test _compute_distributed_log_softmax against PyTorch baseline."""
146+
tp_group = torch.distributed.new_group(ranks=list(range(tp_size)))
147+
148+
batch_size = 3
149+
seq_len = 5
150+
full_vocab_size = 256
151+
vocab_part_size = full_vocab_size // tp_size
152+
153+
vocab_start_index = rank * vocab_part_size
154+
vocab_end_index = (rank + 1) * vocab_part_size
155+
156+
torch.manual_seed(42)
157+
full_logits = torch.randn(batch_size, seq_len, full_vocab_size, device="cuda")
158+
vocab_parallel_logits = full_logits[:, :, vocab_start_index:vocab_end_index].clone()
159+
160+
baseline_log_softmax = torch.nn.functional.log_softmax(full_logits, dim=-1)
161+
expected = baseline_log_softmax[:, :, vocab_start_index:vocab_end_index]
162+
163+
distributed = _compute_distributed_log_softmax(vocab_parallel_logits, tp_group)
164+
165+
torch.testing.assert_close(distributed, expected, rtol=1e-5, atol=1e-5)
166+
167+
168+
def _run_edge_cases(rank, world_size, tp_size):
169+
"""Test numerical stability and boundary conditions for DistributedLogprob."""
170+
tp_group = torch.distributed.new_group(ranks=list(range(tp_size)))
171+
172+
batch_size = 2
173+
seq_len = 3
174+
full_vocab_size = 64
175+
vocab_part_size = full_vocab_size // tp_size
176+
177+
vocab_start_index = rank * vocab_part_size
178+
vocab_end_index = (rank + 1) * vocab_part_size
179+
180+
# Large logits — should not produce NaN or Inf
181+
torch.manual_seed(42)
182+
large_logits = torch.randn(batch_size, seq_len, full_vocab_size, device="cuda") * 100
183+
vocab_parallel_logits = large_logits[:, :, vocab_start_index:vocab_end_index].clone()
184+
185+
torch.manual_seed(43)
186+
target = torch.randint(0, full_vocab_size, (batch_size, seq_len), device="cuda")
187+
188+
log_probs = DistributedLogprob.apply(
189+
vocab_parallel_logits, target, vocab_start_index, vocab_end_index, tp_group, True
190+
)
191+
192+
assert not torch.isnan(log_probs).any(), "Log probs contain NaN"
193+
assert not torch.isinf(log_probs).any(), "Log probs contain Inf"
194+
195+
# All targets pointing to vocab index 0
196+
zero_target = torch.zeros(batch_size, seq_len, dtype=torch.long, device="cuda")
197+
198+
log_probs_zero = DistributedLogprob.apply(
199+
vocab_parallel_logits, zero_target, vocab_start_index, vocab_end_index, tp_group, True
200+
)
201+
202+
torch.manual_seed(42)
203+
baseline_large_logits = torch.randn(batch_size, seq_len, full_vocab_size, device="cuda") * 100
204+
baseline_log_probs = _torch_baseline_logprob(baseline_large_logits, zero_target)
205+
206+
torch.testing.assert_close(log_probs_zero, baseline_log_probs, rtol=1e-4, atol=1e-4)
207+
208+
209+
# ---------------------------------------------------------------------------
210+
# Pytest test functions
211+
# ---------------------------------------------------------------------------
212+
213+
214+
@pytest.mark.parametrize(
215+
"tp_size, chunk_size",
216+
[
217+
(1, None),
218+
(2, None),
219+
(1, 4),
220+
(2, 4),
221+
],
222+
)
223+
def test_distributed_logprob_forward_and_backward(
224+
distributed_test_runner, tp_size, chunk_size
225+
):
226+
test_fn = functools.partial(
227+
_run_logprob_forward_and_backward, tp_size=tp_size, chunk_size=chunk_size
228+
)
229+
distributed_test_runner(test_fn, world_size=tp_size)
230+
231+
232+
@pytest.mark.parametrize("tp_size", [1, 2])
233+
def test_distributed_log_softmax(distributed_test_runner, tp_size):
234+
test_fn = functools.partial(_run_log_softmax, tp_size=tp_size)
235+
distributed_test_runner(test_fn, world_size=tp_size)
236+
237+
238+
def test_distributed_logprob_edge_cases(distributed_test_runner):
239+
test_fn = functools.partial(_run_edge_cases, tp_size=2)
240+
distributed_test_runner(test_fn, world_size=2)
241+
242+
243+
# ---------------------------------------------------------------------------
244+
# ChunkedDistributedGatherLogprob
245+
# ---------------------------------------------------------------------------
246+
247+
248+
def _run_chunked_gather_logprob(rank, world_size, tp_size, chunk_size, inference_only):
249+
"""Test ChunkedDistributedGatherLogprob forward (and optionally backward)."""
250+
tp_group = torch.distributed.new_group(ranks=list(range(tp_size)))
251+
252+
batch_size = 2
253+
seq_len = 16
254+
vocab_size = 256
255+
gather_k = 3
256+
257+
torch.manual_seed(1337)
258+
full_logits = torch.randn(batch_size, seq_len, vocab_size, device="cuda")
259+
global_indices = torch.randint(
260+
low=0, high=vocab_size, size=(batch_size, seq_len, gather_k), device="cuda"
261+
)
262+
263+
vocab_part_size = vocab_size // tp_size
264+
vocab_start_index = rank * vocab_part_size
265+
vocab_end_index = (rank + 1) * vocab_part_size
266+
267+
# Baseline: single-GPU log_softmax + gather
268+
baseline_logits = full_logits.clone().detach().requires_grad_(not inference_only)
269+
baseline_log_probs = torch.nn.functional.log_softmax(baseline_logits, dim=-1)
270+
baseline_selected = torch.gather(
271+
baseline_log_probs, dim=-1, index=global_indices
272+
)
273+
274+
if not inference_only:
275+
torch.gather(
276+
baseline_log_probs, dim=-1, index=global_indices
277+
).sum().backward()
278+
baseline_grad = baseline_logits.grad[:, :, vocab_start_index:vocab_end_index]
279+
280+
# Distributed path
281+
local_logits = full_logits[:, :, vocab_start_index:vocab_end_index]
282+
local_logits = local_logits.clone().detach().requires_grad_(not inference_only)
283+
284+
gathered = ChunkedDistributedGatherLogprob.apply(
285+
local_logits,
286+
global_indices,
287+
vocab_start_index,
288+
vocab_end_index,
289+
chunk_size,
290+
tp_group,
291+
inference_only,
292+
)
293+
294+
torch.testing.assert_close(gathered, baseline_selected, rtol=1e-4, atol=1e-4)
295+
296+
if not inference_only:
297+
gathered.sum().backward()
298+
torch.testing.assert_close(
299+
local_logits.grad, baseline_grad, rtol=1e-4, atol=1e-4
300+
)
301+
302+
303+
@pytest.mark.parametrize(
304+
"tp_size, chunk_size, inference_only",
305+
[
306+
(1, 5, False),
307+
(2, 4, False),
308+
(1, 3, True),
309+
],
310+
)
311+
def test_chunked_distributed_gather_logprob(
312+
distributed_test_runner, tp_size, chunk_size, inference_only
313+
):
314+
test_fn = functools.partial(
315+
_run_chunked_gather_logprob,
316+
tp_size=tp_size,
317+
chunk_size=chunk_size,
318+
inference_only=inference_only,
319+
)
320+
distributed_test_runner(test_fn, world_size=tp_size)
321+
322+
323+
# ---------------------------------------------------------------------------
324+
# ChunkedDistributedEntropy
325+
# ---------------------------------------------------------------------------
326+
327+
328+
def _run_chunked_distributed_entropy(
329+
rank, world_size, tp_size, chunk_size, inference_only
330+
):
331+
"""Test ChunkedDistributedEntropy forward (and optionally backward)."""
332+
tp_group = torch.distributed.new_group(ranks=list(range(tp_size)))
333+
334+
batch_size = 2
335+
seq_len = 16
336+
vocab_size = 256
337+
vocab_part_size = vocab_size // tp_size
338+
vocab_start_index = rank * vocab_part_size
339+
vocab_end_index = (rank + 1) * vocab_part_size
340+
341+
torch.manual_seed(1337)
342+
full_logits = torch.randn(batch_size, seq_len, vocab_size, device="cuda")
343+
344+
# Baseline: single-GPU entropy H = sum_v p_v * log(p_v)
345+
baseline_logits = full_logits.clone().detach().requires_grad_(not inference_only)
346+
baseline_log_probs = torch.nn.functional.log_softmax(baseline_logits, dim=-1)
347+
baseline_probs = baseline_log_probs.exp()
348+
baseline_entropy = (baseline_probs * baseline_log_probs).sum(dim=-1)
349+
350+
if not inference_only:
351+
baseline_entropy.sum().backward()
352+
baseline_grad = baseline_logits.grad[
353+
:, :, vocab_start_index:vocab_end_index
354+
].clone()
355+
356+
# Distributed path
357+
local_logits = full_logits[:, :, vocab_start_index:vocab_end_index]
358+
local_logits = local_logits.clone().detach().requires_grad_(not inference_only)
359+
360+
distributed_entropy = ChunkedDistributedEntropy.apply(
361+
local_logits,
362+
chunk_size,
363+
tp_group,
364+
inference_only,
365+
)
366+
367+
torch.testing.assert_close(
368+
distributed_entropy, baseline_entropy, rtol=1e-4, atol=1e-4
369+
)
370+
371+
if not inference_only:
372+
distributed_entropy.sum().backward()
373+
torch.testing.assert_close(
374+
local_logits.grad, baseline_grad, rtol=1e-4, atol=1e-4
375+
)
376+
377+
378+
@pytest.mark.parametrize(
379+
"tp_size, chunk_size, inference_only",
380+
[
381+
(1, 5, False),
382+
(2, 4, False),
383+
(1, 3, True),
384+
],
385+
)
386+
def test_chunked_distributed_entropy(
387+
distributed_test_runner, tp_size, chunk_size, inference_only
388+
):
389+
test_fn = functools.partial(
390+
_run_chunked_distributed_entropy,
391+
tp_size=tp_size,
392+
chunk_size=chunk_size,
393+
inference_only=inference_only,
394+
)
395+
distributed_test_runner(test_fn, world_size=tp_size)

0 commit comments

Comments
 (0)