|
| 1 | +import os |
| 2 | +import torch |
| 3 | +import time |
| 4 | +import torch.multiprocessing as mp |
| 5 | +import itertools |
| 6 | +from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd |
| 7 | +from lightllm.models.deepseek2.triton_kernel.rotary_emb_config import DeepseekV3RotaryKernelConfig |
| 8 | +from lightllm.utils.watchdog_utils import Watchdog |
| 9 | +from typing import List |
| 10 | +from lightllm.utils.log_utils import init_logger |
| 11 | + |
| 12 | +logger = init_logger(__name__) |
| 13 | + |
| 14 | + |
| 15 | +def set_seed(): |
| 16 | + import torch |
| 17 | + import random |
| 18 | + import numpy as np |
| 19 | + |
| 20 | + seed = 42 |
| 21 | + torch.manual_seed(seed) |
| 22 | + random.seed(seed) |
| 23 | + np.random.seed(seed) |
| 24 | + if torch.cuda.is_available(): |
| 25 | + torch.cuda.manual_seed(seed) |
| 26 | + torch.cuda.manual_seed_all(seed) |
| 27 | + return |
| 28 | + |
| 29 | + |
| 30 | +@torch.no_grad() |
| 31 | +def test_kernel( |
| 32 | + M: int, |
| 33 | + Q_HEAD_NUM: int, |
| 34 | + K_HEAD_NUM: int, |
| 35 | + HEAD_DIM: int, |
| 36 | + dtype: torch.dtype, |
| 37 | + test_count: int, |
| 38 | + **config, |
| 39 | +): |
| 40 | + set_seed() |
| 41 | + input_tuples = [] |
| 42 | + |
| 43 | + q = torch.randn((M, Q_HEAD_NUM, HEAD_DIM), device="cuda", dtype=dtype) / 10 |
| 44 | + k = torch.randn((M, K_HEAD_NUM, HEAD_DIM), device="cuda", dtype=dtype) / 10 |
| 45 | + cos = torch.randn((M, HEAD_DIM // 2), device="cuda", dtype=dtype) |
| 46 | + sin = torch.randn((M, HEAD_DIM // 2), device="cuda", dtype=dtype) |
| 47 | + |
| 48 | + for _ in range(test_count): |
| 49 | + input_tuples.append((q.clone(), k.clone(), cos.clone(), sin.clone())) |
| 50 | + |
| 51 | + # warm_up |
| 52 | + rotary_emb_fwd(q=q, k=k, cos=cos, sin=sin, **config) |
| 53 | + |
| 54 | + graph = torch.cuda.CUDAGraph() |
| 55 | + |
| 56 | + with torch.cuda.graph(graph): |
| 57 | + for index in range(test_count): |
| 58 | + q, k, cos, sin = input_tuples[index] |
| 59 | + rotary_emb_fwd(q=q, k=k, cos=cos, sin=sin, **config) |
| 60 | + |
| 61 | + graph.replay() |
| 62 | + |
| 63 | + torch.cuda.synchronize() |
| 64 | + start_event = torch.cuda.Event(enable_timing=True) |
| 65 | + end_event = torch.cuda.Event(enable_timing=True) |
| 66 | + start_event.record() |
| 67 | + graph.replay() |
| 68 | + end_event.record() |
| 69 | + end_event.synchronize() |
| 70 | + |
| 71 | + cost_time = start_event.elapsed_time(end_event) |
| 72 | + |
| 73 | + logger.info(str(config)) |
| 74 | + logger.info(f"bf16 {M} cost time: {cost_time} ms") |
| 75 | + return cost_time |
| 76 | + |
| 77 | + |
| 78 | +def worker( |
| 79 | + M: int, |
| 80 | + Q_HEAD_NUM: int, |
| 81 | + K_HEAD_NUM: int, |
| 82 | + HEAD_DIM: int, |
| 83 | + dtype: torch.dtype, |
| 84 | + test_count: int, |
| 85 | + test_configs, |
| 86 | + queue, |
| 87 | +): |
| 88 | + dog = Watchdog(timeout=10) |
| 89 | + dog.start() |
| 90 | + try: |
| 91 | + for index in range(len(test_configs)): |
| 92 | + cost_time = test_kernel( |
| 93 | + M=M, |
| 94 | + Q_HEAD_NUM=Q_HEAD_NUM, |
| 95 | + K_HEAD_NUM=K_HEAD_NUM, |
| 96 | + HEAD_DIM=HEAD_DIM, |
| 97 | + dtype=dtype, |
| 98 | + test_count=test_count, |
| 99 | + **test_configs[index], |
| 100 | + ) |
| 101 | + dog.heartbeat() |
| 102 | + queue.put(cost_time) # Put result in queue |
| 103 | + |
| 104 | + except Exception as ex: |
| 105 | + logger.error(str(ex)) |
| 106 | + logger.exception(str(ex)) |
| 107 | + import sys |
| 108 | + |
| 109 | + sys.exit(-1) |
| 110 | + pass |
| 111 | + |
| 112 | + |
| 113 | +def get_test_configs(split_id, split_count): |
| 114 | + index = 0 |
| 115 | + result = itertools.product([1, 2, 4, 8, 16, 32], [1, 2, 4, 8], [1, 2, 3, 4, 5]) |
| 116 | + for BLOCK_SEQ, num_warps, num_stages in result: |
| 117 | + t_config = { |
| 118 | + "BLOCK_SEQ": BLOCK_SEQ, |
| 119 | + "num_warps": num_warps, |
| 120 | + "num_stages": num_stages, |
| 121 | + } |
| 122 | + if index % split_count == split_id: |
| 123 | + yield t_config |
| 124 | + index += 1 |
| 125 | + else: |
| 126 | + index += 1 |
| 127 | + |
| 128 | + |
| 129 | +def tuning_configs( |
| 130 | + device_id: int, # use for mult mp tunning |
| 131 | + device_count: int, |
| 132 | + M: int, |
| 133 | + Q_HEAD_NUM: int, |
| 134 | + K_HEAD_NUM: int, |
| 135 | + HEAD_DIM: int, |
| 136 | + dtype: torch.dtype, |
| 137 | + test_count: int, |
| 138 | +): |
| 139 | + os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id) |
| 140 | + best_config, best_cost_time = None, 10000000 |
| 141 | + queue = mp.Queue() |
| 142 | + test_configs = [] |
| 143 | + for t_config in get_test_configs(device_id, device_count): |
| 144 | + test_configs.append(t_config) |
| 145 | + if len(test_configs) < 256: |
| 146 | + continue |
| 147 | + |
| 148 | + p = mp.Process( |
| 149 | + target=worker, |
| 150 | + args=( |
| 151 | + M, |
| 152 | + Q_HEAD_NUM, |
| 153 | + K_HEAD_NUM, |
| 154 | + HEAD_DIM, |
| 155 | + dtype, |
| 156 | + test_count, |
| 157 | + test_configs, |
| 158 | + queue, |
| 159 | + ), |
| 160 | + ) |
| 161 | + p.start() |
| 162 | + p.join() |
| 163 | + while len(test_configs) != 0: |
| 164 | + try: |
| 165 | + cost_time = queue.get_nowait() |
| 166 | + logger.info(f"get {test_configs[0]} cost_time: {cost_time}") |
| 167 | + if cost_time < best_cost_time: |
| 168 | + best_config = test_configs[0] |
| 169 | + best_cost_time = cost_time |
| 170 | + logger.info(f"cur best : {best_config} {best_cost_time}") |
| 171 | + del test_configs[0:1] |
| 172 | + except: |
| 173 | + del test_configs[0:16] |
| 174 | + logger.info(f"cur best : {best_config} {best_cost_time}") |
| 175 | + break |
| 176 | + |
| 177 | + while len(test_configs) != 0: |
| 178 | + p = mp.Process( |
| 179 | + target=worker, |
| 180 | + args=( |
| 181 | + M, |
| 182 | + Q_HEAD_NUM, |
| 183 | + K_HEAD_NUM, |
| 184 | + HEAD_DIM, |
| 185 | + dtype, |
| 186 | + test_count, |
| 187 | + test_configs, |
| 188 | + queue, |
| 189 | + ), |
| 190 | + ) |
| 191 | + p.start() |
| 192 | + p.join() |
| 193 | + |
| 194 | + while len(test_configs) != 0: |
| 195 | + try: |
| 196 | + cost_time = queue.get_nowait() |
| 197 | + logger.info(f"get {test_configs[0]} cost_time: {cost_time}") |
| 198 | + if cost_time < best_cost_time: |
| 199 | + best_config = test_configs[0] |
| 200 | + best_cost_time = cost_time |
| 201 | + logger.info(f"cur best : {best_config} {best_cost_time}") |
| 202 | + del test_configs[0:1] |
| 203 | + except: |
| 204 | + del test_configs[0:16] |
| 205 | + logger.info(f"cur best : {best_config} {best_cost_time}") |
| 206 | + break |
| 207 | + |
| 208 | + logger.info(f"M {M} {best_config} best cost: {best_cost_time}") |
| 209 | + return best_config, best_cost_time |
| 210 | + |
| 211 | + |
| 212 | +if __name__ == "__main__": |
| 213 | + torch.multiprocessing.set_start_method("spawn") |
| 214 | + from lightllm.utils.tuning_utils import mp_tuning |
| 215 | + |
| 216 | + # for deepseekv3 600B |
| 217 | + q_head_num = 128 |
| 218 | + k_head_num = 1 |
| 219 | + head_dim = 64 |
| 220 | + dtype = torch.bfloat16 |
| 221 | + for m in [1, 128, 256, 1024, 2048, 4096, 8192]: |
| 222 | + json_dict = {} |
| 223 | + ans = mp_tuning( |
| 224 | + tuning_configs, |
| 225 | + { |
| 226 | + "M": m, |
| 227 | + "Q_HEAD_NUM": q_head_num, |
| 228 | + "K_HEAD_NUM": k_head_num, |
| 229 | + "HEAD_DIM": head_dim, |
| 230 | + "dtype": dtype, |
| 231 | + "test_count": 20, |
| 232 | + }, |
| 233 | + ) |
| 234 | + json_dict[m] = ans |
| 235 | + DeepseekV3RotaryKernelConfig.save_config( |
| 236 | + M=m, |
| 237 | + Q_HEAD_NUM=q_head_num, |
| 238 | + K_HEAD_NUM=k_head_num, |
| 239 | + HEAD_DIM=head_dim, |
| 240 | + dtype=str(dtype), |
| 241 | + config_json=json_dict, |
| 242 | + ) |
0 commit comments