Skip to content

Commit 44d02d5

Browse files
test: move parallel logprobs test to test_ops.py reflecting folder structure
1 parent cbd41f9 commit 44d02d5

File tree

3 files changed

+341
-357
lines changed

3 files changed

+341
-357
lines changed

src/forge/util/ops.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
import torch
88
import torch.distributed as dist
99
import torch.nn.functional as F
10-
1110
from torch.distributed.tensor import DTensor
11+
from torch.distributed.tensor.placement_types import Shard
1212

1313

1414
def compute_logprobs(
@@ -169,8 +169,6 @@ def get_vocab_shard_info(
169169
Tuple of (tp_group, tp_rank, tp_size, vocab_start, vocab_end).
170170
If not sharded, returns (None, 0, 1, 0, vocab_size).
171171
"""
172-
from torch.distributed.tensor.placement_types import Shard
173-
174172
local_logits = logits._local_tensor
175173
placements = logits.placements
176174
device_mesh = logits.device_mesh

tests/unit_tests/util/test_ops.py

Lines changed: 340 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,19 @@
66

77
import pytest
88
import torch
9+
import torch.distributed as dist
910
import torch.nn.functional as F
10-
from forge.util.ops import compute_logprobs
11+
12+
from forge.util.ops import (
13+
compute_logprobs,
14+
compute_logprobs_parallel,
15+
get_vocab_shard_info,
16+
)
17+
18+
from tests.test_utils import gpu_test
19+
from torch.distributed.device_mesh import init_device_mesh
20+
from torch.distributed.tensor import DTensor, Shard
21+
from torch.testing._internal.common_fsdp import FSDPTest
1122

1223

1324
def _textbook_log_softmax(logits: torch.Tensor, input_ids: torch.Tensor):
@@ -162,3 +173,331 @@ def test_align_comparison(self):
162173

163174
# Both should give the same result
164175
assert torch.allclose(result_aligned, result_manual, atol=1e-5)
176+
177+
178+
class TestParallelLogprobs(FSDPTest):
179+
"""Test parallel logprobs against reference implementation."""
180+
181+
@property
182+
def world_size(self) -> int:
183+
return 2
184+
185+
@gpu_test(gpu_count=2)
186+
def test_parallel_logprobs_matches_sequential(self):
187+
"""Verify parallel logprobs produces same results as sequential version."""
188+
torch.manual_seed(42)
189+
190+
batch_size = 4
191+
seq_len = 16
192+
vocab_size = 1000 # Must be divisible by world_size
193+
target_len = 8
194+
195+
rank = dist.get_rank()
196+
device = torch.device(f"cuda:{rank}")
197+
198+
# Create test data on rank 0 and broadcast to ensure consistency
199+
if rank == 0:
200+
# Full logits tensor (what we'd have without sharding)
201+
full_logits = torch.randn(
202+
batch_size, seq_len, vocab_size, dtype=torch.float32, device=device
203+
)
204+
# Target tokens for logprob computation
205+
target_ids = torch.randint(
206+
0, vocab_size, (batch_size, target_len), device=device
207+
)
208+
else:
209+
full_logits = torch.empty(
210+
batch_size, seq_len, vocab_size, dtype=torch.float32, device=device
211+
)
212+
target_ids = torch.empty(
213+
batch_size, target_len, dtype=torch.int64, device=device
214+
)
215+
216+
# Broadcast to all ranks
217+
dist.broadcast(full_logits, src=0)
218+
dist.broadcast(target_ids, src=0)
219+
220+
# Compute reference result using sequential version
221+
expected = compute_logprobs(full_logits, target_ids, align=True)
222+
223+
# Create device mesh for tensor parallel
224+
mesh = init_device_mesh("cuda", (self.world_size,), mesh_dim_names=("tp",))
225+
226+
# Create DTensor sharded on vocab dimension (dim=2)
227+
# Each rank gets vocab_size // world_size columns
228+
dtensor_logits = DTensor.from_local(
229+
full_logits[
230+
:, :, rank * (vocab_size // 2) : (rank + 1) * (vocab_size // 2)
231+
],
232+
mesh,
233+
placements=[Shard(2)], # Shard on vocab dimension
234+
)
235+
236+
# Compute parallel result
237+
result = compute_logprobs_parallel(dtensor_logits, target_ids, align=True)
238+
239+
# Verify results match
240+
torch.testing.assert_close(
241+
result,
242+
expected,
243+
atol=1e-5,
244+
rtol=1e-5,
245+
msg="Parallel logprobs should match sequential version",
246+
)
247+
248+
@gpu_test(gpu_count=2)
249+
def test_parallel_logprobs_with_temperature(self):
250+
"""Test parallel logprobs with different temperature values."""
251+
torch.manual_seed(123)
252+
253+
batch_size = 2
254+
seq_len = 10
255+
vocab_size = 500
256+
target_len = 5
257+
258+
rank = dist.get_rank()
259+
device = torch.device(f"cuda:{rank}")
260+
261+
if rank == 0:
262+
full_logits = torch.randn(
263+
batch_size, seq_len, vocab_size, dtype=torch.float32, device=device
264+
)
265+
target_ids = torch.randint(
266+
0, vocab_size, (batch_size, target_len), device=device
267+
)
268+
else:
269+
full_logits = torch.empty(
270+
batch_size, seq_len, vocab_size, dtype=torch.float32, device=device
271+
)
272+
target_ids = torch.empty(
273+
batch_size, target_len, dtype=torch.int64, device=device
274+
)
275+
276+
dist.broadcast(full_logits, src=0)
277+
dist.broadcast(target_ids, src=0)
278+
279+
mesh = init_device_mesh("cuda", (self.world_size,), mesh_dim_names=("tp",))
280+
local_vocab = vocab_size // self.world_size
281+
dtensor_logits = DTensor.from_local(
282+
full_logits[:, :, rank * local_vocab : (rank + 1) * local_vocab],
283+
mesh,
284+
placements=[Shard(2)],
285+
)
286+
287+
for temperature in [0.5, 1.0, 2.0]:
288+
expected = compute_logprobs(
289+
full_logits, target_ids, temperature=temperature, align=True
290+
)
291+
result = compute_logprobs_parallel(
292+
dtensor_logits, target_ids, temperature=temperature, align=True
293+
)
294+
torch.testing.assert_close(
295+
result,
296+
expected,
297+
atol=1e-5,
298+
rtol=1e-5,
299+
msg=f"Failed with temperature={temperature}",
300+
)
301+
302+
@gpu_test(gpu_count=2)
303+
def test_parallel_logprobs_align_false(self):
304+
"""Test parallel logprobs with align=False."""
305+
torch.manual_seed(456)
306+
307+
batch_size = 3
308+
seq_len = 8
309+
vocab_size = 200
310+
311+
rank = dist.get_rank()
312+
device = torch.device(f"cuda:{rank}")
313+
314+
if rank == 0:
315+
full_logits = torch.randn(
316+
batch_size, seq_len, vocab_size, dtype=torch.float32, device=device
317+
)
318+
# With align=False, target_ids same length as seq_len
319+
target_ids = torch.randint(
320+
0, vocab_size, (batch_size, seq_len), device=device
321+
)
322+
else:
323+
full_logits = torch.empty(
324+
batch_size, seq_len, vocab_size, dtype=torch.float32, device=device
325+
)
326+
target_ids = torch.empty(
327+
batch_size, seq_len, dtype=torch.int64, device=device
328+
)
329+
330+
dist.broadcast(full_logits, src=0)
331+
dist.broadcast(target_ids, src=0)
332+
333+
expected = compute_logprobs(full_logits, target_ids, align=False)
334+
335+
mesh = init_device_mesh("cuda", (self.world_size,), mesh_dim_names=("tp",))
336+
local_vocab = vocab_size // self.world_size
337+
dtensor_logits = DTensor.from_local(
338+
full_logits[:, :, rank * local_vocab : (rank + 1) * local_vocab],
339+
mesh,
340+
placements=[Shard(2)],
341+
)
342+
343+
result = compute_logprobs_parallel(dtensor_logits, target_ids, align=False)
344+
345+
torch.testing.assert_close(
346+
result,
347+
expected,
348+
atol=1e-5,
349+
rtol=1e-5,
350+
msg="Parallel logprobs with align=False should match",
351+
)
352+
353+
@gpu_test(gpu_count=2)
354+
def test_parallel_logprobs_numerical_stability(self):
355+
"""Test parallel logprobs handles extreme values correctly."""
356+
torch.manual_seed(789)
357+
358+
batch_size = 2
359+
seq_len = 4
360+
vocab_size = 100
361+
target_len = 2
362+
363+
rank = dist.get_rank()
364+
device = torch.device(f"cuda:{rank}")
365+
366+
# Test with large values
367+
if rank == 0:
368+
full_logits = torch.randn(
369+
batch_size, seq_len, vocab_size, dtype=torch.float32, device=device
370+
)
371+
# Add some extreme values
372+
full_logits[:, :, 0] = 1000.0
373+
full_logits[:, :, 50] = -1000.0
374+
target_ids = torch.randint(
375+
0, vocab_size, (batch_size, target_len), device=device
376+
)
377+
else:
378+
full_logits = torch.empty(
379+
batch_size, seq_len, vocab_size, dtype=torch.float32, device=device
380+
)
381+
target_ids = torch.empty(
382+
batch_size, target_len, dtype=torch.int64, device=device
383+
)
384+
385+
dist.broadcast(full_logits, src=0)
386+
dist.broadcast(target_ids, src=0)
387+
388+
expected = compute_logprobs(full_logits, target_ids, align=True)
389+
390+
mesh = init_device_mesh("cuda", (self.world_size,), mesh_dim_names=("tp",))
391+
local_vocab = vocab_size // self.world_size
392+
dtensor_logits = DTensor.from_local(
393+
full_logits[:, :, rank * local_vocab : (rank + 1) * local_vocab],
394+
mesh,
395+
placements=[Shard(2)],
396+
)
397+
398+
result = compute_logprobs_parallel(dtensor_logits, target_ids, align=True)
399+
400+
# Should not have NaN or Inf
401+
assert torch.isfinite(result).all(), "Result contains NaN or Inf"
402+
assert torch.isfinite(expected).all(), "Expected contains NaN or Inf"
403+
404+
torch.testing.assert_close(
405+
result,
406+
expected,
407+
atol=1e-4, # Slightly relaxed for extreme values
408+
rtol=1e-4,
409+
msg="Parallel logprobs should be numerically stable",
410+
)
411+
412+
@gpu_test(gpu_count=2)
413+
def test_get_vocab_shard_info(self):
414+
"""Test vocab shard info extraction."""
415+
torch.manual_seed(111)
416+
417+
batch_size = 2
418+
seq_len = 4
419+
vocab_size = 100
420+
421+
rank = dist.get_rank()
422+
device = torch.device(f"cuda:{rank}")
423+
424+
full_logits = torch.randn(
425+
batch_size, seq_len, vocab_size, dtype=torch.float32, device=device
426+
)
427+
428+
mesh = init_device_mesh("cuda", (self.world_size,), mesh_dim_names=("tp",))
429+
local_vocab = vocab_size // self.world_size
430+
dtensor_logits = DTensor.from_local(
431+
full_logits[:, :, rank * local_vocab : (rank + 1) * local_vocab],
432+
mesh,
433+
placements=[Shard(2)],
434+
)
435+
436+
tp_group, tp_rank, tp_size, vocab_start, vocab_end = get_vocab_shard_info(
437+
dtensor_logits
438+
)
439+
440+
assert tp_group is not None, "Should have TP group for sharded tensor"
441+
assert tp_rank == rank, f"TP rank should be {rank}, got {tp_rank}"
442+
assert tp_size == self.world_size, f"TP size should be {self.world_size}"
443+
assert vocab_start == rank * local_vocab, "Vocab start incorrect"
444+
assert vocab_end == (rank + 1) * local_vocab, "Vocab end incorrect"
445+
446+
447+
class TestParallelLogprobs4GPU(FSDPTest):
448+
"""Test parallel logprobs with 4 GPUs."""
449+
450+
@property
451+
def world_size(self) -> int:
452+
return 4
453+
454+
@gpu_test(gpu_count=4)
455+
def test_parallel_logprobs_4_way_sharding(self):
456+
"""Test with 4-way vocab sharding."""
457+
torch.manual_seed(999)
458+
459+
batch_size = 8
460+
seq_len = 32
461+
vocab_size = 1000 # Divisible by 4
462+
target_len = 16
463+
464+
rank = dist.get_rank()
465+
device = torch.device(f"cuda:{rank}")
466+
467+
if rank == 0:
468+
full_logits = torch.randn(
469+
batch_size, seq_len, vocab_size, dtype=torch.float32, device=device
470+
)
471+
target_ids = torch.randint(
472+
0, vocab_size, (batch_size, target_len), device=device
473+
)
474+
else:
475+
full_logits = torch.empty(
476+
batch_size, seq_len, vocab_size, dtype=torch.float32, device=device
477+
)
478+
target_ids = torch.empty(
479+
batch_size, target_len, dtype=torch.int64, device=device
480+
)
481+
482+
dist.broadcast(full_logits, src=0)
483+
dist.broadcast(target_ids, src=0)
484+
485+
expected = compute_logprobs(full_logits, target_ids, align=True)
486+
487+
mesh = init_device_mesh("cuda", (self.world_size,), mesh_dim_names=("tp",))
488+
local_vocab = vocab_size // self.world_size
489+
dtensor_logits = DTensor.from_local(
490+
full_logits[:, :, rank * local_vocab : (rank + 1) * local_vocab],
491+
mesh,
492+
placements=[Shard(2)],
493+
)
494+
495+
result = compute_logprobs_parallel(dtensor_logits, target_ids, align=True)
496+
497+
torch.testing.assert_close(
498+
result,
499+
expected,
500+
atol=1e-5,
501+
rtol=1e-5,
502+
msg="4-way parallel logprobs should match sequential",
503+
)

0 commit comments

Comments
 (0)