Skip to content

Commit 3e3b5f4

Browse files
authored
Merge pull request #12675 from Sand3r-/fix-conv-mkldnn-0.15
Update MKLDNN to 0.15, fix convolution integration
2 parents 557be6f + 4a7f069 commit 3e3b5f4

File tree

5 files changed

+39
-23
lines changed

5 files changed

+39
-23
lines changed

cmake/external/mkldnn.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ ExternalProject_Add(
5454
${EXTERNAL_PROJECT_LOG_ARGS}
5555
DEPENDS ${MKLDNN_DEPENDS}
5656
GIT_REPOSITORY "https://github.com/01org/mkl-dnn.git"
57-
GIT_TAG "a29d8487a63afca3d5b8c5bbdbb473cf8ccc6e51"
57+
GIT_TAG "64e03a1939e0d526aa8e9f2e3f7dc0ad8d372944"
5858
PREFIX ${MKLDNN_SOURCES_DIR}
5959
UPDATE_COMMAND ""
6060
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}

paddle/fluid/framework/tensor.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,16 @@ size_t Tensor::memory_size() const {
3131
return holder_ == nullptr ? 0UL : holder_->size() - offset_;
3232
}
3333

34-
void* Tensor::mutable_data(platform::Place place, std::type_index type) {
34+
void* Tensor::mutable_data(platform::Place place, std::type_index type,
35+
size_t requested_size) {
3536
if (holder_ != nullptr) {
3637
holder_->set_type(type);
3738
}
3839
PADDLE_ENFORCE_GE(numel(), 0,
3940
"When calling this method, the Tensor's numel must be "
4041
"equal or larger than zero. "
4142
"Please check Tensor::Resize has been called first.");
42-
int64_t size = numel() * SizeOfType(type);
43+
size_t size = requested_size ? requested_size : numel() * SizeOfType(type);
4344
/* some versions of boost::variant don't have operator!= */
4445
if (holder_ == nullptr || !(holder_->place() == place) ||
4546
holder_->size() < size + offset_) {
@@ -68,10 +69,10 @@ void* Tensor::mutable_data(platform::Place place, std::type_index type) {
6869
offset_);
6970
}
7071

71-
void* Tensor::mutable_data(platform::Place place) {
72+
void* Tensor::mutable_data(platform::Place place, size_t requested_size) {
7273
PADDLE_ENFORCE(this->holder_ != nullptr,
7374
"Cannot invoke mutable data if current hold nothing.");
74-
return mutable_data(place, holder_->type());
75+
return mutable_data(place, holder_->type(), requested_size);
7576
}
7677

7778
Tensor& Tensor::ShareDataWith(const Tensor& src) {

paddle/fluid/framework/tensor.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,22 +89,24 @@ class Tensor {
8989
* @note If not exist, then allocation.
9090
*/
9191
template <typename T>
92-
T* mutable_data(platform::Place place);
92+
T* mutable_data(platform::Place place, size_t requested_size = 0);
9393

94-
void* mutable_data(platform::Place place, std::type_index type);
94+
void* mutable_data(platform::Place place, std::type_index type,
95+
size_t requested_size = 0);
9596

96-
void* mutable_data(platform::Place place);
97+
void* mutable_data(platform::Place place, size_t requested_size = 0);
9798

9899
/**
99100
* @brief Return a pointer to mutable memory block.
100101
*
101-
* @param[in] dims The dimensions of the memory block.
102-
* @param[in] place The place of the memory block.
102+
* @param[in] dims The dimensions of the memory block.
103+
* @param[in] place The place of the memory block.
104+
* @param[in] requested_size The size of the block in bytes.
103105
*
104106
* @note If not exist, then allocation.
105107
*/
106108
template <typename T>
107-
T* mutable_data(DDim dims, platform::Place place);
109+
T* mutable_data(DDim dims, platform::Place place, size_t requested_size = 0);
108110

109111
/*! Return the dimensions of the memory block. */
110112
const DDim& dims() const;

paddle/fluid/framework/tensor_impl.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,16 +46,17 @@ inline T* Tensor::data() {
4646
}
4747

4848
template <typename T>
49-
inline T* Tensor::mutable_data(DDim dims, platform::Place place) {
49+
inline T* Tensor::mutable_data(DDim dims, platform::Place place,
50+
size_t requested_size) {
5051
static_assert(std::is_pod<T>::value, "T must be POD");
5152
Resize(dims);
52-
return mutable_data<T>(place);
53+
return mutable_data<T>(place, requested_size);
5354
}
5455

5556
template <typename T>
56-
inline T* Tensor::mutable_data(platform::Place place) {
57+
inline T* Tensor::mutable_data(platform::Place place, size_t requested_size) {
5758
static_assert(std::is_pod<T>::value, "T must be POD");
58-
return reinterpret_cast<T*>(mutable_data(place, typeid(T)));
59+
return reinterpret_cast<T*>(mutable_data(place, typeid(T), requested_size));
5960
}
6061

6162
inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) {

paddle/fluid/operators/conv_mkldnn_op.cc

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,18 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
5353
key_ += "-BWD";
5454
}
5555

56+
size_t GetDstMemorySize() const {
57+
return conv_pd_->dst_primitive_desc().get_size();
58+
}
59+
60+
size_t GetDiffWeightsMemorySize() const {
61+
return conv_bwd_weights_pd_->diff_weights_primitive_desc().get_size();
62+
}
63+
64+
size_t GetDiffSourceMemorySize() const {
65+
return conv_bwd_data_pd_->diff_src_primitive_desc().get_size();
66+
}
67+
5668
std::shared_ptr<mkldnn::memory> AcquireSrcMemoryFromWeightsPrimitive(
5769
const std::shared_ptr<mkldnn::memory> user_memory_p,
5870
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
@@ -294,7 +306,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
294306

295307
const T* input_data = input->data<T>();
296308
const T* filter_data = filter->data<T>();
297-
T* output_data = output->mutable_data<T>(ctx.GetPlace());
298309

299310
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
300311
std::vector<int> weights_tz =
@@ -354,6 +365,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
354365
auto user_weights_memory_p = handler.AcquireWeightsMemory(
355366
user_weights_md, to_void_cast<T>(filter_data));
356367

368+
T* output_data =
369+
output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize());
357370
// create reorder primitive if the input format is not the preferred one
358371
auto src_memory_p =
359372
handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline);
@@ -476,13 +489,6 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
476489
T* input_grad_data = nullptr;
477490
T* filter_grad_data = nullptr;
478491

479-
if (input_grad) {
480-
input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
481-
}
482-
if (filter_grad) {
483-
filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace());
484-
}
485-
486492
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
487493
std::vector<int> weights_tz =
488494
paddle::framework::vectorize2int(filter->dims());
@@ -568,6 +574,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
568574
handler.AcquireDiffDstMemoryFromWeightsPrimitive(
569575
user_diff_dst_memory_p, pipeline);
570576

577+
const size_t size = handler.GetDiffWeightsMemorySize();
578+
filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace(), size);
579+
571580
auto diff_weights_memory_p =
572581
handler.AcquireDiffWeightsMemoryFromWeightsPrimitive(
573582
reinterpret_cast<void*>(filter_grad_data));
@@ -590,6 +599,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
590599
handler.AcquireDiffDstMemoryFromDataPrimitive(user_diff_dst_memory_p,
591600
pipeline);
592601

602+
const size_t size = handler.GetDiffSourceMemorySize();
603+
input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace(), size);
604+
593605
auto diff_src_memory_p = handler.AcquireDiffSrcMemoryFromDataPrimitive(
594606
reinterpret_cast<void*>(input_grad_data));
595607

0 commit comments

Comments
 (0)