Skip to content

Commit 7c046ae

Browse files
authored
Merge pull request #12323 from reyoung/feature/polish_reshape_and_lod_tensor_blocking_queue
Feature/polish reshape and lod tensor blocking queue
2 parents 03d70c1 + fa9cbfb commit 7c046ae

File tree

2 files changed

+1
-18
lines changed

2 files changed

+1
-18
lines changed

paddle/fluid/operators/reader/lod_tensor_blocking_queue.h

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,10 @@ class LoDTensorBlockingQueue {
3838

3939
public:
4040
bool Push(const std::vector<framework::LoDTensor>& lod_tensor_vec) {
41-
CheckDims(lod_tensor_vec);
4241
return queue_.Send(lod_tensor_vec);
4342
}
4443

4544
bool Push(std::vector<framework::LoDTensor>&& lod_tensor_vec) {
46-
CheckDims(lod_tensor_vec);
4745
return queue_.Send(std::move(lod_tensor_vec));
4846
}
4947

@@ -65,21 +63,6 @@ class LoDTensorBlockingQueue {
6563
inline bool IsClosed() const { return queue_.IsClosed(); }
6664

6765
private:
68-
void CheckDims(
69-
const std::vector<framework::LoDTensor>& lod_tensor_vec) const {
70-
PADDLE_ENFORCE(dims_.size() == lod_tensor_vec.size(),
71-
"Expect input size is %d but found %s", dims_.size(),
72-
lod_tensor_vec.size());
73-
for (size_t i = 0; i < dims_.size(); ++i) {
74-
const auto& in_dims = framework::slice_ddim(
75-
lod_tensor_vec[i].dims(), 1, lod_tensor_vec[i].dims().size());
76-
const auto& expect_dims =
77-
framework::slice_ddim(dims_[i], 1, dims_[i].size());
78-
PADDLE_ENFORCE(in_dims == expect_dims,
79-
"Dims of the %d-th input tensor do not match", i);
80-
}
81-
}
82-
8366
BlockingQueue<std::vector<framework::LoDTensor>> queue_;
8467
std::vector<framework::DDim> dims_;
8568
};

paddle/fluid/operators/reshape_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ class ReshapeKernel {
216216
if (shape_tensor) {
217217
auto *shape_data = shape_tensor->data<int>();
218218
framework::Tensor cpu_shape_tensor;
219-
if (platform::is_gpu_place(ctx.GetPlace())) {
219+
if (platform::is_gpu_place(shape_tensor->place())) {
220220
TensorCopySync(*shape_tensor, platform::CPUPlace(), &cpu_shape_tensor);
221221
shape_data = cpu_shape_tensor.data<int>();
222222
}

0 commit comments

Comments
 (0)