Skip to content

Commit b5996fa

Browse files
committed
Fix unstable selected_rows_functor_test.cu
1 parent d402234 commit b5996fa

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

paddle/fluid/operators/math/selected_rows_functor.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ struct SelectedRowsAddTensor<platform::CUDADeviceContext, T> {
107107
PADDLE_ENFORCE_EQ(in1_height, out_dims[0]);
108108

109109
auto& in1_value = input1.value();
110-
framework::Vector<int64_t> in1_rows(input1.rows());
110+
auto& in1_rows = input1.rows();
111111

112112
int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
113113
PADDLE_ENFORCE_EQ(in1_row_numel, input2.numel() / in1_height);
@@ -206,7 +206,7 @@ struct SelectedRowsAddToTensor<platform::CUDADeviceContext, T> {
206206
PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]);
207207

208208
auto& in1_value = input1.value();
209-
framework::Vector<int64_t> in1_rows(input1.rows());
209+
auto& in1_rows = input1.rows();
210210

211211
int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
212212
PADDLE_ENFORCE_EQ(in1_row_numel, input2->numel() / in1_height);

paddle/fluid/operators/math/selected_rows_functor_test.cu

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ limitations under the License. */
2020
TEST(selected_rows_functor, gpu_add) {
2121
paddle::platform::CUDAPlace gpu_place(0);
2222
paddle::platform::CPUPlace cpu_place;
23-
paddle::platform::CUDADeviceContext ctx(gpu_place);
23+
paddle::platform::CUDADeviceContext& ctx =
24+
*reinterpret_cast<paddle::platform::CUDADeviceContext*>(
25+
paddle::platform::DeviceContextPool::Instance().Get(gpu_place));
2426
paddle::operators::math::SetConstant<paddle::platform::CUDADeviceContext,
2527
float>
2628
functor;
@@ -132,7 +134,9 @@ TEST(selected_rows_functor, gpu_add) {
132134
TEST(selected_rows_functor, gpu_add_to) {
133135
paddle::platform::CUDAPlace gpu_place(0);
134136
paddle::platform::CPUPlace cpu_place;
135-
paddle::platform::CUDADeviceContext ctx(gpu_place);
137+
paddle::platform::CUDADeviceContext& ctx =
138+
*reinterpret_cast<paddle::platform::CUDADeviceContext*>(
139+
paddle::platform::DeviceContextPool::Instance().Get(gpu_place));
136140
paddle::operators::math::SetConstant<paddle::platform::CUDADeviceContext,
137141
float>
138142
functor;

0 commit comments

Comments
 (0)