Skip to content

Commit f203c94

Browse files
committed
update PreferenceLoss and DistillationLossFn
Signed-off-by: Yuki Huang <yukih@nvidia.com>
1 parent de4a64b commit f203c94

File tree

4 files changed

+275
-179
lines changed

4 files changed

+275
-179
lines changed

nemo_rl/algorithms/loss_functions.py

Lines changed: 4 additions & 162 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,6 @@
2020
from nemo_rl.algorithms.interfaces import LossFunction, LossType
2121
from nemo_rl.algorithms.utils import calculate_kl, masked_mean
2222
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
23-
from nemo_rl.distributed.model_utils import (
24-
ChunkedDistributedEntropy,
25-
ChunkedDistributedGatherLogprob,
26-
_get_tokens_on_this_cp_rank,
27-
allgather_cp_sharded_tensor,
28-
gather_logits_at_global_indices,
29-
)
3023

3124
Tensor = TypeVar("Tensor", bound=torch.Tensor)
3225

@@ -999,165 +992,14 @@ def __init__(self, cfg: DistillationLossConfig):
999992

1000993
def __call__(
1001994
self,
1002-
next_token_logits: torch.Tensor,
995+
student_topk_logprobs: torch.Tensor,
996+
teacher_topk_logprobs: torch.Tensor,
997+
H_all: torch.Tensor | None,
1003998
data: DistillationLossDataDict,
1004999
global_valid_seqs: torch.Tensor,
10051000
global_valid_toks: torch.Tensor,
1006-
vocab_parallel_rank: Optional[int] = None,
1007-
vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
1008-
context_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
10091001
) -> tuple[torch.Tensor, dict[str, Any]]:
10101002
"""Compute distillation loss between teacher and student logits."""
1011-
# Basic shapes
1012-
input_ids = data["input_ids"]
1013-
batch_size = input_ids.shape[0]
1014-
1015-
# CP support: get CP group and size
1016-
cp_group = context_parallel_group
1017-
cp_size = 1 if cp_group is None else torch.distributed.get_world_size(cp_group)
1018-
1019-
# Ensure float32 for stability (match other losses)
1020-
next_token_logits = next_token_logits.to(torch.float32)
1021-
per_token_kl = None
1022-
# Preferred truncated-KL path: teacher provides top-k support per position
1023-
teacher_topk_logits = data["teacher_topk_logits"] # [B, S, k]
1024-
teacher_topk_indices = data["teacher_topk_indices"] # [B, S, k]
1025-
1026-
if teacher_topk_indices.shape[-1] <= 0:
1027-
raise ValueError(
1028-
f"topk must be positive, got {teacher_topk_indices.shape[-1]}. "
1029-
"topk=0 is not supported as it would result in empty tensor operations."
1030-
)
1031-
1032-
# Determine processing path and setup variables
1033-
if vocab_parallel_group is not None:
1034-
assert vocab_parallel_rank is not None, (
1035-
"vocab_parallel_rank must be provided when vocab_parallel_group is provided"
1036-
)
1037-
V_local = int(next_token_logits.shape[-1])
1038-
vocab_start_index = vocab_parallel_rank * V_local
1039-
vocab_end_index = (vocab_parallel_rank + 1) * V_local
1040-
parallel_group = vocab_parallel_group
1041-
logits_tensor = next_token_logits
1042-
elif isinstance(next_token_logits, torch.distributed.tensor.DTensor):
1043-
device_mesh = next_token_logits.device_mesh
1044-
tp_group = device_mesh.get_group("tp")
1045-
tp_rank = tp_group.rank()
1046-
local_student_logits = next_token_logits.to_local()
1047-
V_local = int(local_student_logits.shape[-1])
1048-
vocab_start_index = tp_rank * V_local
1049-
vocab_end_index = (tp_rank + 1) * V_local
1050-
parallel_group = tp_group
1051-
logits_tensor = local_student_logits
1052-
teacher_topk_indices = teacher_topk_indices.to(local_student_logits.device)
1053-
# For DTensor, derive CP group/size from the device mesh to ensure CP-aware alignment
1054-
if (
1055-
device_mesh.mesh_dim_names is not None
1056-
and "cp" in device_mesh.mesh_dim_names
1057-
):
1058-
cp_group = device_mesh.get_group("cp")
1059-
cp_size = cp_group.size()
1060-
else:
1061-
cp_group = None
1062-
cp_size = 1
1063-
else:
1064-
parallel_group = None
1065-
logits_tensor = next_token_logits
1066-
1067-
# Process based on zero_outside_topk setting
1068-
if self.zero_outside_topk and parallel_group is not None:
1069-
# Distributed processing with chunking
1070-
indices_local = teacher_topk_indices
1071-
pad_len = 0
1072-
if cp_size > 1:
1073-
pad_len = logits_tensor.shape[1] * cp_size - indices_local.shape[1]
1074-
if pad_len > 0:
1075-
indices_local = torch.nn.functional.pad(
1076-
indices_local, (0, 0, 0, pad_len), value=0
1077-
)
1078-
cp_rank = torch.distributed.get_rank(cp_group)
1079-
indices_local = _get_tokens_on_this_cp_rank(
1080-
indices_local, cp_rank, cp_size, seq_dim=1
1081-
)
1082-
1083-
S_local = int(logits_tensor.shape[1])
1084-
chunk_size = max(1, min(S_local, 1024))
1085-
student_topk_logprobs = ChunkedDistributedGatherLogprob.apply( # type: ignore
1086-
logits_tensor,
1087-
indices_local,
1088-
vocab_start_index,
1089-
vocab_end_index,
1090-
chunk_size,
1091-
parallel_group,
1092-
False,
1093-
)
1094-
1095-
if self.kl_type != "forward":
1096-
H_all = ChunkedDistributedEntropy.apply( # type: ignore
1097-
logits_tensor,
1098-
chunk_size,
1099-
parallel_group,
1100-
False,
1101-
)
1102-
1103-
if cp_size > 1:
1104-
student_topk_logprobs = allgather_cp_sharded_tensor(
1105-
student_topk_logprobs, cp_group, seq_dim=1
1106-
)
1107-
if self.kl_type != "forward":
1108-
H_all = allgather_cp_sharded_tensor(H_all, cp_group, seq_dim=1)
1109-
if pad_len > 0:
1110-
student_topk_logprobs = student_topk_logprobs[:, :-pad_len, :]
1111-
if self.kl_type != "forward":
1112-
H_all = H_all[:, :-pad_len]
1113-
elif self.zero_outside_topk:
1114-
# Non-distributed processing
1115-
student_logprobs = torch.nn.functional.log_softmax(logits_tensor, dim=-1)
1116-
student_topk_logprobs = student_logprobs.gather(
1117-
dim=-1, index=teacher_topk_indices.to(student_logprobs.device)
1118-
)
1119-
if self.kl_type != "forward":
1120-
H_all = (student_logprobs.exp() * student_logprobs).sum(-1)
1121-
else:
1122-
# Gather logits at global indices
1123-
if (parallel_group is not None) or (cp_size > 1):
1124-
student_topk_logits = gather_logits_at_global_indices(
1125-
logits_tensor,
1126-
teacher_topk_indices,
1127-
tp_group=parallel_group,
1128-
cp_group=cp_group,
1129-
vocab_start_index=(
1130-
vocab_start_index if parallel_group is not None else 0
1131-
),
1132-
vocab_end_index=(
1133-
vocab_end_index
1134-
if parallel_group is not None
1135-
else int(logits_tensor.shape[-1])
1136-
),
1137-
)
1138-
else:
1139-
student_topk_logits = logits_tensor.gather(
1140-
dim=-1, index=teacher_topk_indices.to(logits_tensor.device)
1141-
)
1142-
student_topk_logprobs = torch.nn.functional.log_softmax(
1143-
student_topk_logits, dim=-1
1144-
)
1145-
1146-
# Move teacher tensors to the same device/dtype as student_topk_logits
1147-
teacher_topk_logits = teacher_topk_logits.to(
1148-
student_topk_logprobs.device, dtype=student_topk_logprobs.dtype
1149-
)
1150-
teacher_topk_logprobs = torch.nn.functional.log_softmax(
1151-
teacher_topk_logits, dim=-1
1152-
)
1153-
1154-
# Single point of next-token alignment after TP/CP processing
1155-
teacher_topk_logprobs = teacher_topk_logprobs[:, :-1, :]
1156-
student_topk_logprobs = student_topk_logprobs[:, :-1, :]
1157-
if self.zero_outside_topk and self.kl_type != "forward":
1158-
# Align H_all with next-token prediction
1159-
H_all = H_all[:, :-1]
1160-
11611003
student_probs = student_topk_logprobs.exp() # [B, S-1, k]
11621004
teacher_probs = teacher_topk_logprobs.exp() # [B, S-1, k]
11631005

@@ -1210,7 +1052,7 @@ def __call__(
12101052

12111053
metrics = {
12121054
"loss": float(kl_loss.item()) if kl_loss.ndim == 0 else kl_loss,
1213-
"num_valid_samples": int(batch_size),
1055+
"num_valid_samples": data["input_ids"].shape[0],
12141056
}
12151057

12161058
return kl_loss, metrics

nemo_rl/distributed/model_utils.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,6 +1026,173 @@ def gather_logits_at_global_indices(
10261026
return gathered_logits
10271027

10281028

1029+
def get_distilllation_topk_logprobs_from_logits(
1030+
student_logits: torch.Tensor,
1031+
teacher_topk_logits: torch.Tensor,
1032+
teacher_topk_indices: torch.Tensor,
1033+
zero_outside_topk: bool,
1034+
calculate_entropy: bool,
1035+
vocab_parallel_rank: Optional[int] = None,
1036+
vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
1037+
context_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
1038+
):
1039+
"""Compute top-k log probabilities from logits."""
1040+
if teacher_topk_indices.shape[-1] <= 0:
1041+
raise ValueError(
1042+
f"topk must be positive, got {teacher_topk_indices.shape[-1]}. "
1043+
"topk=0 is not supported as it would result in empty tensor operations."
1044+
)
1045+
1046+
# Ensure float32 for stability
1047+
student_logits = student_logits.to(torch.float32)
1048+
# Move teacher topk indices to the same device as student logits
1049+
teacher_topk_indices = teacher_topk_indices.to(student_logits.device)
1050+
1051+
# CP support: get CP group and size
1052+
cp_group = context_parallel_group
1053+
cp_size = 1 if cp_group is None else torch.distributed.get_world_size(cp_group)
1054+
1055+
# Process based on the student logits type
1056+
if vocab_parallel_group is not None:
1057+
assert vocab_parallel_rank is not None, (
1058+
"vocab_parallel_rank must be provided when vocab_parallel_group is provided"
1059+
)
1060+
student_logits = student_logits
1061+
parallel_group = vocab_parallel_group
1062+
1063+
V_local = int(student_logits.shape[-1])
1064+
vocab_start_index = vocab_parallel_rank * V_local
1065+
vocab_end_index = (vocab_parallel_rank + 1) * V_local
1066+
1067+
elif isinstance(student_logits, torch.distributed.tensor.DTensor):
1068+
device_mesh = student_logits.device_mesh
1069+
tp_group = device_mesh.get_group("tp")
1070+
1071+
student_logits = student_logits.to_local()
1072+
parallel_group = tp_group
1073+
1074+
tp_rank = tp_group.rank()
1075+
V_local = int(student_logits.shape[-1])
1076+
vocab_start_index = tp_rank * V_local
1077+
vocab_end_index = (tp_rank + 1) * V_local
1078+
1079+
# For DTensor, derive CP group/size from the device mesh to ensure CP-aware alignment
1080+
if (
1081+
device_mesh.mesh_dim_names is not None
1082+
and "cp" in device_mesh.mesh_dim_names
1083+
):
1084+
cp_group = device_mesh.get_group("cp")
1085+
cp_size = cp_group.size()
1086+
else:
1087+
cp_group = None
1088+
cp_size = 1
1089+
1090+
else:
1091+
student_logits = student_logits
1092+
parallel_group = None
1093+
1094+
# Process based on the zero_outside_topk setting
1095+
H_all = None
1096+
if zero_outside_topk:
1097+
# Distributed processing
1098+
if parallel_group is not None:
1099+
indices_local = teacher_topk_indices
1100+
pad_len = 0
1101+
1102+
if cp_size > 1:
1103+
pad_len = student_logits.shape[1] * cp_size - indices_local.shape[1]
1104+
if pad_len > 0:
1105+
indices_local = torch.nn.functional.pad(
1106+
indices_local, (0, 0, 0, pad_len), value=0
1107+
)
1108+
cp_rank = torch.distributed.get_rank(cp_group)
1109+
indices_local = _get_tokens_on_this_cp_rank(
1110+
indices_local, cp_rank, cp_size, seq_dim=1
1111+
)
1112+
1113+
seq_len_local = int(student_logits.shape[1])
1114+
chunk_size = max(1, min(seq_len_local, 1024))
1115+
student_topk_logprobs = ChunkedDistributedGatherLogprob.apply( # type: ignore
1116+
student_logits,
1117+
indices_local,
1118+
vocab_start_index,
1119+
vocab_end_index,
1120+
chunk_size,
1121+
parallel_group,
1122+
False,
1123+
)
1124+
1125+
if calculate_entropy:
1126+
H_all = ChunkedDistributedEntropy.apply( # type: ignore
1127+
student_logits,
1128+
chunk_size,
1129+
parallel_group,
1130+
False,
1131+
)
1132+
1133+
if cp_size > 1:
1134+
student_topk_logprobs = allgather_cp_sharded_tensor(
1135+
student_topk_logprobs, cp_group, seq_dim=1
1136+
)
1137+
if calculate_entropy:
1138+
H_all = allgather_cp_sharded_tensor(H_all, cp_group, seq_dim=1)
1139+
if pad_len > 0:
1140+
student_topk_logprobs = student_topk_logprobs[:, :-pad_len, :]
1141+
if calculate_entropy:
1142+
H_all = H_all[:, :-pad_len]
1143+
1144+
# Non-distributed processing
1145+
else:
1146+
student_logprobs = torch.nn.functional.log_softmax(student_logits, dim=-1)
1147+
student_topk_logprobs = student_logprobs.gather(
1148+
dim=-1, index=teacher_topk_indices
1149+
)
1150+
1151+
if calculate_entropy:
1152+
H_all = (student_logprobs.exp() * student_logprobs).sum(-1)
1153+
1154+
else:
1155+
# Distributed processing
1156+
if parallel_group is not None or cp_size > 1:
1157+
if parallel_group is None:
1158+
vocab_start_index = 0
1159+
vocab_end_index = int(student_logits.shape[-1])
1160+
1161+
student_topk_logits = gather_logits_at_global_indices(
1162+
student_logits,
1163+
teacher_topk_indices,
1164+
tp_group=parallel_group,
1165+
cp_group=cp_group,
1166+
vocab_start_index=vocab_start_index,
1167+
vocab_end_index=vocab_end_index,
1168+
)
1169+
1170+
# Non-distributed processing
1171+
else:
1172+
student_topk_logits = student_logits.gather(
1173+
dim=-1, index=teacher_topk_indices
1174+
)
1175+
1176+
student_topk_logprobs = torch.nn.functional.log_softmax(
1177+
student_topk_logits, dim=-1
1178+
)
1179+
1180+
# Move teacher tensors to the same device/dtype as student_topk_logits
1181+
teacher_topk_logits = teacher_topk_logits.to(
1182+
student_topk_logprobs.device, dtype=student_topk_logprobs.dtype
1183+
)
1184+
teacher_topk_logprobs = torch.nn.functional.log_softmax(teacher_topk_logits, dim=-1)
1185+
1186+
# Single point of next-token alignment after TP/CP processing
1187+
teacher_topk_logprobs = teacher_topk_logprobs[:, :-1, :]
1188+
student_topk_logprobs = student_topk_logprobs[:, :-1, :]
1189+
1190+
if calculate_entropy:
1191+
H_all = H_all[:, :-1]
1192+
1193+
return student_topk_logprobs, teacher_topk_logprobs, H_all
1194+
1195+
10291196
class ChunkedDistributedEntropy(torch.autograd.Function):
10301197
"""Compute H_all = sum_v p_v log p_v across TP with chunking over sequence.
10311198

0 commit comments

Comments
 (0)