Skip to content

Commit 9bf9d5d

Browse files
committed
Merge branch 'qwen1-vl-mrope-fix' of https://github.com/ModelTC/lightllm into qwen2-vl-mrope-fix
2 parents c803ea9 + ee4710c commit 9bf9d5d

File tree

5 files changed

+19
-59
lines changed

5 files changed

+19
-59
lines changed

lightllm/models/llama/model.py

Lines changed: 2 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def _init_custom(self):
118118
scaling_type = rope_scaling["type"]
119119
else:
120120
raise ValueError(f"Unknown RoPE scaling format {rope_scaling}")
121-
if scaling_type == "default":
121+
if scaling_type == "default" or "mrope_section" in rope_scaling:
122122
self._init_to_get_rotary()
123123
elif scaling_type == "yarn":
124124
self._init_to_get_yarn_rotary()
@@ -129,7 +129,7 @@ def _init_custom(self):
129129
elif scaling_type == "llama3":
130130
self._init_to_get_llama3_rotary()
131131
elif scaling_type == "mrope":
132-
self._init_to_get_mrope_rotary()
132+
self._init_to_get_rotary()
133133
else:
134134
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
135135
return
@@ -373,47 +373,3 @@ def _init_to_get_llama3_rotary(self, default_base=10000):
373373
self._cos_cached = torch.cos(freqs).to(self.data_type).cuda()
374374
self._sin_cached = torch.sin(freqs).to(self.data_type).cuda()
375375
return
376-
377-
def _init_to_get_mrope_rotary(self, default_base=10000):
378-
partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_)
379-
if self.config.get("rope_scaling", {}) is None:
380-
rope_scaling_factor = 1.0
381-
else:
382-
rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0)
383-
384-
base = self.config.get("rope_theta", float(default_base))
385-
386-
if "max_sequence_length" in self.config:
387-
max_seq_len = self.config["max_sequence_length"]
388-
else:
389-
max_position_embeddings = self.config.get(
390-
"max_position_embeddings", 2048 if base <= 10000.0 + 1e-5 else 16384
391-
)
392-
max_seq_len = max_position_embeddings * rope_scaling_factor
393-
394-
# NTK
395-
try:
396-
ntk_alpha = float(os.environ.get("LIGHTLLM_NTK_ALPHA", 1))
397-
assert ntk_alpha >= 1
398-
if ntk_alpha > 1:
399-
logger.info(f"Note: NTK enabled, alpha set to {ntk_alpha}")
400-
max_seq_len *= ntk_alpha
401-
base = base * (ntk_alpha ** (partial_head_dim / (partial_head_dim - 2))) # Base change formula
402-
except:
403-
pass
404-
405-
self.inv_freq = 1.0 / (
406-
base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim)
407-
)
408-
409-
t = (
410-
torch.arange(max(max_seq_len + 1024 * 128, self.max_seq_length), device="cpu", dtype=torch.float32)
411-
/ rope_scaling_factor
412-
)
413-
freqs = torch.outer(t, self.inv_freq) # (T, D/2)
414-
freqs = torch.cat((freqs, freqs), dim=-1) # (T, D)
415-
416-
self._cos_cached = torch.cos(freqs).to(self.data_type).cuda()
417-
self._sin_cached = torch.sin(freqs).to(self.data_type).cuda()
418-
419-
return

lightllm/models/qwen2_vl/infer_struct.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
3131
b_position_delta[batch_idx] = position_delta
3232
position_ids = self.position_ids + torch.tensor(b_position_delta, device=self.position_ids.device)
3333
position_ids = position_ids.unsqueeze(0).expand(3, -1)
34-
self.position_cos = model._cos_cached[position_ids.unsqueeze(1)] # (3, 1, L, D)
35-
self.position_sin = model._sin_cached[position_ids.unsqueeze(1)] # (3, 1, L, D)
34+
self.position_cos = model._cos_cached[position_ids] # (3, L, D)
35+
self.position_sin = model._sin_cached[position_ids] # (3, L, D)
3636
if get_env_start_args().enable_fa3:
3737
self.max_seq_len = self.max_kv_seq_len
3838
self.q_max_seq_len = self.max_q_seq_len

lightllm/models/qwen2_vl/triton_kernel/mrope.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def mrope_kernel(
2121
HALF: tl.constexpr,
2222
s_tok: tl.int32,
2323
s_ax: tl.int32,
24+
s_d: tl.int32,
2425
q_sb: tl.int32,
2526
q_sh: tl.int32,
2627
q_sl: tl.int32,
@@ -77,7 +78,8 @@ def mrope_kernel(
7778
rot_vals = tl.where(offs < HALF, -rot_vals, rot_vals)
7879

7980
axis_id = tl.load(AXIS_ptr + offs, mask=mask, other=0) # 0,1,2
80-
cos_idx = pid_l * s_tok + axis_id * s_ax + offs
81+
idx_d = tl.where(offs < HALF, offs, offs - HALF)
82+
cos_idx = pid_l * s_tok + axis_id * s_ax + idx_d * s_d
8183
c = tl.load(COS_ptr + cos_idx, mask=mask, other=0.0)
8284
s = tl.load(SIN_ptr + cos_idx, mask=mask, other=0.0)
8385

@@ -101,12 +103,11 @@ def mrope_triton(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch
101103
qo_sb, qo_sh, qo_sl, qo_sd = map(int, q_out.stride())
102104
ko_sb, ko_sh, ko_sl, ko_sd = map(int, k_out.stride())
103105

104-
assert len(cos.shape) == 4
105-
token_dim = 2
106-
axis_dim = 0
106+
assert len(cos.shape) == 3
107107

108-
s_token = int(cos.stride(token_dim))
109-
s_axis = int(cos.stride(axis_dim))
108+
s_axis = int(cos.stride(0))
109+
s_token = int(cos.stride(1))
110+
s_d = int(cos.stride(2))
110111

111112
grid = (B * (H_q + H_k), L)
112113

@@ -126,6 +127,7 @@ def mrope_triton(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch
126127
HALF,
127128
s_token,
128129
s_axis,
130+
s_d,
129131
q_sb,
130132
q_sh,
131133
q_sl,

lightllm/models/qwen2_vl/triton_kernel/rotary_pos_emb.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,9 @@ def rotary_kernel(
3434
partner_d = tl.where(d < HALF_D, d + HALF_D, d - HALF_D)
3535

3636
for pid_l in tl.range(pid_l_start, total_len, step=tl.num_programs(axis=1)):
37-
cos_ptr_ = cos_ptr + pid_l * stride_cos_l + d
38-
sin_ptr_ = sin_ptr + pid_l * stride_sin_l + d
37+
idx_d = tl.where(d < HALF_D, d, d - HALF_D)
38+
cos_ptr_ = cos_ptr + pid_l * stride_cos_l + idx_d * stride_cos_d
39+
sin_ptr_ = sin_ptr + pid_l * stride_sin_l + idx_d * stride_sin_d
3940
cos = tl.load(cos_ptr_, mask=mask)
4041
sin = tl.load(sin_ptr_, mask=mask)
4142

@@ -52,7 +53,7 @@ def rotary_kernel(
5253

5354
y = x * cos + rotated * sin
5455

55-
out_ptr_ = out_ptr + base + d
56+
out_ptr_ = out_ptr + base + d * stride_d
5657
tl.store(out_ptr_, y, mask=mask)
5758

5859

@@ -66,8 +67,8 @@ def apply_rotary_pos_emb_triton(
6667
orig_dtype = tensor.dtype
6768
x = tensor.float()
6869

69-
cos = cos.repeat(1, 2).view(cos.size(0), -1).contiguous().float()
70-
sin = sin.repeat(1, 2).view(sin.size(0), -1).contiguous().float()
70+
cos = cos.contiguous().float()
71+
sin = sin.contiguous().float()
7172

7273
L, H, D = x.shape
7374
HALF_D = D // 2

lightllm/server/httpserver/manager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ async def generate(
270270
start_time = time.time()
271271
request_headers = request.headers if request is not None else {}
272272
group_request_id = self.alloc_req_id(sampling_params, is_health_req)
273+
273274
try:
274275
original_multimodal_params = None
275276
if self.is_multinode_tp_master:

0 commit comments

Comments
 (0)