|
| 1 | +import torch |
| 2 | +import time |
| 3 | +import os |
| 4 | +import torch.multiprocessing as mp |
| 5 | +from typing import List |
| 6 | +from lightllm.utils.log_utils import init_logger |
| 7 | +from lightllm.models.deepseek2.triton_kernel.gqa_flash_decoding import gqa_token_decode_attention_flash_decoding |
| 8 | +from lightllm.utils.watchdog_utils import Watchdog |
| 9 | + |
| 10 | +logger = init_logger(__name__) |
| 11 | + |
| 12 | + |
| 13 | +def set_seed(): |
| 14 | + import torch |
| 15 | + import random |
| 16 | + import numpy as np |
| 17 | + |
| 18 | + seed = 42 |
| 19 | + torch.manual_seed(seed) |
| 20 | + random.seed(seed) |
| 21 | + np.random.seed(seed) |
| 22 | + if torch.cuda.is_available(): |
| 23 | + torch.cuda.manual_seed(seed) |
| 24 | + torch.cuda.manual_seed_all(seed) |
| 25 | + return |
| 26 | + |
| 27 | + |
| 28 | +@torch.no_grad() |
| 29 | +def test_decode_attentions( |
| 30 | + q_nope_shape: List[int], |
| 31 | + q_rope_shape: List[int], |
| 32 | + kv_nope_shape: List[int], |
| 33 | + kv_rope_shape: List[int], |
| 34 | + test_seq_len: int, |
| 35 | + dtype: torch.dtype, |
| 36 | + test_count: int = 20, |
| 37 | + **run_config, |
| 38 | +): |
| 39 | + set_seed() |
| 40 | + tmp_class = type("TestObj", (object,), {}) |
| 41 | + infer_state = tmp_class() |
| 42 | + infer_state.batch_size = q_nope_shape[0] |
| 43 | + infer_state.max_len_in_batch = test_seq_len |
| 44 | + infer_state.req_manager = tmp_class() |
| 45 | + infer_state.req_manager.req_to_token_indexs = torch.zeros( |
| 46 | + (infer_state.batch_size, infer_state.max_len_in_batch), dtype=torch.int32, device="cuda" |
| 47 | + ) |
| 48 | + infer_state.req_manager.req_to_token_indexs.view(-1)[:] = torch.arange( |
| 49 | + 0, infer_state.batch_size * infer_state.max_len_in_batch, step=1, dtype=torch.int32 |
| 50 | + ).cuda() |
| 51 | + infer_state.b_req_idx = torch.arange(0, infer_state.batch_size, step=1, dtype=torch.int32).cuda() |
| 52 | + infer_state.b_seq_len = torch.full((infer_state.batch_size,), fill_value=test_seq_len, dtype=torch.int32).cuda() |
| 53 | + infer_state.total_token_num_tensor = torch.sum(infer_state.b_seq_len) |
| 54 | + |
| 55 | + input_tuples = [] |
| 56 | + for _ in range(test_count): |
| 57 | + q_nope = torch.randn(q_nope_shape, device="cuda", dtype=dtype) / 10 |
| 58 | + q_rope = torch.randn(q_rope_shape, device="cuda", dtype=dtype) / 10 |
| 59 | + kv_buffer_shape = [ |
| 60 | + (test_seq_len + 10) * infer_state.batch_size, |
| 61 | + kv_nope_shape[1], |
| 62 | + kv_nope_shape[2] + kv_rope_shape[2], |
| 63 | + ] |
| 64 | + kv_buffer = torch.randn(kv_buffer_shape, device="cuda", dtype=dtype) / 10 |
| 65 | + |
| 66 | + kv_nope = kv_buffer[:, :, 0 : kv_nope_shape[2]] |
| 67 | + kv_rope = kv_buffer[:, :, kv_nope_shape[2] :] |
| 68 | + o_tensor = torch.empty_like(q_nope) |
| 69 | + input_tuples.append((q_nope, q_rope, kv_buffer, kv_nope, kv_rope, o_tensor)) |
| 70 | + |
| 71 | + tensor_dict = {} |
| 72 | + |
| 73 | + def inner_alloc_func(shape, dtype=torch.float32, device="cuda"): |
| 74 | + shape = tuple(shape) |
| 75 | + if shape not in tensor_dict: |
| 76 | + ans = torch.empty(shape, dtype=dtype, device=device) |
| 77 | + tensor_dict[shape] = ans |
| 78 | + return ans |
| 79 | + else: |
| 80 | + return tensor_dict[shape] |
| 81 | + |
| 82 | + gqa_token_decode_attention_flash_decoding( |
| 83 | + q_nope, |
| 84 | + q_rope, |
| 85 | + kv_nope, |
| 86 | + kv_rope, |
| 87 | + infer_state, |
| 88 | + q_nope_shape[1], |
| 89 | + q_nope_shape[2], |
| 90 | + q_rope_shape[2], |
| 91 | + None, |
| 92 | + 0.01, |
| 93 | + out=o_tensor, |
| 94 | + alloc_tensor_func=inner_alloc_func, |
| 95 | + **run_config, |
| 96 | + ) |
| 97 | + |
| 98 | + graph = torch.cuda.CUDAGraph() |
| 99 | + with torch.cuda.graph(graph): |
| 100 | + for index in range(test_count): |
| 101 | + q_nope, q_rope, kv_buffer, kv_nope, kv_rope, o_tensor = input_tuples[index] |
| 102 | + gqa_token_decode_attention_flash_decoding( |
| 103 | + q_nope, |
| 104 | + q_rope, |
| 105 | + kv_nope, |
| 106 | + kv_rope, |
| 107 | + infer_state, |
| 108 | + q_nope_shape[1], |
| 109 | + q_nope_shape[2], |
| 110 | + q_rope_shape[2], |
| 111 | + None, |
| 112 | + 0.01, |
| 113 | + out=o_tensor, |
| 114 | + alloc_tensor_func=inner_alloc_func, |
| 115 | + **run_config, |
| 116 | + ) |
| 117 | + |
| 118 | + graph.replay() |
| 119 | + |
| 120 | + torch.cuda.synchronize() |
| 121 | + start = time.time() |
| 122 | + graph.replay() |
| 123 | + torch.cuda.synchronize() |
| 124 | + |
| 125 | + cost_time = (time.time() - start) * 1000 |
| 126 | + |
| 127 | + logger.info(f"bf16 {test_seq_len} cost time: {cost_time} ms") |
| 128 | + return cost_time |
| 129 | + |
| 130 | + |
| 131 | +def worker( |
| 132 | + q_nope_shape: List[int], |
| 133 | + q_rope_shape: List[int], |
| 134 | + kv_nope_shape: List[int], |
| 135 | + kv_rope_shape: List[int], |
| 136 | + test_seq_len: int, |
| 137 | + dtype: torch.dtype, |
| 138 | + test_count: int, |
| 139 | + test_configs, |
| 140 | + queue, |
| 141 | +): |
| 142 | + dog = Watchdog(timeout=10) |
| 143 | + dog.start() |
| 144 | + |
| 145 | + try: |
| 146 | + for index in range(len(test_configs)): |
| 147 | + tuning_config = test_configs[index] |
| 148 | + cost_time = test_decode_attentions( |
| 149 | + q_nope_shape=q_nope_shape, |
| 150 | + q_rope_shape=q_rope_shape, |
| 151 | + kv_nope_shape=kv_nope_shape, |
| 152 | + kv_rope_shape=kv_rope_shape, |
| 153 | + test_seq_len=test_seq_len, |
| 154 | + dtype=dtype, |
| 155 | + test_count=test_count, |
| 156 | + **tuning_config, |
| 157 | + ) |
| 158 | + dog.heartbeat() |
| 159 | + queue.put(cost_time) # Put result in queue |
| 160 | + except Exception as ex: |
| 161 | + logger.error(str(ex) + f"config {tuning_config}") |
| 162 | + import sys |
| 163 | + |
| 164 | + sys.exit(-1) |
| 165 | + pass |
| 166 | + |
| 167 | + |
| 168 | +def get_test_configs(split_id, split_count): |
| 169 | + index = 0 |
| 170 | + for block_n in [16, 32]: |
| 171 | + for block_q_head in [ |
| 172 | + 16, |
| 173 | + ]: |
| 174 | + for stage1_num_warps in [2, 4, 8, 16]: |
| 175 | + for stage1_num_stages in [ |
| 176 | + 1, |
| 177 | + 2, |
| 178 | + 3, |
| 179 | + 4, |
| 180 | + 5, |
| 181 | + 6, |
| 182 | + 7, |
| 183 | + 8, |
| 184 | + 12, |
| 185 | + 15, |
| 186 | + ]: |
| 187 | + for stage2_num_warps in [1, 2, 4]: |
| 188 | + for stage2_num_stages in [ |
| 189 | + 1, |
| 190 | + 3, |
| 191 | + ]: |
| 192 | + t_config = { |
| 193 | + "BLOCK_N": block_n, |
| 194 | + "BLOCK_Q_HEAD": block_q_head, |
| 195 | + "stage1_num_warps": stage1_num_warps, |
| 196 | + "stage1_num_stages": stage1_num_stages, |
| 197 | + "stage2_num_warps": stage2_num_warps, |
| 198 | + "stage2_num_stages": stage2_num_stages, |
| 199 | + } |
| 200 | + if index % split_count == split_id: |
| 201 | + yield t_config |
| 202 | + index += 1 |
| 203 | + else: |
| 204 | + index += 1 |
| 205 | + |
| 206 | + |
| 207 | +def tuning_configs( |
| 208 | + device_id: int, # use for mult mp tunning |
| 209 | + device_count: int, |
| 210 | + q_nope_shape: List[int], |
| 211 | + q_rope_shape: List[int], |
| 212 | + kv_nope_shape: List[int], |
| 213 | + kv_rope_shape: List[int], |
| 214 | + test_seq_len: int, |
| 215 | + dtype: torch.dtype, |
| 216 | + test_count: int, |
| 217 | +): |
| 218 | + os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id) |
| 219 | + best_config, best_cost_time = None, 10000000 |
| 220 | + queue = mp.Queue() |
| 221 | + test_configs = [] |
| 222 | + for t_config in get_test_configs(device_id, device_count): |
| 223 | + test_configs.append(t_config) |
| 224 | + if len(test_configs) < 64: |
| 225 | + continue |
| 226 | + |
| 227 | + p = mp.Process( |
| 228 | + target=worker, |
| 229 | + args=( |
| 230 | + q_nope_shape, |
| 231 | + q_rope_shape, |
| 232 | + kv_nope_shape, |
| 233 | + kv_rope_shape, |
| 234 | + test_seq_len, |
| 235 | + dtype, |
| 236 | + test_count, |
| 237 | + test_configs, |
| 238 | + queue, |
| 239 | + ), |
| 240 | + ) |
| 241 | + p.start() |
| 242 | + p.join() |
| 243 | + |
| 244 | + while len(test_configs) != 0: |
| 245 | + try: |
| 246 | + cost_time = queue.get_nowait() |
| 247 | + logger.info(f"get {test_configs[0]} cost_time: {cost_time}") |
| 248 | + if cost_time < best_cost_time: |
| 249 | + best_config = test_configs[0] |
| 250 | + best_cost_time = cost_time |
| 251 | + logger.info(f"cur best {best_config}, {best_cost_time}") |
| 252 | + del test_configs[0:1] |
| 253 | + except: |
| 254 | + logger.info(f"cur best {best_config}, {best_cost_time}") |
| 255 | + del test_configs[0:1] |
| 256 | + break |
| 257 | + |
| 258 | + while len(test_configs) != 0: |
| 259 | + p = mp.Process( |
| 260 | + target=worker, |
| 261 | + args=( |
| 262 | + q_nope_shape, |
| 263 | + q_rope_shape, |
| 264 | + kv_nope_shape, |
| 265 | + kv_rope_shape, |
| 266 | + test_seq_len, |
| 267 | + dtype, |
| 268 | + test_count, |
| 269 | + test_configs, |
| 270 | + queue, |
| 271 | + ), |
| 272 | + ) |
| 273 | + p.start() |
| 274 | + p.join() |
| 275 | + |
| 276 | + while len(test_configs) != 0: |
| 277 | + try: |
| 278 | + cost_time = queue.get_nowait() |
| 279 | + logger.info(f"get {test_configs[0]} cost_time: {cost_time}") |
| 280 | + if cost_time < best_cost_time: |
| 281 | + best_config = test_configs[0] |
| 282 | + best_cost_time = cost_time |
| 283 | + logger.info(f"cur best {best_config}, {best_cost_time}") |
| 284 | + del test_configs[0:1] |
| 285 | + except: |
| 286 | + logger.info(f"cur best {best_config}, {best_cost_time}") |
| 287 | + del test_configs[0:1] |
| 288 | + break |
| 289 | + |
| 290 | + logger.info(f"{best_config} best cost: {best_cost_time}") |
| 291 | + return best_config, best_cost_time |
| 292 | + |
| 293 | + |
| 294 | +if __name__ == "__main__": |
| 295 | + torch.multiprocessing.set_start_method("spawn") |
| 296 | + |
| 297 | + from lightllm.utils.tuning_utils import mp_tuning |
| 298 | + from lightllm.models.deepseek2.triton_kernel.gqa_flash_decoding_config import MlaDecodeAttentionKernelConfig |
| 299 | + |
| 300 | + q_head_num = 16 |
| 301 | + q_head_dim = 512 |
| 302 | + q_rope_dim = 64 |
| 303 | + import collections |
| 304 | + |
| 305 | + store_json_ans = collections.defaultdict(dict) |
| 306 | + for batch_size in [1, 8, 16, 32, 64, 128, 256]: |
| 307 | + for seq_len in [256, 512, 1024, 2048, 4096, 8192]: |
| 308 | + if batch_size * seq_len > 128 * 1024 * 4: |
| 309 | + continue |
| 310 | + |
| 311 | + ans = mp_tuning( |
| 312 | + tuning_configs, |
| 313 | + { |
| 314 | + "q_nope_shape": [batch_size, q_head_num, q_head_dim], |
| 315 | + "q_rope_shape": [batch_size, q_head_num, q_rope_dim], |
| 316 | + "kv_nope_shape": [None, 1, q_head_dim], |
| 317 | + "kv_rope_shape": [None, 1, q_rope_dim], |
| 318 | + "test_seq_len": seq_len, |
| 319 | + "dtype": torch.bfloat16, |
| 320 | + "test_count": 40, |
| 321 | + }, |
| 322 | + ) |
| 323 | + store_json_ans[seq_len][batch_size] = ans |
| 324 | + |
| 325 | + MlaDecodeAttentionKernelConfig.save_config( |
| 326 | + q_head_num=q_head_num, |
| 327 | + q_head_dim=q_head_dim, |
| 328 | + q_rope_dim=q_rope_dim, |
| 329 | + out_dtype=str(torch.bfloat16), |
| 330 | + config_json=store_json_ans, |
| 331 | + ) |
| 332 | + |
| 333 | + pass |
0 commit comments