Skip to content

Commit 13816dd

Browse files
jczajaluotao1
authored andcommitted
[MKL-DNN] Fix to crash of Transformer when mkldnn is to be used (#16233)
* - Fix to crash of Transformer when mkldnn is to be used Desc: TensorCopy was not setting MKLDNN primitive descriptor when layout was to be kMKLDNN test=develop * - Enable transformer for mkl-dnn test=develo * - Compilation fix test=develop * - Removed manual selection of MKL-DNN ops to be used in Transformer test test=develop
1 parent 7e20e76 commit 13816dd

File tree

3 files changed

+24
-3
lines changed

3 files changed

+24
-3
lines changed

paddle/fluid/framework/tensor_util.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place,
4444
<< dst_place;
4545
return;
4646
}
47+
#ifdef PADDLE_WITH_MKLDNN
48+
if (src.layout() == DataLayout::kMKLDNN) {
49+
dst->set_mkldnn_prim_desc(src.get_mkldnn_prim_desc());
50+
}
51+
#endif
4752
memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr,
4853
boost::get<platform::CPUPlace>(src_place), src_ptr, size);
4954
}

paddle/fluid/inference/tests/api/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ set(TRANSFORMER_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/transformer")
110110
download_model_and_data(${TRANSFORMER_INSTALL_DIR} "temp%2Ftransformer_model.tar.gz" "temp%2Ftransformer_data.txt.tar.gz")
111111
inference_analysis_test(test_analyzer_transformer SRCS analyzer_transformer_tester.cc
112112
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
113-
ARGS --infer_model=${TRANSFORMER_INSTALL_DIR}/model --infer_data=${TRANSFORMER_INSTALL_DIR}/data.txt --batch_size=8)
113+
ARGS --infer_model=${TRANSFORMER_INSTALL_DIR}/model --infer_data=${TRANSFORMER_INSTALL_DIR}/data.txt --batch_size=8 SERIAL)
114114

115115
# ocr
116116
set(OCR_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/ocr")

paddle/fluid/inference/tests/api/analyzer_transformer_tester.cc

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,17 +183,25 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
183183
}
184184

185185
// Easy for profiling independently.
186-
TEST(Analyzer_Transformer, profile) {
186+
void profile(bool use_mkldnn = false) {
187187
AnalysisConfig cfg;
188188
SetConfig(&cfg);
189189
std::vector<PaddleTensor> outputs;
190+
if (use_mkldnn) {
191+
cfg.EnableMKLDNN();
192+
}
190193

191194
std::vector<std::vector<PaddleTensor>> input_slots_all;
192195
SetInput(&input_slots_all);
193196
TestPrediction(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
194197
input_slots_all, &outputs, FLAGS_num_threads);
195198
}
196199

200+
TEST(Analyzer_Transformer, profile) { profile(); }
201+
#ifdef PADDLE_WITH_MKLDNN
202+
TEST(Analyzer_Transformer, profile_mkldnn) { profile(true); }
203+
#endif
204+
197205
// Check the fuse status
198206
TEST(Analyzer_Transformer, fuse_statis) {
199207
AnalysisConfig cfg;
@@ -206,15 +214,23 @@ TEST(Analyzer_Transformer, fuse_statis) {
206214
}
207215

208216
// Compare result of NativeConfig and AnalysisConfig
209-
TEST(Analyzer_Transformer, compare) {
217+
void compare(bool use_mkldnn = false) {
210218
AnalysisConfig cfg;
211219
SetConfig(&cfg);
220+
if (use_mkldnn) {
221+
cfg.EnableMKLDNN();
222+
}
212223

213224
std::vector<std::vector<PaddleTensor>> input_slots_all;
214225
SetInput(&input_slots_all);
215226
CompareNativeAndAnalysis(
216227
reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all);
217228
}
218229

230+
TEST(Analyzer_Transformer, compare) { compare(); }
231+
#ifdef PADDLE_WITH_MKLDNN
232+
TEST(Analyzer_Transformer, compare_mkldnn) { compare(true /* use_mkldnn */); }
233+
#endif
234+
219235
} // namespace inference
220236
} // namespace paddle

0 commit comments

Comments
 (0)