|
16 | 16 | import os |
17 | 17 | from huggingface_hub import hf_hub_download |
18 | 18 |
|
| 19 | +class MedusaConfig(PretrainedConfig): |
| 20 | + """ |
| 21 | + Configuration class for Medusa model. |
| 22 | +
|
| 23 | + Args: |
| 24 | + medusa_num_heads (int, optional): Number of heads for the Medusa layer. Default is 2. |
| 25 | + medusa_num_layers (int, optional): Number of Medusa layers. Default is 1. |
| 26 | + base_model_name_or_path (str, optional): The name or path of the base model. Default is "lmsys/vicuna-7b-v1.3". |
| 27 | + **kwargs: Additional keyword arguments to be passed to the parent class constructor. |
| 28 | + """ |
| 29 | + |
| 30 | + def __init__( |
| 31 | + self, |
| 32 | + medusa_num_heads=5, |
| 33 | + medusa_num_layers=1, |
| 34 | + base_model_name_or_path="lmsys/vicuna-7b-v1.3", |
| 35 | + **kwargs, |
| 36 | + ): |
| 37 | + super().__init__(**kwargs) |
| 38 | + self.medusa_num_heads = medusa_num_heads |
| 39 | + self.medusa_num_layers = medusa_num_layers |
| 40 | + self.base_model_name_or_path = base_model_name_or_path |
19 | 41 |
|
20 | 42 | class ResBlock(nn.Module): |
21 | 43 | """ |
@@ -106,13 +128,34 @@ def from_pretrained( |
106 | 128 | **kwargs, |
107 | 129 | ): |
108 | 130 | # Manually load config to ensure that the medusa_num_heads parameter is loaded |
109 | | - config = AutoConfig.from_pretrained(pretrained_model_name_or_path) |
110 | | - return super().from_pretrained( |
111 | | - pretrained_model_name_or_path, |
112 | | - *args, |
113 | | - **kwargs, |
114 | | - config=config, |
115 | | - ) |
| 131 | + try: |
| 132 | + config = AutoConfig.from_pretrained(pretrained_model_name_or_path) |
| 133 | + return super().from_pretrained( |
| 134 | + pretrained_model_name_or_path, |
| 135 | + *args, |
| 136 | + **kwargs, |
| 137 | + config=config, |
| 138 | + ) |
| 139 | + except: |
| 140 | + config = MedusaConfig.from_pretrained(pretrained_model_name_or_path) |
| 141 | + base_model_config = AutoConfig.from_pretrained(config.base_model_name_or_path) |
| 142 | + base_model_config.medusa_num_heads = 5 # TODO: fix the uploaded config (only include 2 heads) |
| 143 | + base_model_config.medusa_num_layers = config.medusa_num_layers |
| 144 | + model = super().from_pretrained( |
| 145 | + config.base_model_name_or_path, |
| 146 | + *args, |
| 147 | + **kwargs, |
| 148 | + config=base_model_config, |
| 149 | + ) |
| 150 | + medusa_head_path = os.path.join(pretrained_model_name_or_path, "medusa_lm_head.pt") |
| 151 | + if os.path.exists(medusa_head_path): |
| 152 | + filename = medusa_head_path |
| 153 | + else: |
| 154 | + filename = hf_hub_download(pretrained_model_name_or_path, "medusa_lm_head.pt") |
| 155 | + medusa_head_state_dict = torch.load(filename, map_location=model.device) |
| 156 | + model.medusa_head.load_state_dict(medusa_head_state_dict, strict=False) |
| 157 | + return model |
| 158 | + |
116 | 159 |
|
117 | 160 | def get_tokenizer(self): |
118 | 161 | """Get the tokenizer of the base model. |
@@ -326,7 +369,14 @@ def from_pretrained( |
326 | 369 | **kwargs, |
327 | 370 | ): |
328 | 371 | # Manually load config to ensure that the medusa_num_heads parameter is loaded |
329 | | - config = AutoConfig.from_pretrained(pretrained_model_name_or_path) |
| 372 | + try: |
| 373 | + config = AutoConfig.from_pretrained(pretrained_model_name_or_path) |
| 374 | + except: |
| 375 | + # MEDUSA-v0.1 load |
| 376 | + config = MedusaConfig.from_pretrained(pretrained_model_name_or_path) |
| 377 | + base_model_config = AutoConfig.from_pretrained(config.base_model_name_or_path) |
| 378 | + config.model_type = base_model_config.model_type |
| 379 | + |
330 | 380 | if config.model_type == "llama": |
331 | 381 | return MedusaModelLlama.from_pretrained( |
332 | 382 | pretrained_model_name_or_path, |
|
0 commit comments