|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | +from transformers import PreTrainedModel, PretrainedConfig |
| 4 | +from .modeling_llama_kv import LlamaForCausalLM as KVLlamaForCausalLM |
| 5 | +from .utils import * |
| 6 | +from .kv_cache import initialize_past_key_values |
| 7 | +from .medusa_choices import mc_sim_7b_63 |
| 8 | +from transformers import AutoTokenizer |
| 9 | +import os |
| 10 | +from huggingface_hub import hf_hub_download |
| 11 | + |
| 12 | + |
| 13 | +class MedusaConfig(PretrainedConfig): |
| 14 | + """ |
| 15 | + Configuration class for Medusa model. |
| 16 | +
|
| 17 | + Args: |
| 18 | + medusa_num_heads (int, optional): Number of heads for the Medusa layer. Default is 2. |
| 19 | + medusa_num_layers (int, optional): Number of Medusa layers. Default is 1. |
| 20 | + base_model_name_or_path (str, optional): The name or path of the base model. Default is "lmsys/vicuna-7b-v1.3". |
| 21 | + **kwargs: Additional keyword arguments to be passed to the parent class constructor. |
| 22 | + """ |
| 23 | + |
| 24 | + def __init__( |
| 25 | + self, |
| 26 | + medusa_num_heads=4, |
| 27 | + medusa_num_layers=1, |
| 28 | + base_model_name_or_path="lmsys/vicuna-7b-v1.3", |
| 29 | + **kwargs, |
| 30 | + ): |
| 31 | + super().__init__(**kwargs) |
| 32 | + self.medusa_num_heads = medusa_num_heads |
| 33 | + self.medusa_num_layers = medusa_num_layers |
| 34 | + self.base_model_name_or_path = base_model_name_or_path |
| 35 | + |
| 36 | + |
| 37 | +class ResBlock(nn.Module): |
| 38 | + """ |
| 39 | + A Residual Block module. |
| 40 | +
|
| 41 | + This module performs a linear transformation followed by a SiLU activation, |
| 42 | + and then adds the result to the original input, creating a residual connection. |
| 43 | +
|
| 44 | + Args: |
| 45 | + hidden_size (int): The size of the hidden layers in the block. |
| 46 | + """ |
| 47 | + |
| 48 | + def __init__(self, hidden_size): |
| 49 | + super().__init__() |
| 50 | + self.linear = nn.Linear(hidden_size, hidden_size) |
| 51 | + # Initialize as an identity mapping |
| 52 | + torch.nn.init.zeros_(self.linear.weight) |
| 53 | + # Use SiLU activation to keep consistent with the Llama model |
| 54 | + self.act = nn.SiLU() |
| 55 | + |
| 56 | + def forward(self, x): |
| 57 | + """ |
| 58 | + Forward pass of the ResBlock. |
| 59 | +
|
| 60 | + Args: |
| 61 | + x (torch.Tensor): Input tensor. |
| 62 | +
|
| 63 | + Returns: |
| 64 | + torch.Tensor: Output after the residual connection and activation. |
| 65 | + """ |
| 66 | + return x + self.act(self.linear(x)) |
| 67 | + |
| 68 | + |
| 69 | +class MedusaModel(nn.Module): |
| 70 | + """The Medusa Language Model Head. |
| 71 | +
|
| 72 | + This module creates a series of prediction heads (based on the 'medusa' parameter) |
| 73 | + on top of a given base model. Each head is composed of a sequence of residual blocks |
| 74 | + followed by a linear layer. |
| 75 | + """ |
| 76 | + |
| 77 | + def __init__( |
| 78 | + self, |
| 79 | + base_model, |
| 80 | + medusa_num_heads=4, |
| 81 | + medusa_num_layers=1, |
| 82 | + base_model_name_or_path="lmsys/vicuna-7b-v1.3", |
| 83 | + ): |
| 84 | + """ |
| 85 | + Args: |
| 86 | + base_model (nn.Module): The base language model to be used. |
| 87 | + medusa_num_heads (int, optional): Number of additional tokens to predict. Defaults to 3. |
| 88 | + medusa_num_layers (int, optional): Number of ResBlock layers for each Medusa head. Defaults to 0. |
| 89 | + """ |
| 90 | + super().__init__() |
| 91 | + self.base_model = base_model |
| 92 | + self.config = base_model.config |
| 93 | + self.hidden_size = base_model.lm_head.weight.shape[-1] |
| 94 | + self.vocab_size = base_model.lm_head.weight.shape[0] |
| 95 | + self.medusa = medusa_num_heads |
| 96 | + self.medusa_num_layers = medusa_num_layers |
| 97 | + self.base_model_name_or_path = base_model_name_or_path |
| 98 | + self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name_or_path) |
| 99 | + # Create a list of Medusa heads |
| 100 | + self.medusa_head = nn.ModuleList( |
| 101 | + [ |
| 102 | + nn.Sequential( |
| 103 | + *([ResBlock(self.hidden_size)] * medusa_num_layers), |
| 104 | + nn.Linear(self.hidden_size, self.vocab_size, bias=False), |
| 105 | + ) |
| 106 | + for _ in range(medusa_num_heads) |
| 107 | + ] |
| 108 | + ) |
| 109 | + |
| 110 | + # Ensure medusa_head's dtype and device align with the base_model |
| 111 | + self.medusa_head.to(self.base_model.dtype).to(self.base_model.device) |
| 112 | + |
| 113 | + for i in range(medusa_num_heads): |
| 114 | + # Initialize the weights of each medusa_head using the base model's weights |
| 115 | + self.medusa_head[i][-1].weight.data[:] = base_model.lm_head.weight.data[:] |
| 116 | + |
| 117 | + def get_tokenizer(self): |
| 118 | + """Get the tokenizer of the base model. |
| 119 | +
|
| 120 | + Returns: |
| 121 | + Tokenizer: The tokenizer of the base model. |
| 122 | + """ |
| 123 | + return self.tokenizer |
| 124 | + |
| 125 | + @classmethod |
| 126 | + def from_pretrained( |
| 127 | + cls, |
| 128 | + medusa_head_name_or_path, |
| 129 | + base_model=None, |
| 130 | + medusa_num_heads=None, |
| 131 | + **kwargs, |
| 132 | + ): |
| 133 | + """ |
| 134 | + Args: |
| 135 | + medusa_head_name_or_path (str): Name or path of the Medusa head to load. |
| 136 | + **kwargs: Additional keyword arguments for loading the base model. |
| 137 | +
|
| 138 | + Returns: |
| 139 | + MedusaModel: A MedusaModel instance loaded from the given path. |
| 140 | + """ |
| 141 | + medusa_config = MedusaConfig.from_pretrained(medusa_head_name_or_path) |
| 142 | + if medusa_num_heads is not None: |
| 143 | + print("Overriding medusa_num_heads as:", medusa_num_heads) |
| 144 | + medusa_config.medusa_num_heads = medusa_num_heads |
| 145 | + if base_model is not None: |
| 146 | + print("Overriding base_model as:", base_model) |
| 147 | + medusa_config.base_model_name_or_path = base_model |
| 148 | + |
| 149 | + base_model = KVLlamaForCausalLM.from_pretrained( |
| 150 | + medusa_config.base_model_name_or_path, **kwargs |
| 151 | + ) |
| 152 | + |
| 153 | + model = cls( |
| 154 | + base_model, |
| 155 | + medusa_config.medusa_num_heads, |
| 156 | + medusa_config.medusa_num_layers, |
| 157 | + medusa_config.base_model_name_or_path, |
| 158 | + ) |
| 159 | + medusa_head_path = os.path.join(medusa_head_name_or_path, "medusa_lm_head.pt") |
| 160 | + if os.path.exists(medusa_head_path): |
| 161 | + filename = medusa_head_path |
| 162 | + else: |
| 163 | + filename = hf_hub_download(medusa_head_name_or_path, "medusa_lm_head.pt") |
| 164 | + medusa_head_state_dict = torch.load(filename, map_location=base_model.device) |
| 165 | + model.medusa_head.load_state_dict(medusa_head_state_dict, strict=False) |
| 166 | + |
| 167 | + return model |
| 168 | + |
| 169 | + def forward( |
| 170 | + self, |
| 171 | + input_ids=None, |
| 172 | + attention_mask=None, |
| 173 | + labels=None, |
| 174 | + past_key_values=None, |
| 175 | + output_orig=False, |
| 176 | + position_ids=None, |
| 177 | + ): |
| 178 | + """Forward pass of the MedusaModel. |
| 179 | +
|
| 180 | + Args: |
| 181 | + input_ids (torch.Tensor, optional): Input token IDs. |
| 182 | + attention_mask (torch.Tensor, optional): Attention mask. |
| 183 | + labels (torch.Tensor, optional): Ground truth labels for loss computation. |
| 184 | + past_key_values (tuple, optional): Tuple containing past key and value states for attention. |
| 185 | + output_orig (bool, optional): Whether to also output predictions from the original LM head. |
| 186 | + position_ids (torch.Tensor, optional): Position IDs. |
| 187 | +
|
| 188 | + Returns: |
| 189 | + torch.Tensor: A tensor containing predictions from all Medusa heads. |
| 190 | + (Optional) Original predictions from the base model's LM head. |
| 191 | + """ |
| 192 | + with torch.inference_mode(): |
| 193 | + # Pass input through the base model |
| 194 | + outputs = self.base_model.model( |
| 195 | + input_ids=input_ids, |
| 196 | + attention_mask=attention_mask, |
| 197 | + past_key_values=past_key_values, |
| 198 | + position_ids=position_ids, |
| 199 | + ) |
| 200 | + if output_orig: |
| 201 | + orig = self.base_model.lm_head(outputs[0]) |
| 202 | + # Clone the output hidden states |
| 203 | + hidden_states = outputs[0].clone() |
| 204 | + medusa_logits = [] |
| 205 | + # TODO: Consider parallelizing this loop for efficiency? |
| 206 | + for i in range(self.medusa): |
| 207 | + medusa_logits.append(self.medusa_head[i](hidden_states)) |
| 208 | + if output_orig: |
| 209 | + return torch.stack(medusa_logits, dim=0), outputs, orig |
| 210 | + return torch.stack(medusa_logits, dim=0) |
| 211 | + |
| 212 | + def medusa_generate( |
| 213 | + self, |
| 214 | + input_ids, |
| 215 | + attention_mask=None, |
| 216 | + temperature=0.0, |
| 217 | + max_steps=512, |
| 218 | + # The hyperparameters below are for the Medusa |
| 219 | + # top-1 prediciton for the next token, top-7 predictions for the next token, top-6 predictions for the next next token. |
| 220 | + medusa_choices=mc_sim_7b_63, |
| 221 | + posterior_threshold=0.09, # threshold validation of Medusa output |
| 222 | + # another threshold hyperparameter, recommended to be sqrt(posterior_threshold) |
| 223 | + posterior_alpha=0.3, |
| 224 | + ): |
| 225 | + """ |
| 226 | + Args: |
| 227 | + input_ids (torch.Tensor, optional): Input token IDs. |
| 228 | + attention_mask (torch.Tensor, optional): Attention mask. |
| 229 | + temperature (float, optional): Temperature for typical acceptance. |
| 230 | + medusa_choices (list, optional): A list of integers indicating the number of choices for each Medusa head. |
| 231 | + posterior_threshold (float, optional): Threshold for posterior validation. |
| 232 | + posterior_alpha (float, optional): Another threshold hyperparameter, recommended to be sqrt(posterior_threshold). |
| 233 | + Returns: |
| 234 | + torch.Tensor: Output token IDs. |
| 235 | +
|
| 236 | + Warning: Only support batch size 1 for now!! |
| 237 | + """ |
| 238 | + assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!" |
| 239 | + # Avoid modifying the input_ids in-place |
| 240 | + input_ids = input_ids.clone() |
| 241 | + |
| 242 | + # Cache medusa buffers (the fixed patterns for tree attention) |
| 243 | + if hasattr(self, "medusa_choices") and self.medusa_choices == medusa_choices: |
| 244 | + # Load the cached medusa buffer |
| 245 | + medusa_buffers = self.medusa_buffers |
| 246 | + else: |
| 247 | + # Initialize the medusa buffer |
| 248 | + medusa_buffers = generate_medusa_buffers( |
| 249 | + medusa_choices, device=self.base_model.device |
| 250 | + ) |
| 251 | + self.medusa_buffers = medusa_buffers |
| 252 | + self.medusa_choices = medusa_choices |
| 253 | + |
| 254 | + |
| 255 | + # Initialize the past key and value states |
| 256 | + if hasattr(self, "past_key_values"): |
| 257 | + past_key_values = self.past_key_values |
| 258 | + past_key_values_data = self.past_key_values_data |
| 259 | + current_length_data = self.current_length_data |
| 260 | + # Reset the past key and value states |
| 261 | + current_length_data.zero_() |
| 262 | + else: |
| 263 | + ( |
| 264 | + past_key_values, |
| 265 | + past_key_values_data, |
| 266 | + current_length_data, |
| 267 | + ) = initialize_past_key_values(self.base_model) |
| 268 | + self.past_key_values = past_key_values |
| 269 | + self.past_key_values_data = past_key_values_data |
| 270 | + self.current_length_data = current_length_data |
| 271 | + |
| 272 | + input_len = input_ids.shape[1] |
| 273 | + |
| 274 | + reset_medusa_mode(self) |
| 275 | + # Initialize tree attention mask and process prefill tokens |
| 276 | + medusa_logits, logits = initialize_medusa( |
| 277 | + input_ids, self, medusa_buffers["medusa_attn_mask"], past_key_values |
| 278 | + ) |
| 279 | + |
| 280 | + new_token = 0 |
| 281 | + last_round_token = 0 |
| 282 | + |
| 283 | + for idx in range(max_steps): |
| 284 | + # Generate candidates with topk predictions from Medusa heads |
| 285 | + candidates, tree_candidates = generate_candidates( |
| 286 | + medusa_logits, |
| 287 | + logits, |
| 288 | + medusa_buffers["tree_indices"], |
| 289 | + medusa_buffers["retrieve_indices"], |
| 290 | + ) |
| 291 | + |
| 292 | + # Use tree attention to verify the candidates and get predictions |
| 293 | + medusa_logits, logits, outputs = tree_decoding( |
| 294 | + self, |
| 295 | + tree_candidates, |
| 296 | + past_key_values, |
| 297 | + medusa_buffers["medusa_position_ids"], |
| 298 | + input_ids, |
| 299 | + medusa_buffers["retrieve_indices"], |
| 300 | + ) |
| 301 | + |
| 302 | + # Evaluate the posterior of the candidates to select the accepted candidate prefix |
| 303 | + best_candidate, accept_length = evaluate_posterior( |
| 304 | + logits, candidates, temperature, posterior_threshold, posterior_alpha |
| 305 | + ) |
| 306 | + |
| 307 | + # Update the input_ids and logits |
| 308 | + input_ids, logits, medusa_logits, new_token = update_inference_inputs( |
| 309 | + input_ids, |
| 310 | + candidates, |
| 311 | + best_candidate, |
| 312 | + accept_length, |
| 313 | + medusa_buffers["retrieve_indices"], |
| 314 | + outputs, |
| 315 | + logits, |
| 316 | + medusa_logits, |
| 317 | + new_token, |
| 318 | + past_key_values_data, |
| 319 | + current_length_data, |
| 320 | + ) |
| 321 | + |
| 322 | + yield { |
| 323 | + "text": self.tokenizer.decode( |
| 324 | + input_ids[0, input_len:], |
| 325 | + skip_special_tokens=True, |
| 326 | + spaces_between_special_tokens=False, |
| 327 | + clean_up_tokenization_spaces=True, |
| 328 | + ) |
| 329 | + } |
| 330 | + |
| 331 | + if self.tokenizer.eos_token_id in input_ids[0, input_len:]: |
| 332 | + break |
0 commit comments