Skip to content

Commit e228e70

Browse files
authored
fix ZeroCopyTensor::mutable_data(), test=release/1.6 (#21581)
1 parent 0a4002f commit e228e70

File tree

2 files changed

+2
-1
lines changed

2 files changed

+2
-1
lines changed

paddle/fluid/inference/api/details/zero_copy_tensor.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ T *ZeroCopyTensor::mutable_data(PaddlePlace place) {
5252
return tensor->mutable_data<T>(platform::CPUPlace());
5353
}
5454
case static_cast<int>(PaddlePlace::kGPU): {
55-
return tensor->mutable_data<T>(platform::CUDAPlace());
55+
return tensor->mutable_data<T>(platform::CUDAPlace(device_));
5656
}
5757
default:
5858
PADDLE_THROW("Unsupported place: %d", static_cast<int>(place));

paddle/fluid/inference/tests/api/trt_fc_prelu_test.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ TEST(ZeroCopyTensor, uint8) {
5151
input_t->Reshape({batch_size, length});
5252
input_t->copy_from_cpu(input);
5353
input_t->type();
54+
input_t->mutable_data<uint8_t>(PaddlePlace::kGPU);
5455

5556
ASSERT_TRUE(predictor->ZeroCopyRun());
5657
}

0 commit comments

Comments
 (0)