|
22 | 22 | ) |
23 | 23 | from torchvision.transforms.v2 import functional as F |
24 | 24 |
|
| 25 | +from lightllm.utils.log_utils import init_logger |
| 26 | + |
| 27 | +logger = init_logger(__name__) |
| 28 | + |
| 29 | + |
25 | 30 | IMAGE_FACTOR = 28 |
26 | 31 | MIN_PIXELS = 4 * 28 * 28 |
27 | 32 | MAX_PIXELS = 16384 * 28 * 28 |
@@ -110,6 +115,7 @@ def __init__( |
110 | 115 | self.interpolation = interpolation |
111 | 116 | self.data_format = ChannelDimension.FIRST |
112 | 117 | self._fused_cache = {} # key: (do_norm, do_rescale, rescale_factor, device) |
| 118 | + self.free_gpu_mem = 0 |
113 | 119 |
|
114 | 120 | def _get_fused_mean_std( |
115 | 121 | self, |
@@ -162,9 +168,19 @@ def rescale_and_normalize( |
162 | 168 |
|
163 | 169 | def preprocess(self, image) -> Tuple[torch.Tensor, torch.Tensor]: |
164 | 170 | image_arr = np.asarray(image, dtype=np.uint8) |
165 | | - # TODO check cuda tensor oom reason |
166 | | - # image_data = torch.from_numpy(image_arr).permute(2, 0, 1).contiguous().to("cuda", non_blocking=True) |
167 | | - image_data = torch.from_numpy(image_arr).permute(2, 0, 1) |
| 171 | + image_data = torch.from_numpy(image_arr).permute(2, 0, 1).contiguous() |
| 172 | + |
| 173 | + self.free_gpu_mem, _ = torch.cuda.mem_get_info() |
| 174 | + image_size = int(image_arr.size) * 4 # 后面需要转float32 |
| 175 | + if image_size < self.free_gpu_mem * 0.9: |
| 176 | + image_data = image_data.to("cuda", non_blocking=True) |
| 177 | + else: |
| 178 | + logger.warning( |
| 179 | + f"[Qwen2VLImageProcessor] preprocess fallback to CPU:" |
| 180 | + f"shape = {tuple(image_arr.shape), }" |
| 181 | + f"image_size = {image_size/(1024**2):.2f} MB," |
| 182 | + f"free_gpu_mem = {self.free_gpu_mem/(1024**2):.2f} MB" |
| 183 | + ) |
168 | 184 |
|
169 | 185 | grouped_images, grouped_images_index = group_images_by_shape( |
170 | 186 | [image_data], disable_grouping=self.disable_grouping |
|
0 commit comments