Skip to content

Commit 20adadc

Browse files
authored
add npu device support for uie (#4401)
* add npu support for uie * update get_device/get_env_device
1 parent be7d123 commit 20adadc

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

paddlenlp/taskflow/task.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,9 @@ def _prepare_static_mode(self):
171171
if paddle.get_device() == "cpu":
172172
self._config.disable_gpu()
173173
self._config.enable_mkldnn()
174+
elif paddle.get_device().split(":", 1)[0] == "npu":
175+
self._config.disable_gpu()
176+
self._config.enable_npu(self.kwargs["device_id"])
174177
else:
175178
self._config.enable_use_gpu(100, self.kwargs["device_id"])
176179
# TODO(linjieccc): enable after fixed

paddlenlp/utils/tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def get_env_device():
122122
"""
123123
if paddle.is_compiled_with_cuda():
124124
return "gpu"
125-
elif paddle.is_compiled_with_npu():
125+
elif "npu" in paddle.device.get_all_custom_device_type():
126126
return "npu"
127127
elif paddle.is_compiled_with_rocm():
128128
return "rocm"

0 commit comments

Comments
 (0)