Skip to content

Commit 67c84c4

Browse files
authored
[IR] Support inplace pass (#56672)
* add code * add code * refine code * add code * fix bug * fix bug * fix bug * add code * add ut * polish code * fix bug * refine code * fix bug * refine code * fix bug * refine code * fix bug * refine code * fix bug * refine code * add code * fix bug * fix bug * fix bug * fix bug * fix bug * refine code
1 parent 11a526a commit 67c84c4

File tree

14 files changed

+484
-16
lines changed

14 files changed

+484
-16
lines changed

paddle/fluid/framework/new_executor/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ set(STANDALONE_EXECUTOR_DEPS
1616
phi_kernel_adaptor
1717
program_translator
1818
instruction_base
19+
pd_inplace_pass
1920
ir)
2021

2122
cc_library(

paddle/fluid/framework/new_executor/standalone_executor.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,24 @@
1616
#include "paddle/fluid/framework/new_executor/feed_fetch_utils.h"
1717
#include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h"
1818
#include "paddle/fluid/framework/new_executor/program_interpreter.h"
19+
#include "paddle/fluid/platform/flags.h"
1920
#include "paddle/fluid/platform/profiler/event_tracing.h"
2021

2122
#include "paddle/fluid/ir/transforms/pd_op_to_kernel_pass.h"
2223

24+
#include "paddle/fluid/ir/transforms/inplace_pass.h"
2325
#include "paddle/fluid/ir_adaptor/translator/translate.h"
2426
#include "paddle/ir/core/program.h"
27+
#include "paddle/ir/pass/pass.h"
28+
#include "paddle/ir/pass/pass_manager.h"
2529

2630
PHI_DECLARE_bool(enable_new_ir_in_executor);
2731
PHI_DECLARE_bool(enable_new_ir_api);
2832

33+
PADDLE_DEFINE_EXPORTED_bool(new_ir_apply_inplace_pass,
34+
true,
35+
"new ir kernel program apply inplace pass.");
36+
2937
namespace paddle {
3038
namespace framework {
3139
StandaloneExecutor::StandaloneExecutor(const platform::Place& place,
@@ -101,6 +109,13 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place,
101109
}
102110
auto kernel_program =
103111
paddle::dialect::PdOpLowerToKernelPass(base_program.get(), place);
112+
113+
if (FLAGS_new_ir_apply_inplace_pass) {
114+
ir::PassManager pm(ir::IrContext::Instance(), 3);
115+
pm.AddPass(ir::CreateInplacePass());
116+
pm.Run(kernel_program.get());
117+
}
118+
104119
interpretercores_.emplace_back(
105120
std::make_shared<InterpreterCore>(place_,
106121
fetch_var_names_,

paddle/fluid/ir/dialect/op_generator/op_gen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -856,7 +856,7 @@ def OpGenerator(
856856
op_infer_meta_map,
857857
muta_attr_is_input=False,
858858
)
859-
if len(op_attribute_name_list) > 1:
859+
if len(op_attribute_name_list) > 0:
860860
(
861861
build_args_with_attr_is_map_for_declare,
862862
build_func_with_attr_is_map,

paddle/fluid/ir/dialect/paddle_dialect/utils/op_yaml_info_parser.cc

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,15 +84,15 @@ const OpRunTimeInfo& OpYamlInfoParser::OpRuntimeInfo() const {
8484
return std::get<3>(op_info_tuple_);
8585
}
8686

87-
const std::map<std::string, int>& OpYamlInfoParser::InputName2Id() const {
87+
const std::map<std::string, uint32_t>& OpYamlInfoParser::InputName2Id() const {
8888
return input_name2id_;
8989
}
9090

91-
const std::map<std::string, int>& OpYamlInfoParser::OutputName2Id() const {
91+
const std::map<std::string, uint32_t>& OpYamlInfoParser::OutputName2Id() const {
9292
return output_name2id_;
9393
}
9494

95-
const std::vector<int>& OpYamlInfoParser::NoNeedBufferIds() const {
95+
const std::vector<uint32_t>& OpYamlInfoParser::NoNeedBufferIds() const {
9696
return no_need_buffer_ids_;
9797
}
9898

@@ -118,6 +118,17 @@ const std::string& OpYamlInfoParser::InplaceName(
118118
"Can not find inplace input of [%s].", out_name));
119119
}
120120

121+
std::unordered_map<uint32_t, uint32_t> OpYamlInfoParser::GetInplaceIdMap()
122+
const {
123+
std::unordered_map<uint32_t, uint32_t> inplace_id_map;
124+
auto& inplace_info = std::get<3>(op_info_tuple_).inplace;
125+
for (const auto& info : inplace_info) {
126+
inplace_id_map[OutputName2Id().at(info.first)] =
127+
InputName2Id().at(info.second);
128+
}
129+
return inplace_id_map;
130+
}
131+
121132
bool OpYamlInfoParser::HasView(const std::string& out_name) const {
122133
auto& view_info = std::get<3>(op_info_tuple_).view;
123134
for (size_t i = 0; i < view_info.size(); i++) {

paddle/fluid/ir/dialect/paddle_dialect/utils/op_yaml_info_parser.h

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@ class OpYamlInfoParser {
3434
const std::vector<std::string>& TensorParams(bool is_kernel = false) const;
3535
const std::vector<std::string>& AttrParams(bool is_kernel = false) const;
3636
const OpRunTimeInfo& OpRuntimeInfo() const;
37-
const std::map<std::string, int>& InputName2Id() const;
38-
const std::map<std::string, int>& OutputName2Id() const;
37+
const std::map<std::string, uint32_t>& InputName2Id() const;
38+
const std::map<std::string, uint32_t>& OutputName2Id() const;
3939

40-
const std::vector<int>& NoNeedBufferIds() const;
40+
const std::vector<uint32_t>& NoNeedBufferIds() const;
4141

4242
const std::vector<std::string>& InputNames() const {
4343
return input_name_list_;
@@ -53,6 +53,8 @@ class OpYamlInfoParser {
5353

5454
const std::string& InplaceName(const std::string& out_name) const;
5555

56+
std::unordered_map<uint32_t, uint32_t> GetInplaceIdMap() const;
57+
5658
bool HasView(const std::string& out_name) const;
5759

5860
const std::string& ViewName(const std::string& out_name) const;
@@ -68,20 +70,20 @@ class OpYamlInfoParser {
6870
OpInfoTuple op_info_tuple_;
6971

7072
// input info
71-
std::map<std::string, int> input_name2id_;
73+
std::map<std::string, uint32_t> input_name2id_;
7274
std::vector<std::string> input_name_list_;
7375
std::map<std::string, OpInputInfo> input_info_;
74-
int input_tensor_number_{0};
76+
uint32_t input_tensor_number_{0};
7577

7678
// no_need_buffer_ids
77-
std::vector<int> no_need_buffer_ids_;
79+
std::vector<uint32_t> no_need_buffer_ids_;
7880

7981
// attribute info
8082
std::vector<std::string> attribute_name_list_;
8183
std::map<std::string, OpAttributeInfo> attr_info_;
8284

8385
// output info
84-
std::map<std::string, int> output_name2id_;
86+
std::map<std::string, uint32_t> output_name2id_;
8587
std::vector<std::string> output_name_list_;
8688
std::map<std::string, OpOutputInfo> output_info_;
8789

paddle/fluid/ir/transforms/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,8 @@ cc_library(
1212
_constant_folding_pass
1313
SRCS constant_folding_pass.cc
1414
DEPS standalone_executor pd_op_to_kernel_pass transform_general_functions)
15+
16+
cc_library(
17+
pd_inplace_pass
18+
SRCS inplace_pass.cc
19+
DEPS pd_dialect_core op_yaml_info_parser)

0 commit comments

Comments
 (0)