Skip to content

Commit 4152c0b

Browse files
committed
fix dist log prob test
1 parent 99ba48f commit 4152c0b

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

tests/test_shardformer/test_layer/test_dist_log_prob.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import pytest
22
import torch
3-
from coati.distributed.utils import log_probs_from_logits
43

54
import colossalai
65
from colossalai.logging import disable_existing_loggers
@@ -12,6 +11,22 @@
1211
)
1312

1413

14+
def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
15+
"""
16+
Compute the log probabilities from logits for the given labels.
17+
18+
Args:
19+
logits (torch.Tensor): The input logits.
20+
labels (torch.Tensor): The target labels.
21+
22+
Returns:
23+
torch.Tensor: The log probabilities corresponding to the labels.
24+
"""
25+
log_probs = torch.log_softmax(logits, dim=-1)
26+
per_label_logps = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
27+
return per_label_logps.squeeze(-1)
28+
29+
1530
def check_dist_log_prob(rank, world_size, port):
1631
disable_existing_loggers()
1732
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost", backend="nccl")

0 commit comments

Comments
 (0)