Skip to content

Commit 984aee4

Browse files
authored
fix errors caused by gpu in conditions (PaddlePaddle#75551)
1 parent 37f7dbe commit 984aee4

File tree

7 files changed

+16
-7
lines changed

7 files changed

+16
-7
lines changed

python/paddle/distributed/fleet/recompute/recompute.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ def _recompute_without_reentrant(
411411

412412
if preserve_rng_state:
413413
cur_device = paddle.get_device()
414-
if 'gpu:' in cur_device:
414+
if cur_device.startswith('gpu:'):
415415
fw_cuda_rng_state = paddle.get_cuda_rng_state()
416416
elif 'cpu' in cur_device:
417417
fw_cuda_rng_state = paddle.get_rng_state()

python/paddle/incubate/jit/inference_decorator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ def create_predictor(self, input_tensor_lists):
393393
config.enable_new_ir(self.enable_new_ir)
394394

395395
device_num = paddle.device.get_device()
396-
if 'gpu' in device_num:
396+
if device_num.startswith('gpu'):
397397
gpu_id = int(device_num.split(':')[1])
398398
config.enable_use_gpu(
399399
self.memory_pool_init_size_mb,

python/paddle/tensor/linalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4558,7 +4558,7 @@ def lstsq(
45584558
f"Only support valid driver is 'gels', 'gelss', 'gelsd', 'gelsy' or None for CPU inputs. But got {driver}"
45594559
)
45604560
driver = "gelsy" if driver is None else driver
4561-
elif "gpu" in device:
4561+
elif device.startswith('gpu'):
45624562
if driver not in (None, "gels"):
45634563
raise ValueError(
45644564
f"Only support valid driver is 'gels' or None for CUDA inputs. But got {driver}"

test/legacy_test/test_compat_slogdet.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,10 @@ def slogdet_backward(self, x, _, grad_logabsdet):
5656

5757
def test_compat_slogdet(self):
5858
devices = [paddle.device.get_device()]
59-
if "gpu:" in devices and not paddle.device.is_compiled_with_rocm():
59+
if (
60+
any(device.startswith("gpu:") for device in devices)
61+
and not paddle.device.is_compiled_with_rocm()
62+
):
6063
devices.append("cpu")
6164
for device in devices:
6265
with paddle.device.device_guard(device), dygraph_guard():

test/legacy_test/test_div_op.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -737,7 +737,10 @@ def test_gpu(self):
737737

738738
def test_infer_symbolic_shape(self):
739739
devices = [paddle.device.get_device()]
740-
if "gpu:" in devices and not paddle.device.is_compiled_with_rocm():
740+
if (
741+
any(device.startswith("gpu:") for device in devices)
742+
and not paddle.device.is_compiled_with_rocm()
743+
):
741744
devices.append("cpu")
742745

743746
for device in devices:

test/legacy_test/test_random_op.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,10 @@ def test_random_update_to(self):
116116

117117
def test_pir_random_(self):
118118
devices = [paddle.device.get_device()]
119-
if "gpu:" in devices and not paddle.device.is_compiled_with_rocm():
119+
if (
120+
any(device.startswith("gpu:") for device in devices)
121+
and not paddle.device.is_compiled_with_rocm()
122+
):
120123
devices.append("cpu")
121124
for device in devices:
122125
with paddle.device.device_guard(device), dygraph_guard():

test/sot/test_sot_place.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def run_diff_logic_by_check_expected_place(x: paddle.Tensor):
4141
expected_place_str = paddle.get_device()
4242
if "cpu" in expected_place_str:
4343
return x + 1
44-
elif "gpu" in expected_place_str:
44+
elif expected_place_str.startswith("gpu"):
4545
return x + 2
4646
elif "xpu" in expected_place_str:
4747
return x + 3

0 commit comments

Comments
 (0)