|
20 | 20 | logger = init_logger(__name__) |
21 | 21 |
|
22 | 22 |
|
23 | | -class FlashAttentionStateExtraInfo: |
24 | | - def __init__(self, model): |
25 | | - num_heads = model.config["num_attention_heads"] |
26 | | - self.tp_q_head_num = num_heads // get_dp_world_size() |
27 | | - self.qk_nope_head_dim = model.qk_nope_head_dim |
28 | | - self.qk_rope_head_dim = model.qk_rope_head_dim |
29 | | - self.kv_lora_rank = model.kv_lora_rank |
30 | | - self.q_data_type = model.data_type |
31 | | - self.kv_data_type = model.data_type |
32 | | - self.workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(get_current_device_id()) |
33 | | - self.max_seq_length = model.max_seq_length |
34 | | - self.softmax_scale = (self.qk_nope_head_dim + self.qk_rope_head_dim) ** (-0.5) |
35 | | - self.kv_indices_buffer = [ |
36 | | - torch.empty(model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32).to( |
37 | | - get_current_device_id() |
38 | | - ), |
39 | | - torch.empty(model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32).to( |
40 | | - get_current_device_id() |
41 | | - ), |
42 | | - ] |
43 | | - if model.config["rope_scaling"] is not None: |
44 | | - rope_scaling = model.config["rope_scaling"] |
45 | | - mscale_all_dim = rope_scaling.get("mscale_all_dim", 0) |
46 | | - scaling_factor = rope_scaling["factor"] |
47 | | - if mscale_all_dim: |
48 | | - mscale = get_deepseek_mscale(scaling_factor, mscale_all_dim) |
49 | | - self.softmax_scale = self.softmax_scale * mscale * mscale |
50 | | - |
51 | | - |
52 | 23 | class FlashInferStateExtraInfo: |
53 | 24 | def __init__(self, model): |
54 | 25 | num_heads = model.config["num_attention_heads"] |
|
0 commit comments