Skip to content

Commit bf862e2

Browse files
committed
fix
1 parent 9a5c8de commit bf862e2

File tree

2 files changed

+4
-10
lines changed

2 files changed

+4
-10
lines changed

lightllm/models/qwen2_vl/infer_struct.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,9 @@ def __init__(self):
1111
self.position_sin = None
1212

1313
def init_some_extra_state(self, model, input_ids: torch.Tensor):
14-
if "rope_type" in model.config["rope_scaling"]:
15-
rope_scaling_type = model.config["rope_scaling"]["rope_type"]
16-
elif "type" in model.config["rope_scaling"]:
17-
rope_scaling_type = model.config["rope_scaling"]["type"]
18-
if rope_scaling_type == "default":
14+
rope_scaling = model.config["rope_scaling"].get("rope_type", {})
15+
self.rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None))
16+
if self.rope_type != "mrope":
1917
super().init_some_extra_state(model, input_ids)
2018
return
2119
InferStateInfo.init_some_extra_state(self, model, input_ids)

lightllm/models/qwen2_vl/layer_infer/transformer_layer_infer.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,13 @@ class Qwen2VLTransformerLayerInfer(LlamaTransformerLayerInfer):
1414
def __init__(self, layer_num, network_config, mode=[]):
1515
super().__init__(layer_num, network_config, mode)
1616
self.mrope_section = network_config["rope_scaling"]["mrope_section"]
17-
if "rope_type" in network_config["rope_scaling"]:
18-
self.rope_scaling_type = network_config["rope_scaling"]["rope_type"]
19-
elif "type" in network_config["rope_scaling"]:
20-
self.rope_scaling_type = network_config["rope_scaling"]["type"]
2117
axis_map = []
2218
for i, n in enumerate(self.mrope_section * 2):
2319
axis_map += [i % 3] * n
2420
self.axis_map = torch.tensor(axis_map, dtype=torch.int32, device="cuda")
2521

2622
def _get_qkv(self, input, infer_state, layer_weight):
27-
if self.rope_scaling_type == "default":
23+
if infer_state.rope_type != "mrope":
2824
return super()._get_qkv(input, infer_state, layer_weight)
2925
q = layer_weight.q_proj.mm(input)
3026
cache_kv = layer_weight.kv_proj.mm(input).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)

0 commit comments

Comments
 (0)