|
16 | 16 | "outputs": [], |
17 | 17 | "source": [ |
18 | 18 | "import os\n", |
19 | | - "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\" # define GPU id, remove if you want to use all GPUs available\n", |
| 19 | + "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"3\" # define GPU id, remove if you want to use all GPUs available\n", |
20 | 20 | "import torch\n", |
21 | 21 | "from tqdm import tqdm\n", |
22 | 22 | "import time\n", |
|
26 | 26 | "from medusa.model.medusa_model import MedusaModel\n", |
27 | 27 | "from medusa.model.kv_cache import *\n", |
28 | 28 | "from medusa.model.utils import *\n", |
| 29 | + "from medusa.model.medusa_choices import *\n", |
29 | 30 | "import transformers\n", |
30 | 31 | "from huggingface_hub import hf_hub_download" |
31 | 32 | ] |
|
55 | 56 | " elapsed_time = end - start\n", |
56 | 57 | " wall_times[key].append(elapsed_time)\n", |
57 | 58 | "\n", |
58 | | - "def medusa_forward(input_ids, model, tokenizer, medusa_buffers, medusa_topk, temperature, posterior_threshold, posterior_alpha, past_key_values, past_key_values_data, current_length_data, steps = 512):\n", |
| 59 | + "def medusa_forward(input_ids, model, tokenizer, medusa_choices, temperature, posterior_threshold, posterior_alpha, max_steps = 512):\n", |
59 | 60 | " wall_times = {'medusa': [], 'tree': [], 'posterior': [], 'update': [], 'init': []}\n", |
60 | 61 | " \n", |
61 | 62 | " with timed(wall_times, 'init'):\n", |
62 | | - " reset_medusa_mode(model)\n", |
| 63 | + " if hasattr(model, \"medusa_choices\") and model.medusa_choices == medusa_choices:\n", |
| 64 | + " # Load the cached medusa buffer\n", |
| 65 | + " medusa_buffers = model.medusa_buffers\n", |
| 66 | + " else:\n", |
| 67 | + " # Initialize the medusa buffer\n", |
| 68 | + " medusa_buffers = generate_medusa_buffers(\n", |
| 69 | + " medusa_choices, device=model.base_model.device\n", |
| 70 | + " )\n", |
| 71 | + " model.medusa_buffers = medusa_buffers\n", |
| 72 | + " model.medusa_choices = medusa_choices\n", |
| 73 | + "\n", |
| 74 | + " # Initialize the past key and value states\n", |
| 75 | + " if hasattr(model, \"past_key_values\"):\n", |
| 76 | + " past_key_values = model.past_key_values\n", |
| 77 | + " past_key_values_data = model.past_key_values_data\n", |
| 78 | + " current_length_data = model.current_length_data\n", |
| 79 | + " # Reset the past key and value states\n", |
| 80 | + " current_length_data.zero_()\n", |
| 81 | + " else:\n", |
| 82 | + " (\n", |
| 83 | + " past_key_values,\n", |
| 84 | + " past_key_values_data,\n", |
| 85 | + " current_length_data,\n", |
| 86 | + " ) = initialize_past_key_values(model.base_model)\n", |
| 87 | + " model.past_key_values = past_key_values\n", |
| 88 | + " model.past_key_values_data = past_key_values_data\n", |
| 89 | + " model.current_length_data = current_length_data\n", |
| 90 | + "\n", |
63 | 91 | " input_len = input_ids.shape[1]\n", |
64 | | - " medusa_logits, logits = initialize_medusa(input_ids, model, medusa_buffers['medusa_attn_mask'], past_key_values)\n", |
65 | | - " \n", |
| 92 | + " reset_medusa_mode(model)\n", |
| 93 | + " medusa_logits, logits = initialize_medusa(\n", |
| 94 | + " input_ids, model, medusa_buffers[\"medusa_attn_mask\"], past_key_values\n", |
| 95 | + " )\n", |
66 | 96 | " new_token = 0\n", |
67 | 97 | "\n", |
68 | | - " for idx in range(steps): \n", |
| 98 | + " for idx in range(max_steps): \n", |
69 | 99 | " with timed(wall_times, 'medusa'):\n", |
70 | | - " candidates, tree_candidates = generate_candidates(medusa_logits, logits, medusa_topk, medusa_buffers['tree_indices'], temperature)\n", |
| 100 | + " candidates, tree_candidates = generate_candidates(\n", |
| 101 | + " medusa_logits,\n", |
| 102 | + " logits,\n", |
| 103 | + " medusa_buffers[\"tree_indices\"],\n", |
| 104 | + " medusa_buffers[\"retrieve_indices\"],\n", |
| 105 | + " )\n", |
71 | 106 | "\n", |
72 | 107 | " with timed(wall_times, 'tree'):\n", |
73 | | - " medusa_logits, logits, outputs = tree_decoding(model, tree_candidates, past_key_values, medusa_buffers['medusa_position_ids'], input_ids, medusa_buffers['retrieve_indices'])\n", |
| 108 | + " medusa_logits, logits, outputs = tree_decoding(\n", |
| 109 | + " model,\n", |
| 110 | + " tree_candidates,\n", |
| 111 | + " past_key_values,\n", |
| 112 | + " medusa_buffers[\"medusa_position_ids\"],\n", |
| 113 | + " input_ids,\n", |
| 114 | + " medusa_buffers[\"retrieve_indices\"],\n", |
| 115 | + " )\n", |
74 | 116 | "\n", |
75 | 117 | " with timed(wall_times, 'posterior'):\n", |
76 | | - " best_candidate, accept_length = evaluate_posterior(logits, candidates, temperature, posterior_threshold, posterior_alpha)\n", |
| 118 | + " best_candidate, accept_length = evaluate_posterior(\n", |
| 119 | + " logits, candidates, temperature, posterior_threshold, posterior_alpha\n", |
| 120 | + " )\n", |
77 | 121 | " \n", |
78 | 122 | " with timed(wall_times, 'update'):\n", |
79 | | - " input_ids, logits, medusa_logits, new_token = update_inference_inputs(input_ids, candidates, best_candidate, accept_length, medusa_buffers['retrieve_indices'], outputs, logits, medusa_logits, new_token, past_key_values_data, current_length_data)\n", |
| 123 | + " input_ids, logits, medusa_logits, new_token = update_inference_inputs(\n", |
| 124 | + " input_ids,\n", |
| 125 | + " candidates,\n", |
| 126 | + " best_candidate,\n", |
| 127 | + " accept_length,\n", |
| 128 | + " medusa_buffers[\"retrieve_indices\"],\n", |
| 129 | + " outputs,\n", |
| 130 | + " logits,\n", |
| 131 | + " medusa_logits,\n", |
| 132 | + " new_token,\n", |
| 133 | + " past_key_values_data,\n", |
| 134 | + " current_length_data,\n", |
| 135 | + " )\n", |
80 | 136 | "\n", |
81 | 137 | " if tokenizer.eos_token_id in input_ids[0, input_len:].tolist():\n", |
82 | 138 | " break\n", |
|
102 | 158 | "model_name = 'FasterDecoding/medusa-vicuna-7b-v1.3'\n", |
103 | 159 | "model = MedusaModel.from_pretrained(\n", |
104 | 160 | " model_name,\n", |
| 161 | + " medusa_num_heads = 4,\n", |
105 | 162 | " torch_dtype=torch.float16,\n", |
106 | 163 | " low_cpu_mem_usage=True,\n", |
107 | 164 | " device_map=\"auto\"\n", |
108 | 165 | ")\n", |
109 | 166 | "tokenizer = model.get_tokenizer()\n", |
110 | 167 | "\n", |
111 | | - "medusa_choices = torch.tensor([1, 7, 6])\n", |
112 | | - "num_heads = len(medusa_choices) - 1\n", |
113 | | - "medusa_topk = medusa_choices[1:]\n", |
114 | | - "\n", |
115 | | - "medusa_buffers = generate_medusa_buffers(medusa_choices, device=model.base_model.device)" |
| 168 | + "medusa_choices = mc_sim_7b_63\n", |
| 169 | + "\n" |
116 | 170 | ] |
117 | 171 | }, |
118 | 172 | { |
|
135 | 189 | "posterior_alpha = 0.3" |
136 | 190 | ] |
137 | 191 | }, |
138 | | - { |
139 | | - "cell_type": "markdown", |
140 | | - "metadata": {}, |
141 | | - "source": [ |
142 | | - "## Initializing Past Values\n", |
143 | | - "\n", |
144 | | - "We initialize the dedicated cache for past key values." |
145 | | - ] |
146 | | - }, |
147 | | - { |
148 | | - "cell_type": "code", |
149 | | - "execution_count": null, |
150 | | - "metadata": {}, |
151 | | - "outputs": [], |
152 | | - "source": [ |
153 | | - "past_key_values, past_key_values_data, current_length_data = initialize_past_key_values(model.base_model)" |
154 | | - ] |
155 | | - }, |
156 | 192 | { |
157 | 193 | "cell_type": "markdown", |
158 | 194 | "metadata": {}, |
|
192 | 228 | " torch.as_tensor(input_ids).cuda(),\n", |
193 | 229 | " model,\n", |
194 | 230 | " tokenizer,\n", |
195 | | - " medusa_buffers,\n", |
196 | | - " medusa_topk,\n", |
| 231 | + " medusa_choices,\n", |
197 | 232 | " temperature,\n", |
198 | 233 | " posterior_threshold,\n", |
199 | 234 | " posterior_alpha,\n", |
200 | | - " past_key_values,\n", |
201 | | - " past_key_values_data, current_length_data\n", |
202 | 235 | " )\n", |
203 | 236 | " output_ids = output_ids[0][len(input_ids[0]) :]\n", |
204 | 237 | " print(\"Output length:\", output_ids.size(-1))\n", |
|
299 | 332 | "name": "python", |
300 | 333 | "nbconvert_exporter": "python", |
301 | 334 | "pygments_lexer": "ipython3", |
302 | | - "version": "3.9.16" |
| 335 | + "version": "3.9.18" |
303 | 336 | }, |
304 | 337 | "orig_nbformat": 4 |
305 | 338 | }, |
|
0 commit comments