Skip to content

Commit 1cccbb9

Browse files
author
sangchengmeng
committed
opti qwen2-vl vision_process
1 parent 70c6b31 commit 1cccbb9

File tree

1 file changed

+24
-9
lines changed

1 file changed

+24
-9
lines changed

lightllm/models/qwen2_vl/vision_process.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,8 @@ def rescale_and_normalize(
162162

163163
def preprocess(self, image) -> Tuple[torch.Tensor, torch.Tensor]:
164164
image_arr = np.asarray(image, dtype=np.uint8)
165-
image_data = torch.from_numpy(image_arr).permute(2, 0, 1).contiguous()
165+
image_data = torch.from_numpy(image_arr).permute(2, 0, 1)
166+
166167
grouped_images, grouped_images_index = group_images_by_shape(
167168
[image_data], disable_grouping=self.disable_grouping
168169
)
@@ -183,27 +184,39 @@ def preprocess(self, image) -> Tuple[torch.Tensor, torch.Tensor]:
183184
interpolation=self.interpolation,
184185
)
185186
resized_images_grouped[shape] = stacked_images
187+
188+
grouped_images = None
186189
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
190+
resized_images_grouped = None
187191

188-
# Group images by size for further processing
189-
# Needed in case do_resize is False, or resize returns images with different sizes
190192
grouped_images, grouped_images_index = group_images_by_shape(
191193
resized_images, disable_grouping=self.disable_grouping
192194
)
195+
resized_images = None
196+
193197
processed_images_grouped = {}
194198
processed_grids = {}
199+
195200
for shape, stacked_images in grouped_images.items():
201+
stacked_images = stacked_images.to("cuda", non_blocking=True)
202+
196203
resized_height, resized_width = stacked_images.shape[-2:]
197-
# Fused rescale and normalize
204+
198205
patches = self.rescale_and_normalize(
199-
stacked_images, self.do_rescale, self.rescale_factor, self.do_normalize, self.image_mean, self.image_std
206+
stacked_images,
207+
self.do_rescale,
208+
self.rescale_factor,
209+
self.do_normalize,
210+
self.image_mean,
211+
self.image_std,
200212
)
201213
if patches.ndim == 4:
202-
# add a temporal dimension if we have images
203214
patches = patches.unsqueeze(1)
215+
204216
if patches.shape[1] % self.temporal_patch_size != 0:
205217
repeats = patches[:, -1:].repeat(1, self.temporal_patch_size - 1, 1, 1, 1)
206218
patches = torch.cat([patches, repeats], dim=1)
219+
207220
batch_size, grid_t, channel = patches.shape[:3]
208221
grid_t = grid_t // self.temporal_patch_size
209222
grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
@@ -224,8 +237,7 @@ def preprocess(self, image) -> Tuple[torch.Tensor, torch.Tensor]:
224237
.permute(0, 1, 4, 7, 5, 8, 3, 2, 6, 9)
225238
.contiguous()
226239
)
227-
# Reorder dimensions to group grid and patch information for subsequent flattening.
228-
# (batch, grid_t, grid_h, grid_w, merge_h, merge_w, channel, temp_patch_size, patch_h, patch_w)
240+
229241
flatten_patches = patches.view(
230242
batch_size,
231243
grid_t * grid_h * grid_w,
@@ -235,9 +247,12 @@ def preprocess(self, image) -> Tuple[torch.Tensor, torch.Tensor]:
235247
processed_images_grouped[shape] = flatten_patches
236248
processed_grids[shape] = [[grid_t, grid_h, grid_w]] * batch_size
237249

250+
grouped_images = None
251+
238252
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
239253
processed_grids = reorder_images(processed_grids, grouped_images_index)
240-
pixel_values = torch.cat(processed_images, dim=0) # (num_patches_total, C*T*ps*ps)
254+
255+
pixel_values = torch.cat(processed_images, dim=0)
241256
image_grid_thw = torch.as_tensor(processed_grids)
242257

243258
return pixel_values, image_grid_thw

0 commit comments

Comments
 (0)