@@ -1801,15 +1801,15 @@ def test_value_input_is_scalar(self):
18011801)
18021802class TestSetValueWithStrideError (unittest .TestCase ):
18031803 def test_same_place (self ):
1804- x = paddle .rand ([5 , 10 ], device = paddle . CUDAPlace ( 0 ))
1805- y = paddle .rand ([10 , 5 ], device = paddle . CUDAPlace ( 0 ))
1804+ x = paddle .rand ([5 , 10 ], device = get_device_place ( ))
1805+ y = paddle .rand ([10 , 5 ], device = get_device_place ( ))
18061806 y .transpose_ ([1 , 0 ])
18071807 x .set_value (y )
18081808 assert x .is_contiguous ()
18091809
18101810 def test_different_place1 (self ):
18111811 # src place != dst place && src is not contiguous
1812- x = paddle .rand ([5 , 10 ], device = paddle . CUDAPlace ( 0 ))
1812+ x = paddle .rand ([5 , 10 ], device = get_device_place ( ))
18131813 y = paddle .rand ([10 , 5 ], device = paddle .CPUPlace ())
18141814 y .transpose_ ([1 , 0 ])
18151815 x .set_value (y )
@@ -1818,7 +1818,7 @@ def test_different_place1(self):
18181818 def test_different_place2 (self ):
18191819 # src place != dst place && dst is not contiguous
18201820 with self .assertRaises (SystemError ):
1821- x = paddle .ones ([5 , 4 ], device = paddle . CUDAPlace ( 0 ))
1821+ x = paddle .ones ([5 , 4 ], device = get_device_place ( ))
18221822 x .transpose_ ([1 , 0 ])
18231823 y = paddle .rand ([4 , 2 ], device = paddle .CPUPlace ())
18241824 assert not x .is_contiguous ()
0 commit comments