|
20 | 20 | from nemo_rl.algorithms.interfaces import LossFunction, LossType |
21 | 21 | from nemo_rl.algorithms.utils import calculate_kl, masked_mean |
22 | 22 | from nemo_rl.distributed.batched_data_dict import BatchedDataDict |
23 | | -from nemo_rl.distributed.model_utils import ( |
24 | | - ChunkedDistributedEntropy, |
25 | | - ChunkedDistributedGatherLogprob, |
26 | | - _get_tokens_on_this_cp_rank, |
27 | | - allgather_cp_sharded_tensor, |
28 | | - gather_logits_at_global_indices, |
29 | | -) |
30 | 23 |
|
31 | 24 | Tensor = TypeVar("Tensor", bound=torch.Tensor) |
32 | 25 |
|
@@ -999,165 +992,14 @@ def __init__(self, cfg: DistillationLossConfig): |
999 | 992 |
|
1000 | 993 | def __call__( |
1001 | 994 | self, |
1002 | | - next_token_logits: torch.Tensor, |
| 995 | + student_topk_logprobs: torch.Tensor, |
| 996 | + teacher_topk_logprobs: torch.Tensor, |
| 997 | + H_all: torch.Tensor | None, |
1003 | 998 | data: DistillationLossDataDict, |
1004 | 999 | global_valid_seqs: torch.Tensor, |
1005 | 1000 | global_valid_toks: torch.Tensor, |
1006 | | - vocab_parallel_rank: Optional[int] = None, |
1007 | | - vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, |
1008 | | - context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, |
1009 | 1001 | ) -> tuple[torch.Tensor, dict[str, Any]]: |
1010 | 1002 | """Compute distillation loss between teacher and student logits.""" |
1011 | | - # Basic shapes |
1012 | | - input_ids = data["input_ids"] |
1013 | | - batch_size = input_ids.shape[0] |
1014 | | - |
1015 | | - # CP support: get CP group and size |
1016 | | - cp_group = context_parallel_group |
1017 | | - cp_size = 1 if cp_group is None else torch.distributed.get_world_size(cp_group) |
1018 | | - |
1019 | | - # Ensure float32 for stability (match other losses) |
1020 | | - next_token_logits = next_token_logits.to(torch.float32) |
1021 | | - per_token_kl = None |
1022 | | - # Preferred truncated-KL path: teacher provides top-k support per position |
1023 | | - teacher_topk_logits = data["teacher_topk_logits"] # [B, S, k] |
1024 | | - teacher_topk_indices = data["teacher_topk_indices"] # [B, S, k] |
1025 | | - |
1026 | | - if teacher_topk_indices.shape[-1] <= 0: |
1027 | | - raise ValueError( |
1028 | | - f"topk must be positive, got {teacher_topk_indices.shape[-1]}. " |
1029 | | - "topk=0 is not supported as it would result in empty tensor operations." |
1030 | | - ) |
1031 | | - |
1032 | | - # Determine processing path and setup variables |
1033 | | - if vocab_parallel_group is not None: |
1034 | | - assert vocab_parallel_rank is not None, ( |
1035 | | - "vocab_parallel_rank must be provided when vocab_parallel_group is provided" |
1036 | | - ) |
1037 | | - V_local = int(next_token_logits.shape[-1]) |
1038 | | - vocab_start_index = vocab_parallel_rank * V_local |
1039 | | - vocab_end_index = (vocab_parallel_rank + 1) * V_local |
1040 | | - parallel_group = vocab_parallel_group |
1041 | | - logits_tensor = next_token_logits |
1042 | | - elif isinstance(next_token_logits, torch.distributed.tensor.DTensor): |
1043 | | - device_mesh = next_token_logits.device_mesh |
1044 | | - tp_group = device_mesh.get_group("tp") |
1045 | | - tp_rank = tp_group.rank() |
1046 | | - local_student_logits = next_token_logits.to_local() |
1047 | | - V_local = int(local_student_logits.shape[-1]) |
1048 | | - vocab_start_index = tp_rank * V_local |
1049 | | - vocab_end_index = (tp_rank + 1) * V_local |
1050 | | - parallel_group = tp_group |
1051 | | - logits_tensor = local_student_logits |
1052 | | - teacher_topk_indices = teacher_topk_indices.to(local_student_logits.device) |
1053 | | - # For DTensor, derive CP group/size from the device mesh to ensure CP-aware alignment |
1054 | | - if ( |
1055 | | - device_mesh.mesh_dim_names is not None |
1056 | | - and "cp" in device_mesh.mesh_dim_names |
1057 | | - ): |
1058 | | - cp_group = device_mesh.get_group("cp") |
1059 | | - cp_size = cp_group.size() |
1060 | | - else: |
1061 | | - cp_group = None |
1062 | | - cp_size = 1 |
1063 | | - else: |
1064 | | - parallel_group = None |
1065 | | - logits_tensor = next_token_logits |
1066 | | - |
1067 | | - # Process based on zero_outside_topk setting |
1068 | | - if self.zero_outside_topk and parallel_group is not None: |
1069 | | - # Distributed processing with chunking |
1070 | | - indices_local = teacher_topk_indices |
1071 | | - pad_len = 0 |
1072 | | - if cp_size > 1: |
1073 | | - pad_len = logits_tensor.shape[1] * cp_size - indices_local.shape[1] |
1074 | | - if pad_len > 0: |
1075 | | - indices_local = torch.nn.functional.pad( |
1076 | | - indices_local, (0, 0, 0, pad_len), value=0 |
1077 | | - ) |
1078 | | - cp_rank = torch.distributed.get_rank(cp_group) |
1079 | | - indices_local = _get_tokens_on_this_cp_rank( |
1080 | | - indices_local, cp_rank, cp_size, seq_dim=1 |
1081 | | - ) |
1082 | | - |
1083 | | - S_local = int(logits_tensor.shape[1]) |
1084 | | - chunk_size = max(1, min(S_local, 1024)) |
1085 | | - student_topk_logprobs = ChunkedDistributedGatherLogprob.apply( # type: ignore |
1086 | | - logits_tensor, |
1087 | | - indices_local, |
1088 | | - vocab_start_index, |
1089 | | - vocab_end_index, |
1090 | | - chunk_size, |
1091 | | - parallel_group, |
1092 | | - False, |
1093 | | - ) |
1094 | | - |
1095 | | - if self.kl_type != "forward": |
1096 | | - H_all = ChunkedDistributedEntropy.apply( # type: ignore |
1097 | | - logits_tensor, |
1098 | | - chunk_size, |
1099 | | - parallel_group, |
1100 | | - False, |
1101 | | - ) |
1102 | | - |
1103 | | - if cp_size > 1: |
1104 | | - student_topk_logprobs = allgather_cp_sharded_tensor( |
1105 | | - student_topk_logprobs, cp_group, seq_dim=1 |
1106 | | - ) |
1107 | | - if self.kl_type != "forward": |
1108 | | - H_all = allgather_cp_sharded_tensor(H_all, cp_group, seq_dim=1) |
1109 | | - if pad_len > 0: |
1110 | | - student_topk_logprobs = student_topk_logprobs[:, :-pad_len, :] |
1111 | | - if self.kl_type != "forward": |
1112 | | - H_all = H_all[:, :-pad_len] |
1113 | | - elif self.zero_outside_topk: |
1114 | | - # Non-distributed processing |
1115 | | - student_logprobs = torch.nn.functional.log_softmax(logits_tensor, dim=-1) |
1116 | | - student_topk_logprobs = student_logprobs.gather( |
1117 | | - dim=-1, index=teacher_topk_indices.to(student_logprobs.device) |
1118 | | - ) |
1119 | | - if self.kl_type != "forward": |
1120 | | - H_all = (student_logprobs.exp() * student_logprobs).sum(-1) |
1121 | | - else: |
1122 | | - # Gather logits at global indices |
1123 | | - if (parallel_group is not None) or (cp_size > 1): |
1124 | | - student_topk_logits = gather_logits_at_global_indices( |
1125 | | - logits_tensor, |
1126 | | - teacher_topk_indices, |
1127 | | - tp_group=parallel_group, |
1128 | | - cp_group=cp_group, |
1129 | | - vocab_start_index=( |
1130 | | - vocab_start_index if parallel_group is not None else 0 |
1131 | | - ), |
1132 | | - vocab_end_index=( |
1133 | | - vocab_end_index |
1134 | | - if parallel_group is not None |
1135 | | - else int(logits_tensor.shape[-1]) |
1136 | | - ), |
1137 | | - ) |
1138 | | - else: |
1139 | | - student_topk_logits = logits_tensor.gather( |
1140 | | - dim=-1, index=teacher_topk_indices.to(logits_tensor.device) |
1141 | | - ) |
1142 | | - student_topk_logprobs = torch.nn.functional.log_softmax( |
1143 | | - student_topk_logits, dim=-1 |
1144 | | - ) |
1145 | | - |
1146 | | - # Move teacher tensors to the same device/dtype as student_topk_logits |
1147 | | - teacher_topk_logits = teacher_topk_logits.to( |
1148 | | - student_topk_logprobs.device, dtype=student_topk_logprobs.dtype |
1149 | | - ) |
1150 | | - teacher_topk_logprobs = torch.nn.functional.log_softmax( |
1151 | | - teacher_topk_logits, dim=-1 |
1152 | | - ) |
1153 | | - |
1154 | | - # Single point of next-token alignment after TP/CP processing |
1155 | | - teacher_topk_logprobs = teacher_topk_logprobs[:, :-1, :] |
1156 | | - student_topk_logprobs = student_topk_logprobs[:, :-1, :] |
1157 | | - if self.zero_outside_topk and self.kl_type != "forward": |
1158 | | - # Align H_all with next-token prediction |
1159 | | - H_all = H_all[:, :-1] |
1160 | | - |
1161 | 1003 | student_probs = student_topk_logprobs.exp() # [B, S-1, k] |
1162 | 1004 | teacher_probs = teacher_topk_logprobs.exp() # [B, S-1, k] |
1163 | 1005 |
|
@@ -1210,7 +1052,7 @@ def __call__( |
1210 | 1052 |
|
1211 | 1053 | metrics = { |
1212 | 1054 | "loss": float(kl_loss.item()) if kl_loss.ndim == 0 else kl_loss, |
1213 | | - "num_valid_samples": int(batch_size), |
| 1055 | + "num_valid_samples": data["input_ids"].shape[0], |
1214 | 1056 | } |
1215 | 1057 |
|
1216 | 1058 | return kl_loss, metrics |
0 commit comments