|
16 | 16 |
|
17 | 17 | import numpy as np
|
18 | 18 | import paddle
|
| 19 | +import paddle.nn.functional as F |
| 20 | +from paddle.distributed import fleet |
| 21 | +from paddle.distributed.fleet.layers.mpu import mp_ops |
19 | 22 | from paddle.distributed.fleet.meta_parallel import ParallelCrossEntropy
|
20 | 23 |
|
21 | 24 | from ..models.ppo_model_utils import (
|
@@ -57,6 +60,13 @@ def compute_logprob(self, input_ids: paddle.Tensor, position_ids: paddle.Tensor
|
57 | 60 | Raises:
|
58 | 61 | None.
|
59 | 62 | """
|
| 63 | + if self.args.use_fused_head_and_loss_fn: |
| 64 | + return self.compute_fused_logprob( |
| 65 | + input_ids=input_ids, |
| 66 | + position_ids=position_ids, |
| 67 | + **kwargs, |
| 68 | + ) |
| 69 | + |
60 | 70 | log_probs_list = []
|
61 | 71 | batch_size, sequence_length = input_ids.shape
|
62 | 72 | per_device_logprob_batch_size = self.args.per_device_logprob_batch_size
|
@@ -147,6 +157,153 @@ def compute_logprob(self, input_ids: paddle.Tensor, position_ids: paddle.Tensor
|
147 | 157 |
|
148 | 158 | return paddle.concat(log_probs_list, axis=0)
|
149 | 159 |
|
| 160 | + def compute_fused_logprob( |
| 161 | + self, input_ids: paddle.Tensor, position_ids: paddle.Tensor = None, loop_chunk_size=1024, **kwargs |
| 162 | + ): |
| 163 | + log_probs_list = [] |
| 164 | + batch_size, sequence_length = input_ids.shape |
| 165 | + per_device_logprob_batch_size = self.args.per_device_logprob_batch_size |
| 166 | + num_batches = (batch_size + per_device_logprob_batch_size - 1) // per_device_logprob_batch_size |
| 167 | + |
| 168 | + # Pipe model outputs a logits tensor with LMHead, while non-pipe model |
| 169 | + # outputs a tuple with logits tensor as the only one element. |
| 170 | + startend_row_indices = create_startend_row_indices(input_ids, self.tokenizer.pad_token_id) |
| 171 | + response_start = kwargs["prompt"].shape[-1] - 1 if "prompt" in kwargs else 0 |
| 172 | + |
| 173 | + num_embeddings = self.model.config.vocab_size |
| 174 | + tensor_parallel_degree = self.model.config.tensor_parallel_degree |
| 175 | + tensor_parallel_output = self.model.config.tensor_parallel_output |
| 176 | + |
| 177 | + for i in range(num_batches): |
| 178 | + # Calculate the start and end indices for the current batch |
| 179 | + start_index = i * per_device_logprob_batch_size |
| 180 | + end_index = min(start_index + per_device_logprob_batch_size, batch_size) |
| 181 | + |
| 182 | + # Extract the current batch |
| 183 | + current_input_ids = input_ids[start_index:end_index] |
| 184 | + current_startend_row_indices = ( |
| 185 | + startend_row_indices[start_index:end_index] if startend_row_indices is not None else None |
| 186 | + ) |
| 187 | + current_position_ids = position_ids[start_index:end_index] if position_ids is not None else None |
| 188 | + current_labels = current_input_ids[:, response_start + 1 :] |
| 189 | + |
| 190 | + if self.args.use_remove_padding: |
| 191 | + from ..utils.bert_padding import prepare_flashmask_inputs |
| 192 | + |
| 193 | + update_inputs = prepare_flashmask_inputs( |
| 194 | + current_input_ids, |
| 195 | + current_position_ids, |
| 196 | + self.tokenizer.pad_token_id, |
| 197 | + self.model.config.sequence_parallel, |
| 198 | + self.model.config.tensor_parallel_degree, |
| 199 | + ) |
| 200 | + current_input_ids = update_inputs["input_ids"] |
| 201 | + current_position_ids = update_inputs["position_ids"] |
| 202 | + current_startend_row_indices = update_inputs["attn_mask_startend_row_indices"] |
| 203 | + indices = update_inputs["indices"] |
| 204 | + raw_input_shape = update_inputs["raw_input_shape"] |
| 205 | + pad_size = update_inputs["pad_size"] |
| 206 | + |
| 207 | + # NOTE: for use_fused_head_and_loss_fn |
| 208 | + self.model.training = True |
| 209 | + hidden_states, lm_head_weight, lm_head_bias, transpose_y = self.model( |
| 210 | + current_input_ids, |
| 211 | + position_ids=current_position_ids, |
| 212 | + attn_mask_startend_row_indices=current_startend_row_indices, |
| 213 | + ) |
| 214 | + self.model.training = False |
| 215 | + |
| 216 | + if self.args.use_remove_padding: |
| 217 | + if pad_size > 0: |
| 218 | + hidden_states = hidden_states[:, :-pad_size] |
| 219 | + |
| 220 | + from ..utils.bert_padding import pad_input |
| 221 | + |
| 222 | + hidden_states = pad_input( |
| 223 | + hidden_states.squeeze(0), indices, batch=raw_input_shape[0], seqlen=raw_input_shape[1] |
| 224 | + ).contiguous() |
| 225 | + |
| 226 | + if self.args.use_fp32_compute and hidden_states.dtype != paddle.float32: |
| 227 | + hidden_states = hidden_states.cast(paddle.float32) |
| 228 | + lm_head_weight = lm_head_weight.cast(paddle.float32) |
| 229 | + if lm_head_bias is not None: |
| 230 | + lm_head_bias = lm_head_bias.cast(paddle.float32) |
| 231 | + |
| 232 | + # Recover |
| 233 | + hidden_states = hidden_states[:, response_start:-1, :] |
| 234 | + dtype = hidden_states.dtype |
| 235 | + original_shape = hidden_states.shape |
| 236 | + if tensor_parallel_degree > 1: |
| 237 | + assert tensor_parallel_output, ( |
| 238 | + "When tensor_parallel_degree > 1 and use_fused_head_and_loss_fn, " |
| 239 | + "tensor_parallel_output needs to be set to True." |
| 240 | + ) |
| 241 | + # Parallel Configuration |
| 242 | + if tensor_parallel_degree > 1 and tensor_parallel_output: |
| 243 | + hcg = fleet.get_hybrid_communicate_group() |
| 244 | + model_parallel_group = hcg.get_model_parallel_group() |
| 245 | + tensor_parallel_degree = hcg.get_model_parallel_world_size() |
| 246 | + |
| 247 | + # reshape |
| 248 | + hidden_states = hidden_states.reshape([-1, original_shape[-1]]) |
| 249 | + labels = current_labels.reshape([-1]) |
| 250 | + |
| 251 | + n_tokens = hidden_states.shape[0] |
| 252 | + n_classes = lm_head_weight.shape[0] if transpose_y else lm_head_weight.shape[1] |
| 253 | + |
| 254 | + # convert dtype of weights and biases of lm_head |
| 255 | + lm_head_weight_cast = lm_head_weight.astype(dtype) |
| 256 | + if lm_head_bias is not None: |
| 257 | + lm_head_bias_cast = lm_head_bias.astype(dtype) |
| 258 | + |
| 259 | + # use indices to distinguish the devices. |
| 260 | + if tensor_parallel_degree > 1 and tensor_parallel_output: |
| 261 | + rank = hcg.get_model_parallel_rank() |
| 262 | + per_part_size = num_embeddings // tensor_parallel_degree |
| 263 | + indices = paddle.arange( |
| 264 | + rank * per_part_size, |
| 265 | + rank * per_part_size + n_classes, |
| 266 | + dtype=labels.dtype, |
| 267 | + ).unsqueeze(0) |
| 268 | + else: |
| 269 | + indices = paddle.arange(num_embeddings, dtype=labels.dtype).unsqueeze(0) |
| 270 | + |
| 271 | + log_prob_chunks = [] |
| 272 | + for ci in range(0, n_tokens, loop_chunk_size): |
| 273 | + token_start_idx = ci |
| 274 | + token_end_idx = min(ci + loop_chunk_size, n_tokens) |
| 275 | + hidden_states_chunk = hidden_states[token_start_idx:token_end_idx] |
| 276 | + labels_chunk = labels[token_start_idx:token_end_idx] |
| 277 | + |
| 278 | + # Calculate the current logits_chunk, not fused linear |
| 279 | + logits_chunk_cast = paddle.matmul(hidden_states_chunk, lm_head_weight_cast, transpose_y=transpose_y) |
| 280 | + if lm_head_bias is not None: |
| 281 | + logits_chunk_cast += lm_head_bias_cast |
| 282 | + |
| 283 | + logits_chunk = logits_chunk_cast.astype("float32") |
| 284 | + logits_chunk = logits_chunk / self.args.temperature |
| 285 | + |
| 286 | + # rewritten as cross entropy |
| 287 | + if tensor_parallel_degree > 1 and tensor_parallel_output: |
| 288 | + token_loss_chunk = mp_ops._c_softmax_with_cross_entropy( |
| 289 | + logits_chunk, |
| 290 | + labels_chunk, |
| 291 | + group=model_parallel_group, |
| 292 | + return_softmax=False, |
| 293 | + ) |
| 294 | + else: |
| 295 | + token_loss_chunk = F.cross_entropy(logits_chunk, labels_chunk, reduction="none") |
| 296 | + log_prob_chunk = -token_loss_chunk.squeeze(axis=-1) |
| 297 | + log_prob_chunks.append(log_prob_chunk) |
| 298 | + |
| 299 | + log_probs = paddle.concat(log_prob_chunks, axis=-1).reshape(original_shape[:-1]) |
| 300 | + log_probs_list.append(log_probs) |
| 301 | + |
| 302 | + log_prob_chunks = None |
| 303 | + paddle.device.cuda.empty_cache() |
| 304 | + |
| 305 | + return paddle.concat(log_probs_list, axis=0) |
| 306 | + |
150 | 307 | def update_actor(self, rl_batch: Dict[str, paddle.Tensor]) -> Dict[str, Any]:
|
151 | 308 | # inputs shared by policy and value trainer
|
152 | 309 | input_ids = rl_batch["input_ids"].contiguous() # length: src+tgt
|
|
0 commit comments