Skip to content

Commit 673dc26

Browse files
authored
Merge pull request #7164 from tensor-tang/context
Add MKLDNNDeviceContext
2 parents 894236a + 6177cb5 commit 673dc26

File tree

7 files changed

+171
-6
lines changed

7 files changed

+171
-6
lines changed

cmake/external/mkldnn.cmake

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,17 @@ ExternalProject_Add(
6363
-DMKLROOT:PATH=${MKLML_ROOT}
6464
)
6565

66-
ADD_LIBRARY(mkldnn SHARED IMPORTED GLOBAL)
67-
SET_PROPERTY(TARGET mkldnn PROPERTY IMPORTED_LOCATION ${MKLDNN_LIB})
68-
ADD_DEPENDENCIES(mkldnn ${MKLDNN_PROJECT})
66+
ADD_LIBRARY(shared_mkldnn SHARED IMPORTED GLOBAL)
67+
SET_PROPERTY(TARGET shared_mkldnn PROPERTY IMPORTED_LOCATION ${MKLDNN_LIB})
68+
ADD_DEPENDENCIES(shared_mkldnn ${MKLDNN_PROJECT})
6969
MESSAGE(STATUS "MKLDNN library: ${MKLDNN_LIB}")
7070
add_definitions(-DPADDLE_WITH_MKLDNN)
71-
LIST(APPEND external_project_dependencies mkldnn)
71+
LIST(APPEND external_project_dependencies shared_mkldnn)
72+
73+
# generate a static dummy target to track mkldnn dependencies
74+
# for cc_library(xxx SRCS xxx.c DEPS mkldnn)
75+
SET(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/mkldnn_dummy.c)
76+
FILE(WRITE ${dummyfile} "const char * dummy = \"${dummyfile}\";")
77+
ADD_LIBRARY(mkldnn STATIC ${dummyfile})
78+
TARGET_LINK_LIBRARIES(mkldnn ${MKLDNN_LIB} ${MKLML_LIB} ${MKLML_IOMP_LIB})
79+
ADD_DEPENDENCIES(mkldnn ${MKLDNN_PROJECT})

paddle/framework/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ device_context)
4141
cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute)
4242
cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
4343
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)
44-
cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute)
44+
cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_context)
4545
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog
4646
shape_inference data_transform)
4747
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry init)

paddle/operators/tensor.save

-462 Bytes
Binary file not shown.

paddle/platform/CMakeLists.txt

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,16 @@ ELSE()
2121
set(GPU_CTX_DEPS)
2222
ENDIF()
2323

24+
IF(WITH_MKLDNN)
25+
set(MKLDNN_CTX_DEPS mkldnn)
26+
ELSE()
27+
set(MKLDNN_CTX_DEPS)
28+
ENDIF()
29+
2430
# memcpy deoends on device_context, here add deps individually for
2531
# avoiding cycle dependencies
2632
cc_library(device_context SRCS device_context.cc DEPS memory buddy_allocator
27-
system_allocator memory_block meta_data meta_cache place eigen3 ${GPU_CTX_DEPS})
33+
system_allocator memory_block meta_data meta_cache place eigen3 ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS})
2834
nv_test(device_context_test SRCS device_context_test.cu DEPS device_context gpu_info)
2935

3036
nv_test(cudnn_helper_test SRCS cudnn_helper_test.cc DEPS dynload_cuda)

paddle/platform/device_context.cc

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,5 +168,69 @@ cudaStream_t CUDADeviceContext::stream() const { return stream_; }
168168

169169
#endif
170170

171+
#ifdef PADDLE_WITH_MKLDNN
172+
MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
173+
: CPUDeviceContext(place), ready_(false) {
174+
stream_.reset(new mkldnn::stream(mkldnn::stream::kind::eager));
175+
engine_.reset(new mkldnn::engine(mkldnn::engine::cpu, 0));
176+
}
177+
178+
template <typename T>
179+
void MKLDNNDeviceContext::AddElement(const std::string& op_key,
180+
const T& value) {
181+
if (GetElement<T>(op_key)) {
182+
return;
183+
}
184+
GetElementPool<T>().emplace(op_key, std::move(value));
185+
}
186+
187+
template <typename T>
188+
const T& MKLDNNDeviceContext::GetElement(const std::string& op_key) const {
189+
auto it = GetElementPool<T>().find(op_key);
190+
return it == GetElementPool<T>().end() ? nullptr : it->second;
191+
}
192+
193+
template <>
194+
const std::unordered_map<const std::string, const MKLDNNMemoryPtr,
195+
std::hash<std::string>>&
196+
MKLDNNDeviceContext::GetElementPool<MKLDNNMemoryPtr>() const {
197+
return memory_pool_;
198+
}
199+
200+
template <>
201+
const std::unordered_map<const std::string, const MKLDNNPrimitivePtr,
202+
std::hash<std::string>>&
203+
MKLDNNDeviceContext::GetElementPool<MKLDNNPrimitivePtr>() const {
204+
return primitive_pool_;
205+
}
206+
207+
template <>
208+
const std::unordered_map<const std::string, const MKLDNNPrimitiveDescPtr,
209+
std::hash<std::string>>&
210+
MKLDNNDeviceContext::GetElementPool<MKLDNNPrimitiveDescPtr>() const {
211+
return primitive_desc_pool_;
212+
}
213+
214+
void MKLDNNDeviceContext::Execute(bool block) {
215+
if (pipeline_.empty()) {
216+
return;
217+
}
218+
ResetStream();
219+
stream_->submit(pipeline_).wait(block);
220+
ready_ = false;
221+
pipeline_.clear();
222+
}
223+
224+
void MKLDNNDeviceContext::ResetStream() {
225+
if (ready_) {
226+
return;
227+
}
228+
// TODO(TJ): change me when mkldnn have specific method to reset this state
229+
stream_.reset(new mkldnn::stream(mkldnn::stream::kind::eager));
230+
ready_ = true;
231+
}
232+
233+
#endif
234+
171235
} // namespace platform
172236
} // namespace paddle

paddle/platform/device_context.h

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ limitations under the License. */
2121
#define EIGEN_USE_GPU
2222
#endif
2323

24+
#ifdef PADDLE_WITH_MKLDNN
25+
#include "paddle/platform/mkldnn_helper.h"
26+
#endif
27+
2428
#include "paddle/platform/enforce.h"
2529
#include "paddle/platform/place.h"
2630
#include "unsupported/Eigen/CXX11/Tensor"
@@ -105,6 +109,54 @@ struct DefaultDeviceContextType<platform::CUDAPlace> {
105109

106110
#endif
107111

112+
#ifdef PADDLE_WITH_MKLDNN
113+
class MKLDNNDeviceContext : public CPUDeviceContext {
114+
public:
115+
explicit MKLDNNDeviceContext(CPUPlace place);
116+
117+
/* \brief Add new element: memory, primitive or primitive desc */
118+
template <typename T>
119+
void AddElement(const std::string& op_key, const T& value);
120+
121+
/* \brief Get existed element: memory, primitive or primitive desc */
122+
template <typename T>
123+
const T& GetElement(const std::string& op_key) const;
124+
125+
/* \brief Get element pool: memory, primitive or primitive desc pool */
126+
template <typename T>
127+
const std::unordered_map<const std::string, const T, std::hash<std::string>>&
128+
GetElementPool() const;
129+
130+
/* \brief Get the active engine */
131+
const MKLDNNEngine& engine() const { return *engine_; }
132+
133+
/* \brief Submit primitive to pipeline */
134+
void Submit(const MKLDNNPrimitivePtr& p) { pipeline_.push_back(*p); }
135+
136+
/*! \brief Execute all submitted primitives in pipeline */
137+
void Execute(bool block = true);
138+
139+
protected:
140+
/*! \brief Reset the stream to prepare next exectue */
141+
void ResetStream();
142+
143+
private:
144+
std::unordered_map<const std::string, const MKLDNNMemoryPtr,
145+
std::hash<std::string>>
146+
memory_pool_;
147+
std::unordered_map<const std::string, const MKLDNNPrimitivePtr,
148+
std::hash<std::string>>
149+
primitive_pool_;
150+
std::unordered_map<const std::string, const MKLDNNPrimitiveDescPtr,
151+
std::hash<std::string>>
152+
primitive_desc_pool_;
153+
std::vector<MKLDNNPrimitive> pipeline_;
154+
MKLDNNStreamPtr stream_;
155+
MKLDNNEnginePtr engine_;
156+
bool ready_;
157+
};
158+
#endif
159+
108160
/*! \brief device context pool singleton */
109161
class DeviceContextPool {
110162
public:

paddle/platform/mkldnn_helper.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <mkldnn.hpp>
18+
19+
namespace paddle {
20+
namespace platform {
21+
22+
using MKLDNNStream = mkldnn::stream;
23+
using MKLDNNEngine = mkldnn::engine;
24+
using MKLDNNMemory = mkldnn::memory;
25+
using MKLDNNPrimitive = mkldnn::primitive;
26+
using MKLDNNPrimitiveDesc = mkldnn::handle<mkldnn_primitive_desc_t>;
27+
28+
typedef std::unique_ptr<MKLDNNStream> MKLDNNStreamPtr;
29+
typedef std::unique_ptr<MKLDNNEngine> MKLDNNEnginePtr;
30+
typedef std::unique_ptr<MKLDNNMemory> MKLDNNMemoryPtr;
31+
typedef std::unique_ptr<MKLDNNPrimitive> MKLDNNPrimitivePtr;
32+
typedef std::unique_ptr<MKLDNNPrimitiveDesc> MKLDNNPrimitiveDescPtr;
33+
34+
} // namespace platform
35+
} // namespace paddle

0 commit comments

Comments
 (0)