Skip to content

Commit 803bdd1

Browse files
committed
PaddleOCR-VL supports FP32 (#4658)
1 parent f6bb816 commit 803bdd1

File tree

3 files changed

+23
-12
lines changed

3 files changed

+23
-12
lines changed

paddlex/inference/models/common/vlm/transformers/model_utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def _load_state_dict_into_model(
275275
model_to_load, state_dict, start_prefix, convert_from_hf
276276
):
277277
# torch will cast dtype in load_state_dict, but paddle strictly check dtype
278-
_convert_state_dict_dtype_and_shape(state_dict, model_to_load)
278+
_convert_state_dict_dtype_and_shape(state_dict, model_to_load, convert_from_hf)
279279

280280
error_msgs = []
281281

@@ -305,12 +305,16 @@ def _load_state_dict_into_model(
305305
return error_msgs
306306

307307

308-
def _convert_state_dict_dtype_and_shape(state_dict, model_to_load):
308+
def _convert_state_dict_dtype_and_shape(state_dict, model_to_load, convert_from_hf):
309309
# convert the dtype of state dict
310310
def is_0d_or_1d(tensor):
311311
return len(tensor.shape) == 0 or list(tensor.shape) == [1]
312312

313-
for key, value in model_to_load.state_dict().items():
313+
if convert_from_hf:
314+
model_state_dict = model_to_load.get_hf_state_dict()
315+
else:
316+
model_state_dict = model_to_load.state_dict()
317+
for key, value in model_state_dict.items():
314318
if key in list(state_dict.keys()):
315319
if isinstance(state_dict[key], np.ndarray):
316320
raise ValueError(

paddlex/inference/models/doc_vlm/predictor.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828
from ....utils import logging
2929
from ....utils.deps import require_genai_client_plugin
3030
from ....utils.device import TemporaryDeviceChanger
31-
from ....utils.env import get_device_type
3231
from ...common.batch_sampler import DocVLMBatchSampler
32+
from ...utils.misc import is_bfloat16_available
3333
from ..base import BasePredictor
3434
from .result import DocVLMResult
3535

@@ -53,15 +53,8 @@ def __init__(self, *args, **kwargs):
5353
super().__init__(*args, **kwargs)
5454

5555
if self._use_local_model:
56-
import paddle
57-
5856
self.device = kwargs.get("device", None)
59-
self.dtype = (
60-
"bfloat16"
61-
if ("npu" in get_device_type() or paddle.amp.is_bfloat16_supported())
62-
and (self.device is None or "cpu" not in self.device)
63-
else "float32"
64-
)
57+
self.dtype = "bfloat16" if is_bfloat16_available(self.device) else "float32"
6558

6659
self.infer, self.processor = self._build(**kwargs)
6760

paddlex/inference/utils/misc.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,23 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from ...utils.device import get_default_device, parse_device
16+
from ...utils.env import get_device_type
17+
1518

1619
def is_mkldnn_available():
1720
# XXX: Not sure if this is the best way to check if MKL-DNN is available
1821
from paddle.inference import Config
1922

2023
return hasattr(Config, "set_mkldnn_cache_capacity")
24+
25+
26+
def is_bfloat16_available(device):
27+
import paddle.amp
28+
29+
if device is None:
30+
device = get_default_device()
31+
device_type, _ = parse_device(device)
32+
return (
33+
"npu" in get_device_type() or paddle.amp.is_bfloat16_supported()
34+
) and device_type in ("gpu", "npu", "xpu", "mlu", "dcu")

0 commit comments

Comments
 (0)