Skip to content

Commit 206b170

Browse files
SangChengCsangchengmeng
andauthored
check image tag and image num (#1176)
Co-authored-by: sangchengmeng <sangchengmeng@sensetime.com>
1 parent 2fbd2b8 commit 206b170

File tree

6 files changed

+21
-2
lines changed

6 files changed

+21
-2
lines changed

lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse_stage1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ def flash_decode_stage1(
291291
assert k.stride() == v.stride()
292292
NUM_GROUPS = Lk // KV_QUANT_GROUP_SIZE
293293
assert triton.next_power_of_2(NUM_GROUPS) == NUM_GROUPS
294-
294+
295295
assert k.stride() == v.stride()
296296
_fwd_kernel_flash_decode_diverse_stage1[grid](
297297
Q=q,

lightllm/models/internvl/model.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,10 @@ def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs):
149149
raise ValueError("image token error")
150150
except ValueError:
151151
break
152+
if multimodal_params:
153+
image_cnt = len(multimodal_params.images)
154+
if image_cnt != image_id:
155+
raise ValueError(image_cnt == image_id, f"invalid image tag num: {image_cnt} vs {image_id}!")
152156
input_ids.extend(origin_ids[start_idx:])
153157

154158
# audio
@@ -174,6 +178,10 @@ def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs):
174178
raise ValueError("audio token error")
175179
except ValueError:
176180
break
181+
if multimodal_params:
182+
audio_cnt = len(multimodal_params.audios)
183+
if audio_cnt != audio_id:
184+
raise ValueError(audio_cnt == audio_id, f"invalid audio tag num: {audio_cnt} vs {audio_id}!")
177185
input_ids.extend(origin_ids[start_idx:])
178186
return input_ids
179187

lightllm/models/qwen2_vl/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,10 @@ def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs):
7979
raise ValueError("image token error")
8080
except ValueError:
8181
break
82+
if multimodal_params:
83+
image_cnt = len(multimodal_params.images)
84+
if image_cnt != image_id:
85+
raise ValueError(image_cnt == image_id, f"invalid image tag num: {image_cnt} vs {image_id}!")
8286
input_ids.extend(origin_ids)
8387
return input_ids
8488

lightllm/models/qwen2_vl/vision_process.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,8 @@ def preprocess(self, image) -> Tuple[torch.Tensor, torch.Tensor]:
184184
return self._preprocess_bydevice(image, device="cpu")
185185

186186
def _preprocess_bydevice(self, image, device="cuda") -> Tuple[torch.Tensor, torch.Tensor]:
187+
if image.mode != "RGB":
188+
image = image.convert("RGB")
187189
image_arr = np.asarray(image, dtype=np.uint8)
188190
image_data = torch.from_numpy(image_arr).permute(2, 0, 1).contiguous().to(device=device, non_blocking=True)
189191

lightllm/models/qwen_vl/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ def encode(self, prompt, multimodal_params: MultimodalParams = None):
8686
input_ids.extend(origin_ids[end:])
8787
if multimodal_params:
8888
image_cnt = len(multimodal_params.images)
89-
assert image_cnt == image_id, "invalid image tag num: {} vs {}!".format(image_cnt, image_id)
89+
if image_cnt != image_id:
90+
raise ValueError(image_cnt == image_id, f"invalid image tag num: {image_cnt} vs {image_id}!")
9091
return input_ids
9192

9293

lightllm/models/tarsier2/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,10 @@ def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs):
7878
raise ValueError("image token error")
7979
except ValueError:
8080
break
81+
if multimodal_params:
82+
image_cnt = len(multimodal_params.images)
83+
if image_cnt != image_id:
84+
raise ValueError(image_cnt == image_id, f"invalid image tag num: {image_cnt} vs {image_id}!")
8185
input_ids.extend(origin_ids[start_idx:])
8286
return input_ids
8387

0 commit comments

Comments
 (0)