Skip to content

Commit 9ee698e

Browse files
authored
enhance/ditu rnn with fc fuse (#12831)
* make fc fuse work with ditu rnn * add ditu rnn data download to CMAKE
1 parent 78415f3 commit 9ee698e

File tree

10 files changed

+423
-26
lines changed

10 files changed

+423
-26
lines changed

paddle/fluid/framework/ir/graph_helper.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList(
104104
for (auto &adj_n : var->inputs) {
105105
PADDLE_ENFORCE(adj_n->NodeType() == ir::Node::Type::kOperation);
106106
adj_list[n].insert(adj_n);
107-
VLOG(3) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n)
107+
VLOG(4) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n)
108108
<< " -> " << n->Name() << reinterpret_cast<void *>(n)
109109
<< " via " << var->Name() << reinterpret_cast<void *>(var);
110110
}

paddle/fluid/inference/analysis/CMakeLists.txt

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ function (inference_analysis_test TARGET)
2222
if(WITH_TESTING)
2323
set(options "")
2424
set(oneValueArgs "")
25-
set(multiValueArgs SRCS)
25+
set(multiValueArgs SRCS EXTRA_DEPS)
2626
cmake_parse_arguments(analysis_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
2727

2828
set(mem_opt "")
@@ -31,22 +31,43 @@ function (inference_analysis_test TARGET)
3131
endif()
3232
cc_test(${TARGET}
3333
SRCS "${analysis_test_SRCS}"
34-
DEPS analysis graph fc_fuse_pass graph_viz_pass infer_clean_graph_pass graph_pattern_detecter pass
34+
DEPS analysis graph fc_fuse_pass graph_viz_pass infer_clean_graph_pass graph_pattern_detecter pass ${analysis_test_EXTRA_DEPS}
3535
ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model ${mem_opt})
3636
set_tests_properties(${TARGET} PROPERTIES DEPENDS test_word2vec)
3737
endif(WITH_TESTING)
3838
endfunction(inference_analysis_test)
3939

40-
cc_test(test_analyzer SRCS analyzer_tester.cc DEPS paddle_inference_api paddle_fluid_api ir_pass_manager analysis
41-
# ir
42-
fc_fuse_pass
43-
graph_viz_pass
44-
infer_clean_graph_pass
45-
graph_pattern_detecter
46-
pass
47-
ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model)
48-
#set_tests_properties(test_analyzer PROPERTIES DEPENDS test_word2vec)
49-
#inference_api_test(test_analyzer SRC analyzer_tester.cc ARGS test_word2vec)
40+
set(DITU_RNN_MODEL_URL "http://paddle-inference-dist.bj.bcebos.com/ditu_rnn_fluid%2Fmodel.tar.gz")
41+
set(DITU_RNN_DATA_URL "http://paddle-inference-dist.bj.bcebos.com/ditu_rnn_fluid%2Fdata.txt.tar.gz")
42+
set(DITU_INSTALL_DIR "${THIRD_PARTY_PATH}/install/ditu_rnn" CACHE PATH "Ditu RNN model and data root." FORCE)
43+
set(DITU_RNN_MODEL ${DITU_INSTALL_DIR}/model)
44+
set(DITU_RNN_DATA ${DITU_INSTALL_DIR}/data.txt)
45+
46+
function (inference_download_and_uncompress target url gz_filename)
47+
message(STATUS "Download inference test stuff ${gz_filename} from ${url}")
48+
execute_process(COMMAND bash -c "mkdir -p ${DITU_INSTALL_DIR}")
49+
execute_process(COMMAND bash -c "cd ${DITU_INSTALL_DIR} && wget -q ${url}")
50+
execute_process(COMMAND bash -c "cd ${DITU_INSTALL_DIR} && tar xzf ${gz_filename}")
51+
message(STATUS "finish downloading ${gz_filename}")
52+
endfunction(inference_download_and_uncompress)
53+
54+
if (NOT EXISTS ${DITU_INSTALL_DIR})
55+
inference_download_and_uncompress(ditu_rnn_model ${DITU_RNN_MODEL_URL} "ditu_rnn_fluid%2Fmodel.tar.gz")
56+
inference_download_and_uncompress(ditu_rnn_data ${DITU_RNN_DATA_URL} "ditu_rnn_fluid%2Fdata.txt.tar.gz")
57+
endif()
58+
59+
inference_analysis_test(test_analyzer SRCS analyzer_tester.cc
60+
EXTRA_DEPS paddle_inference_api paddle_fluid_api ir_pass_manager analysis
61+
# ir
62+
fc_fuse_pass
63+
graph_viz_pass
64+
infer_clean_graph_pass
65+
graph_pattern_detecter
66+
infer_clean_graph_pass
67+
pass
68+
ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model
69+
--infer_ditu_rnn_model=${DITU_INSTALL_DIR}/model
70+
--infer_ditu_rnn_data=${DITU_INSTALL_DIR}/data.txt)
5071

5172
inference_analysis_test(test_data_flow_graph SRCS data_flow_graph_tester.cc)
5273
inference_analysis_test(test_data_flow_graph_to_fluid_pass SRCS data_flow_graph_to_fluid_pass_tester.cc)

paddle/fluid/inference/analysis/analyzer.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@
2323
#include "paddle/fluid/inference/analysis/tensorrt_subgraph_node_mark_pass.h"
2424
#include "paddle/fluid/inference/analysis/tensorrt_subgraph_pass.h"
2525

26-
namespace paddle {
27-
2826
DEFINE_bool(IA_enable_tensorrt_subgraph_engine, false,
2927
"Enable subgraph to TensorRT engine for acceleration");
3028

@@ -35,6 +33,7 @@ DEFINE_string(IA_graphviz_log_root, "./",
3533

3634
DEFINE_string(IA_output_storage_path, "", "optimized model output path");
3735

36+
namespace paddle {
3837
namespace inference {
3938
namespace analysis {
4039

paddle/fluid/inference/analysis/analyzer.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,14 @@ limitations under the License. */
3939
#include "paddle/fluid/inference/analysis/pass.h"
4040
#include "paddle/fluid/inference/analysis/pass_manager.h"
4141

42-
namespace paddle {
43-
4442
// TODO(Superjomn) add a definition flag like PADDLE_WITH_TENSORRT and hide this
4543
// flag if not available.
4644
DECLARE_bool(IA_enable_tensorrt_subgraph_engine);
4745
DECLARE_string(IA_graphviz_log_root);
4846
DECLARE_string(IA_output_storage_path);
4947
DECLARE_bool(IA_enable_ir);
5048

49+
namespace paddle {
5150
namespace inference {
5251
namespace analysis {
5352

0 commit comments

Comments
 (0)