Skip to content

Commit 54e1283

Browse files
committed
update PreferenceLoss and DistillationLossFn
Signed-off-by: Yuki Huang <yukih@nvidia.com>
1 parent 4c7c2ea commit 54e1283

File tree

4 files changed

+275
-179
lines changed

4 files changed

+275
-179
lines changed

nemo_rl/algorithms/loss_functions.py

Lines changed: 4 additions & 162 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,6 @@
2020
from nemo_rl.algorithms.interfaces import LossFunction, LossType
2121
from nemo_rl.algorithms.utils import calculate_kl, masked_mean
2222
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
23-
from nemo_rl.distributed.model_utils import (
24-
ChunkedDistributedEntropy,
25-
ChunkedDistributedGatherLogprob,
26-
_get_tokens_on_this_cp_rank,
27-
allgather_cp_sharded_tensor,
28-
gather_logits_at_global_indices,
29-
)
3023

3124
Tensor = TypeVar("Tensor", bound=torch.Tensor)
3225

@@ -932,165 +925,14 @@ def __init__(self, cfg: DistillationLossConfig):
932925

933926
def __call__(
934927
self,
935-
next_token_logits: torch.Tensor,
928+
student_topk_logprobs: torch.Tensor,
929+
teacher_topk_logprobs: torch.Tensor,
930+
H_all: torch.Tensor | None,
936931
data: DistillationLossDataDict,
937932
global_valid_seqs: torch.Tensor,
938933
global_valid_toks: torch.Tensor,
939-
vocab_parallel_rank: Optional[int] = None,
940-
vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
941-
context_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
942934
) -> tuple[torch.Tensor, dict[str, Any]]:
943935
"""Compute distillation loss between teacher and student logits."""
944-
# Basic shapes
945-
input_ids = data["input_ids"]
946-
batch_size = input_ids.shape[0]
947-
948-
# CP support: get CP group and size
949-
cp_group = context_parallel_group
950-
cp_size = 1 if cp_group is None else torch.distributed.get_world_size(cp_group)
951-
952-
# Ensure float32 for stability (match other losses)
953-
next_token_logits = next_token_logits.to(torch.float32)
954-
per_token_kl = None
955-
# Preferred truncated-KL path: teacher provides top-k support per position
956-
teacher_topk_logits = data["teacher_topk_logits"] # [B, S, k]
957-
teacher_topk_indices = data["teacher_topk_indices"] # [B, S, k]
958-
959-
if teacher_topk_indices.shape[-1] <= 0:
960-
raise ValueError(
961-
f"topk must be positive, got {teacher_topk_indices.shape[-1]}. "
962-
"topk=0 is not supported as it would result in empty tensor operations."
963-
)
964-
965-
# Determine processing path and setup variables
966-
if vocab_parallel_group is not None:
967-
assert vocab_parallel_rank is not None, (
968-
"vocab_parallel_rank must be provided when vocab_parallel_group is provided"
969-
)
970-
V_local = int(next_token_logits.shape[-1])
971-
vocab_start_index = vocab_parallel_rank * V_local
972-
vocab_end_index = (vocab_parallel_rank + 1) * V_local
973-
parallel_group = vocab_parallel_group
974-
logits_tensor = next_token_logits
975-
elif isinstance(next_token_logits, torch.distributed.tensor.DTensor):
976-
device_mesh = next_token_logits.device_mesh
977-
tp_group = device_mesh.get_group("tp")
978-
tp_rank = tp_group.rank()
979-
local_student_logits = next_token_logits.to_local()
980-
V_local = int(local_student_logits.shape[-1])
981-
vocab_start_index = tp_rank * V_local
982-
vocab_end_index = (tp_rank + 1) * V_local
983-
parallel_group = tp_group
984-
logits_tensor = local_student_logits
985-
teacher_topk_indices = teacher_topk_indices.to(local_student_logits.device)
986-
# For DTensor, derive CP group/size from the device mesh to ensure CP-aware alignment
987-
if (
988-
device_mesh.mesh_dim_names is not None
989-
and "cp" in device_mesh.mesh_dim_names
990-
):
991-
cp_group = device_mesh.get_group("cp")
992-
cp_size = cp_group.size()
993-
else:
994-
cp_group = None
995-
cp_size = 1
996-
else:
997-
parallel_group = None
998-
logits_tensor = next_token_logits
999-
1000-
# Process based on zero_outside_topk setting
1001-
if self.zero_outside_topk and parallel_group is not None:
1002-
# Distributed processing with chunking
1003-
indices_local = teacher_topk_indices
1004-
pad_len = 0
1005-
if cp_size > 1:
1006-
pad_len = logits_tensor.shape[1] * cp_size - indices_local.shape[1]
1007-
if pad_len > 0:
1008-
indices_local = torch.nn.functional.pad(
1009-
indices_local, (0, 0, 0, pad_len), value=0
1010-
)
1011-
cp_rank = torch.distributed.get_rank(cp_group)
1012-
indices_local = _get_tokens_on_this_cp_rank(
1013-
indices_local, cp_rank, cp_size, seq_dim=1
1014-
)
1015-
1016-
S_local = int(logits_tensor.shape[1])
1017-
chunk_size = max(1, min(S_local, 1024))
1018-
student_topk_logprobs = ChunkedDistributedGatherLogprob.apply( # type: ignore
1019-
logits_tensor,
1020-
indices_local,
1021-
vocab_start_index,
1022-
vocab_end_index,
1023-
chunk_size,
1024-
parallel_group,
1025-
False,
1026-
)
1027-
1028-
if self.kl_type != "forward":
1029-
H_all = ChunkedDistributedEntropy.apply( # type: ignore
1030-
logits_tensor,
1031-
chunk_size,
1032-
parallel_group,
1033-
False,
1034-
)
1035-
1036-
if cp_size > 1:
1037-
student_topk_logprobs = allgather_cp_sharded_tensor(
1038-
student_topk_logprobs, cp_group, seq_dim=1
1039-
)
1040-
if self.kl_type != "forward":
1041-
H_all = allgather_cp_sharded_tensor(H_all, cp_group, seq_dim=1)
1042-
if pad_len > 0:
1043-
student_topk_logprobs = student_topk_logprobs[:, :-pad_len, :]
1044-
if self.kl_type != "forward":
1045-
H_all = H_all[:, :-pad_len]
1046-
elif self.zero_outside_topk:
1047-
# Non-distributed processing
1048-
student_logprobs = torch.nn.functional.log_softmax(logits_tensor, dim=-1)
1049-
student_topk_logprobs = student_logprobs.gather(
1050-
dim=-1, index=teacher_topk_indices.to(student_logprobs.device)
1051-
)
1052-
if self.kl_type != "forward":
1053-
H_all = (student_logprobs.exp() * student_logprobs).sum(-1)
1054-
else:
1055-
# Gather logits at global indices
1056-
if (parallel_group is not None) or (cp_size > 1):
1057-
student_topk_logits = gather_logits_at_global_indices(
1058-
logits_tensor,
1059-
teacher_topk_indices,
1060-
tp_group=parallel_group,
1061-
cp_group=cp_group,
1062-
vocab_start_index=(
1063-
vocab_start_index if parallel_group is not None else 0
1064-
),
1065-
vocab_end_index=(
1066-
vocab_end_index
1067-
if parallel_group is not None
1068-
else int(logits_tensor.shape[-1])
1069-
),
1070-
)
1071-
else:
1072-
student_topk_logits = logits_tensor.gather(
1073-
dim=-1, index=teacher_topk_indices.to(logits_tensor.device)
1074-
)
1075-
student_topk_logprobs = torch.nn.functional.log_softmax(
1076-
student_topk_logits, dim=-1
1077-
)
1078-
1079-
# Move teacher tensors to the same device/dtype as student_topk_logits
1080-
teacher_topk_logits = teacher_topk_logits.to(
1081-
student_topk_logprobs.device, dtype=student_topk_logprobs.dtype
1082-
)
1083-
teacher_topk_logprobs = torch.nn.functional.log_softmax(
1084-
teacher_topk_logits, dim=-1
1085-
)
1086-
1087-
# Single point of next-token alignment after TP/CP processing
1088-
teacher_topk_logprobs = teacher_topk_logprobs[:, :-1, :]
1089-
student_topk_logprobs = student_topk_logprobs[:, :-1, :]
1090-
if self.zero_outside_topk and self.kl_type != "forward":
1091-
# Align H_all with next-token prediction
1092-
H_all = H_all[:, :-1]
1093-
1094936
student_probs = student_topk_logprobs.exp() # [B, S-1, k]
1095937
teacher_probs = teacher_topk_logprobs.exp() # [B, S-1, k]
1096938

@@ -1143,7 +985,7 @@ def __call__(
1143985

1144986
metrics = {
1145987
"loss": float(kl_loss.item()) if kl_loss.ndim == 0 else kl_loss,
1146-
"num_valid_samples": int(batch_size),
988+
"num_valid_samples": data["input_ids"].shape[0],
1147989
}
1148990

1149991
return kl_loss, metrics

nemo_rl/distributed/model_utils.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,6 +1026,173 @@ def gather_logits_at_global_indices(
10261026
return gathered_logits
10271027

10281028

1029+
def get_distilllation_topk_logprobs_from_logits(
1030+
student_logits: torch.Tensor,
1031+
teacher_topk_logits: torch.Tensor,
1032+
teacher_topk_indices: torch.Tensor,
1033+
zero_outside_topk: bool,
1034+
calculate_entropy: bool,
1035+
vocab_parallel_rank: Optional[int] = None,
1036+
vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
1037+
context_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
1038+
):
1039+
"""Compute top-k log probabilities from logits."""
1040+
if teacher_topk_indices.shape[-1] <= 0:
1041+
raise ValueError(
1042+
f"topk must be positive, got {teacher_topk_indices.shape[-1]}. "
1043+
"topk=0 is not supported as it would result in empty tensor operations."
1044+
)
1045+
1046+
# Ensure float32 for stability
1047+
student_logits = student_logits.to(torch.float32)
1048+
# Move teacher topk indices to the same device as student logits
1049+
teacher_topk_indices = teacher_topk_indices.to(student_logits.device)
1050+
1051+
# CP support: get CP group and size
1052+
cp_group = context_parallel_group
1053+
cp_size = 1 if cp_group is None else torch.distributed.get_world_size(cp_group)
1054+
1055+
# Process based on the student logits type
1056+
if vocab_parallel_group is not None:
1057+
assert vocab_parallel_rank is not None, (
1058+
"vocab_parallel_rank must be provided when vocab_parallel_group is provided"
1059+
)
1060+
student_logits = student_logits
1061+
parallel_group = vocab_parallel_group
1062+
1063+
V_local = int(student_logits.shape[-1])
1064+
vocab_start_index = vocab_parallel_rank * V_local
1065+
vocab_end_index = (vocab_parallel_rank + 1) * V_local
1066+
1067+
elif isinstance(student_logits, torch.distributed.tensor.DTensor):
1068+
device_mesh = student_logits.device_mesh
1069+
tp_group = device_mesh.get_group("tp")
1070+
1071+
student_logits = student_logits.to_local()
1072+
parallel_group = tp_group
1073+
1074+
tp_rank = tp_group.rank()
1075+
V_local = int(student_logits.shape[-1])
1076+
vocab_start_index = tp_rank * V_local
1077+
vocab_end_index = (tp_rank + 1) * V_local
1078+
1079+
# For DTensor, derive CP group/size from the device mesh to ensure CP-aware alignment
1080+
if (
1081+
device_mesh.mesh_dim_names is not None
1082+
and "cp" in device_mesh.mesh_dim_names
1083+
):
1084+
cp_group = device_mesh.get_group("cp")
1085+
cp_size = cp_group.size()
1086+
else:
1087+
cp_group = None
1088+
cp_size = 1
1089+
1090+
else:
1091+
student_logits = student_logits
1092+
parallel_group = None
1093+
1094+
# Process based on the zero_outside_topk setting
1095+
H_all = None
1096+
if zero_outside_topk:
1097+
# Distributed processing
1098+
if parallel_group is not None:
1099+
indices_local = teacher_topk_indices
1100+
pad_len = 0
1101+
1102+
if cp_size > 1:
1103+
pad_len = student_logits.shape[1] * cp_size - indices_local.shape[1]
1104+
if pad_len > 0:
1105+
indices_local = torch.nn.functional.pad(
1106+
indices_local, (0, 0, 0, pad_len), value=0
1107+
)
1108+
cp_rank = torch.distributed.get_rank(cp_group)
1109+
indices_local = _get_tokens_on_this_cp_rank(
1110+
indices_local, cp_rank, cp_size, seq_dim=1
1111+
)
1112+
1113+
seq_len_local = int(student_logits.shape[1])
1114+
chunk_size = max(1, min(seq_len_local, 1024))
1115+
student_topk_logprobs = ChunkedDistributedGatherLogprob.apply( # type: ignore
1116+
student_logits,
1117+
indices_local,
1118+
vocab_start_index,
1119+
vocab_end_index,
1120+
chunk_size,
1121+
parallel_group,
1122+
False,
1123+
)
1124+
1125+
if calculate_entropy:
1126+
H_all = ChunkedDistributedEntropy.apply( # type: ignore
1127+
student_logits,
1128+
chunk_size,
1129+
parallel_group,
1130+
False,
1131+
)
1132+
1133+
if cp_size > 1:
1134+
student_topk_logprobs = allgather_cp_sharded_tensor(
1135+
student_topk_logprobs, cp_group, seq_dim=1
1136+
)
1137+
if calculate_entropy:
1138+
H_all = allgather_cp_sharded_tensor(H_all, cp_group, seq_dim=1)
1139+
if pad_len > 0:
1140+
student_topk_logprobs = student_topk_logprobs[:, :-pad_len, :]
1141+
if calculate_entropy:
1142+
H_all = H_all[:, :-pad_len]
1143+
1144+
# Non-distributed processing
1145+
else:
1146+
student_logprobs = torch.nn.functional.log_softmax(student_logits, dim=-1)
1147+
student_topk_logprobs = student_logprobs.gather(
1148+
dim=-1, index=teacher_topk_indices
1149+
)
1150+
1151+
if calculate_entropy:
1152+
H_all = (student_logprobs.exp() * student_logprobs).sum(-1)
1153+
1154+
else:
1155+
# Distributed processing
1156+
if parallel_group is not None or cp_size > 1:
1157+
if parallel_group is None:
1158+
vocab_start_index = 0
1159+
vocab_end_index = int(student_logits.shape[-1])
1160+
1161+
student_topk_logits = gather_logits_at_global_indices(
1162+
student_logits,
1163+
teacher_topk_indices,
1164+
tp_group=parallel_group,
1165+
cp_group=cp_group,
1166+
vocab_start_index=vocab_start_index,
1167+
vocab_end_index=vocab_end_index,
1168+
)
1169+
1170+
# Non-distributed processing
1171+
else:
1172+
student_topk_logits = student_logits.gather(
1173+
dim=-1, index=teacher_topk_indices
1174+
)
1175+
1176+
student_topk_logprobs = torch.nn.functional.log_softmax(
1177+
student_topk_logits, dim=-1
1178+
)
1179+
1180+
# Move teacher tensors to the same device/dtype as student_topk_logits
1181+
teacher_topk_logits = teacher_topk_logits.to(
1182+
student_topk_logprobs.device, dtype=student_topk_logprobs.dtype
1183+
)
1184+
teacher_topk_logprobs = torch.nn.functional.log_softmax(teacher_topk_logits, dim=-1)
1185+
1186+
# Single point of next-token alignment after TP/CP processing
1187+
teacher_topk_logprobs = teacher_topk_logprobs[:, :-1, :]
1188+
student_topk_logprobs = student_topk_logprobs[:, :-1, :]
1189+
1190+
if calculate_entropy:
1191+
H_all = H_all[:, :-1]
1192+
1193+
return student_topk_logprobs, teacher_topk_logprobs, H_all
1194+
1195+
10291196
class ChunkedDistributedEntropy(torch.autograd.Function):
10301197
"""Compute H_all = sum_v p_v log p_v across TP with chunking over sequence.
10311198

0 commit comments

Comments
 (0)