Skip to content

[BUG] table_parsing.py 处理目录包含多张表格图片时 CUDA Out of Memory | CUDA OOM when parsing a directory with many table images #237

@RubiaPath

Description

@RubiaPath

Summary | 问题描述

torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 28.49 GiB. GPU 0 has a total capacity of 47.54 GiB of which 6.76 GiB is free. Including non-PyTorch memory, this process has 40.77 GiB memory in use. Of the allocated memory 40.37 GiB is allocated by PyTorch, and 85.85 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

使用 scripts/table_parsing.py 处理包含约 120 张表格 crop PNG 图片的目录时,会在推理过程中出现 CUDA out of memory 错误。

现象对照:

  • 单张图片推理正常(显存变化很小)
  • 初始化模型后显存占用仅 ~1992MB(不是问题源)
  • 目录输入(约 120 张图)会在推理过程中 OOM

Environment | 环境信息

  • GPU: NVIDIA A6000 48G
  • max_new_tokens: 256
  • flash_attn: False
  • 模型: table_parsing_struct_eqtable (InternVL2-1B via struct_eqtable)
  • 输入: 约 120 张 table crop PNG(由 scripts/layout_detection.py 输出 bbox 后裁剪得到)

Steps to Reproduce | 复现步骤

python scripts/table_parsing.py --config configs/table_parsing.yaml
# 配置 inputs 为目录路径(包含 120 张 table PNG)
# 错误: torch.OutOfMemoryError: CUDA out of memory

Analysis | 原因分析

BaseTask.load_images() 一次性返回全量路径列表 | Returns all image paths at once

文件 / File: pdf_extract_kit/tasks/base_task.py:11-32

input_data 为目录时,此方法返回全部图片路径列表:

def load_images(self, input_data):
    images = []
    if os.path.isdir(input_data):
        for root, dirs, files in os.walk(input_data):
            for file in files:
                if file.lower().endswith(('.png', '.jpg', '.jpeg')):
                    image_path = os.path.join(root, file)
                    images.append(image_path)
            images = sorted(images)
            break
    return images

TableParsingStructEqTable.predict() 一次性加载所有 PIL Image 并将全量列表传入后端模型 | Opens all images into PIL and passes the entire list to backend

文件 / File: pdf_extract_kit/tasks/table_parsing/models/struct_eqtable.py:38-52

def predict(self, images, result_path, output_format=None, **kwargs):
    load_images = [Image.open(image_path) for image_path in images]
    results = self.model(load_images, output_format=output_format)
    return results

问题 / Issue:

  • 会一次性打开目录内所有图片(约 120 张)并传入后端模型
  • 后端模型会将这些图片在一次 forward 中统一做 dynamic preprocessing 并在末尾执行一次性 cat + to(device),导致显存峰值过高并触发 OOM(见下)

InternVL.forward() 对所有图片做 dynamic preprocessing 并一次性 cat 后搬到 GPU | dynamic crops + single huge cat to GPU

文件 / File: .venv/lib/python3.12/site-packages/struct_eqtable/internvl/internvl.py:58-98

def forward(self, images, output_format='latex', **kwargs):
    if not isinstance(images, list):
        images = [images]

    pixel_values_list = []
    for image in images:
        path_images = self.dynamic_preprocess(image, image_size=448, max_num=12)
        pixel_values = self.image_processor(
            path_images,
            return_tensors='pt'
        )['pixel_values'].to(torch.bfloat16)
        pixel_values_list.append(pixel_values)

    batch_size = len(pixel_values_list)  # 实际长度(不使用配置参数)

    pixel_values = torch.cat(pixel_values_list, axis=0).to(device)

问题 / Issue:

  • dynamic_preprocess() 会将每张图片切成最多 12 crops + 1 thumbnail = 13 个图块
  • 目录输入约 120 张图会产生约 120 × 13 = 1560 个 crop 张量
  • 最后 torch.cat(...).to(device) 会将大量 crop 张量一次性拼接并搬到 GPU,导致显存峰值过高,从而触发 OOM

Key Findings | 关键发现

batch_size 参数在 wrapper 侧被传递,但在后端未生效 | batch_size is passed by wrapper but dropped/ignored in backend

在 PDF-Extract-Kit 的 wrapper 中,batch_size 会被读取并传给 build_model(...)

文件 / File: pdf_extract_kit/tasks/table_parsing/models/struct_eqtable.py

self.batch_size = config.get('batch_size', 1)

self.model = build_model(
    model_ckpt=self.model_dir,
    max_new_tokens=self.max_new_tokens,
    max_time=self.max_time,
    lmdeploy=self.lmdeploy,
    flash_attn=self.flash_attn,
    batch_size=self.batch_size,
).cuda()

但在后端 InternVL 中:

文件: .venv/lib/python3.12/site-packages/struct_eqtable/internvl/internvl.py:8-15

def __init__(self, model_path='...', max_new_tokens=1024, max_time=30, flash_attn=True, **kwargs):
    # ⚠️ batch_size 被 **kwargs 吃掉,但没有保存!
    self.model_path = model_path
    self.max_new_tokens = max_new_tokens
    self.max_generate_time = max_time
    self.flash_attn = flash_attn
    # ❌ 没有 self.batch_size = kwargs.get('batch_size', 1)
  • InternVL.__init__() 使用 **kwargs 接收额外参数,但未将 batch_size 保存为 self.batch_size(参数被吞掉)
  • InternVL.forward() 也没有任何按 batch_size 分块/分批的逻辑,而是无条件处理全部 images 并一次性 torch.cat(...).to(device)

因此,即使在配置里设置 batch_size,目录输入仍会触发一次性全量 forward,从而 OOM


其他任务实现采用逐张处理 | Other tasks process images one by one

例如:

  • FormulaDetectionYOLOLayoutDetectionLayoutlmv3 等采用 for idx, image in enumerate(images) 的逐张处理策略,避免一次性堆积到 GPU

Proposed Fix | 建议修复

TableParsingStructEqTable.predict() 中实现分批处理(chunking)| Implement chunking in wrapper predict()

文件 / File: pdf_extract_kit/tasks/table_parsing/models/struct_eqtable.py

修复思路:

  • 由于后端 InternVL 未实现 batch_size 分批逻辑(参数被 **kwargs 吞掉并未使用),因此建议在 PDF-Extract-Kit 的 wrapper TableParsingStructEqTable.predict() 层面进行 chunking

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions