Skip to content

Commit c9ecd07

Browse files
BiynXuAurelius84
andauthored
Upgrade pir op lowerer into new group scheduler strategy (#59810)
* [PIR+CINN]Upgrade into New GroupScheduler Strategy upload debug code fix UT and switch C++ into new strategy * fix conflict * fix do_op_schedule * remove code * fix env * fix UT * [CINN] fix group schedule of keep dim reduction * fix group schedule name bug --------- Co-authored-by: Aurelius84 <[email protected]>
1 parent 70c4d21 commit c9ecd07

File tree

9 files changed

+84
-28
lines changed

9 files changed

+84
-28
lines changed

paddle/cinn/ast_gen_ius/ast_gen.cc

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
#include "paddle/cinn/lang/compute.h"
2222
#include "paddle/cinn/optim/replace_var_with_expr.h"
2323

24+
PD_DECLARE_bool(cinn_new_group_scheduler);
25+
2426
namespace cinn {
2527
namespace ast_gen_ius {
2628

@@ -92,11 +94,15 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) {
9294
// for same axis so we re-create objects
9395
std::vector<Var> axis_vars = cinn::common::GenDefaultAxis(axis_len);
9496
for (int i = 0; i < shape.size(); ++i) {
97+
if (FLAGS_cinn_new_group_scheduler && shape[i] == Expr(1)) {
98+
optim::ReplaceVarWithExpr(&init_body, axis[i], Expr(0));
99+
continue;
100+
}
95101
block_vars.push_back(Var(Expr(0),
96102
shape[i],
97103
cinn::UniqName("i" + std::to_string(i)),
98104
/*is_reduce = */ false));
99-
optim::ReplaceVarWithExpr(&init_body, axis[i], block_vars[i]);
105+
optim::ReplaceVarWithExpr(&init_body, axis[i], block_vars.back());
100106
axis_vars[i]->is_reduce_axis = false;
101107
if (shape[i] == Expr(1)) {
102108
iter_values.push_back(Expr(0));
@@ -120,6 +126,10 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) {
120126
// for same axis so we re-create objects
121127
std::vector<Var> reduce_axis_vars = cinn::common::GenDefaultAxis(axis_len);
122128
for (int i = 0; i < shape.size(); ++i) {
129+
if (FLAGS_cinn_new_group_scheduler && shape[i] == Expr(1)) {
130+
optim::ReplaceVarWithExpr(&reduce_body, axis[i], Expr(0));
131+
continue;
132+
}
123133
reduce_block_vars.push_back(Var(Expr(0),
124134
shape[i],
125135
cinn::UniqName("i" + std::to_string(i)),
@@ -142,12 +152,20 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) {
142152
reduce_axis_var->is_reduce_axis = true;
143153
reduce_iter_values.push_back(reduce_axis_var);
144154
}
155+
156+
int non_zero_axis_size = 0;
145157
for (int i = 0; i < axis.size(); ++i) {
146-
optim::ReplaceVarWithExpr(&reduce_body, axis[i], reduce_block_vars[i]);
147-
}
148-
for (int i = axis.size(); i < reduce_block_vars.size(); ++i) {
158+
if (FLAGS_cinn_new_group_scheduler && shape[i] == Expr(1)) {
159+
continue;
160+
}
149161
optim::ReplaceVarWithExpr(
150-
&reduce_body, reduce_axis[i - axis.size()], reduce_block_vars[i]);
162+
&reduce_body, axis[i], reduce_block_vars[non_zero_axis_size]);
163+
++non_zero_axis_size;
164+
}
165+
for (int i = non_zero_axis_size; i < reduce_block_vars.size(); ++i) {
166+
optim::ReplaceVarWithExpr(&reduce_body,
167+
reduce_axis[i - non_zero_axis_size],
168+
reduce_block_vars[i]);
151169
}
152170

153171
reduce_body = ir::ScheduleBlockRealize::Make(
@@ -166,6 +184,9 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) {
166184
// Put the two parts together
167185
ir::Expr body = ir::Block::Make({init_body, reduce_body});
168186
for (int i = static_cast<int>(axis_len) - 1; i >= 0; --i) {
187+
if (shape[i] == Expr(1)) {
188+
continue;
189+
}
169190
ir::Var loop_var = axis[i];
170191
ir::Expr loop_extent = shape[i];
171192
body = ir::For::Make(

paddle/cinn/hlir/framework/pir/op_lowering_impl.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ PD_DECLARE_bool(cinn_use_cuda_vectorize);
3636
PD_DECLARE_bool(cinn_enable_map_expr);
3737
PD_DECLARE_bool(cinn_enable_map_expr_schedule);
3838
PD_DECLARE_bool(cinn_bucket_compile);
39+
PD_DECLARE_bool(cinn_new_group_scheduler);
3940

4041
namespace cinn {
4142
namespace hlir {
@@ -659,6 +660,20 @@ ir::Expr OpLowererImpl::DoGroupSchedule(
659660
const GroupPtr& group,
660661
const std::unordered_map<::pir::Value, ir::Tensor>& tensor_map,
661662
const std::unordered_map<std::string, ir::Tensor>& tmp_tensor_info) {
663+
if (FLAGS_cinn_new_group_scheduler) {
664+
VLOG(3) << "using StaticShapeGroupScheduler to schedule group.";
665+
std::unordered_set<std::string> output_tensor_names;
666+
std::transform(
667+
group->output_ops.begin(),
668+
group->output_ops.end(),
669+
std::inserter(output_tensor_names, output_tensor_names.begin()),
670+
[&](::pir::Operation* op) { return ValueName(op->result(0)); });
671+
std::unique_ptr<ir::GroupScheduler> group_scheduler =
672+
ir::GroupScheduler::Make(
673+
&ir_sch, output_tensor_names, target_, /* is_dy_shape = */ false);
674+
group_scheduler->Schedule();
675+
return ir_sch.GetModule().GetExprs().at(0);
676+
}
662677
// topological order.
663678
auto ops_set = group->OpSet();
664679
auto v_consumers = BuildVirtualConsumer(group);

paddle/cinn/ir/group_schedule/st_shape_group_scheduler.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,16 +256,19 @@ void StaticShapeGroupScheduler::DoLoopAlignment() {
256256
std::vector<ir::Expr> indices =
257257
reduce_loads.begin()->As<ir::Load>()->indices;
258258
for (ir::Expr index : indices) {
259+
if (index.is_constant()) continue;
259260
CHECK_NOTNULL(index.as_var());
260261
int idx = 0;
261262
bool is_reduce_var = false;
262-
for (const ir::Var& iter_var : master_iter_vars) {
263+
for (int iter_idx = 0; iter_idx < master_iter_vars.size(); ++iter_idx) {
264+
auto& iter_var = master_iter_vars[iter_idx];
263265
if (iter_var->name == index.as_var_ref()->name) {
264266
is_reduce_var = iter_var->is_reduce_axis;
265267
break;
266268
}
267269
++idx;
268270
}
271+
if (master_iter_values[idx].is_constant()) continue;
269272
std::vector<ir::Var> loop_vars_in_order;
270273
ir::ir_utils::CollectIRNodesInOrder(
271274
master_iter_values[idx], [&](const ir::Expr* x) {

paddle/cinn/ir/schedule/factorize_reduction.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ namespace ir {
3333
Tensor CreateRFTensor(const Tensor& original_tensor,
3434
const Expr& rf_loop,
3535
int rf_axis) {
36-
std::string name = cinn::common::UniqName(original_tensor->name + "_rf");
36+
std::string name = original_tensor->name + "_rf";
3737
std::vector<Expr> new_shape = original_tensor->shape;
3838
new_shape.insert(new_shape.begin() + rf_axis, rf_loop.As<For>()->extent);
3939
Tensor rf_tensor = _Tensor_::Make(name,

paddle/cinn/ir/schedule/impl/compute_location.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@ void StScheduleImpl::SimpleComputeAt(const Expr& block, const Expr& loop) {
135135
std::vector<Var> replaced_var;
136136
std::vector<Expr> substitute_expr;
137137
for (int i = 0; i < loops.size(); ++i) {
138+
VLOG(3) << i << "-th loop is:\n " << loops[i];
139+
VLOG(3) << i << "-th block_loop:\n" << block_loops[i];
138140
CHECK_EQ(GetLoopExtent(loops[i]), GetLoopExtent(block_loops[i]));
139141
if (block_loops[i].As<ir::For>()->bind_info().valid() &&
140142
!loops[i].As<ir::For>()->bind_info().valid()) {

paddle/cinn/ir/schedule/ir_schedule_util.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1061,7 +1061,9 @@ struct FindBlocksVisitor {
10611061
if (expr->As<ir::For>()) {
10621062
Visit(&(expr->As<ir::For>()->body));
10631063
} else if (expr->As<ir::ScheduleBlockRealize>()) {
1064-
if (!expr->As<ir::ScheduleBlockRealize>()->iter_values.empty()) {
1064+
if (expr->As<ir::ScheduleBlockRealize>()
1065+
->schedule_block.As<ScheduleBlock>()
1066+
->name.substr(0, 4) != "root") {
10651067
auto* schedule_block = expr->As<ir::ScheduleBlockRealize>()
10661068
->schedule_block.As<ir::ScheduleBlock>();
10671069
if (block_name_.empty() || schedule_block->name == block_name_) {

test/cpp/pir/cinn/CMakeLists.txt

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,12 @@ add_subdirectory(adt)
33
if(WITH_TESTING AND WITH_CINN)
44
paddle_test(test_pir_compiler SRCS pir_compiler_test.cc DEPS pir_compiler
55
cinn_runtime_dialect)
6-
set_tests_properties(test_pir_compiler PROPERTIES LABELS "RUN_TYPE=CINN")
76

87
paddle_test(test_jit_instruction SRCS jit_instruction_test.cc DEPS
98
cinn_runtime_dialect pir_compiler)
10-
set_tests_properties(test_jit_instruction PROPERTIES LABELS "RUN_TYPE=CINN")
119

1210
paddle_test(
13-
dialect_convert_test
11+
test_dialect_convert
1412
SRCS
1513
dialect_convert_test.cc
1614
DEPS
@@ -19,10 +17,9 @@ if(WITH_TESTING AND WITH_CINN)
1917
op_dialect_vjp
2018
cinn_op_dialect
2119
pir)
22-
set_tests_properties(dialect_convert_test PROPERTIES LABELS "RUN_TYPE=CINN")
2320

2421
paddle_test(
25-
add_broadcast_to_elementwise_test
22+
test_add_broadcast_to_elementwise
2623
SRCS
2724
add_broadcast_to_elementwise_test.cc
2825
DEPS
@@ -32,7 +29,6 @@ if(WITH_TESTING AND WITH_CINN)
3229
cinn_op_dialect
3330
add_broadcast_to_elementwise_pass
3431
pir)
35-
set_tests_properties(dialect_convert_test PROPERTIES LABELS "RUN_TYPE=CINN")
3632

3733
paddle_test(
3834
test_sub_graph_extract
@@ -45,17 +41,15 @@ if(WITH_TESTING AND WITH_CINN)
4541
op_dialect_vjp
4642
pir_transforms
4743
pir)
48-
set_tests_properties(test_sub_graph_extract PROPERTIES LABELS "RUN_TYPE=CINN")
4944

5045
paddle_test(
51-
ir_op_fusion_test
46+
test_ir_op_fusion
5247
SRCS
5348
ir_op_fusion_test.cc
5449
DEPS
5550
op_with_group_merge_pass
5651
cinn_op_dialect
5752
pir)
58-
set_tests_properties(ir_op_fusion_test PROPERTIES LABELS "RUN_TYPE=CINN")
5953

6054
paddle_test(
6155
test_pir_all_path
@@ -67,7 +61,6 @@ if(WITH_TESTING AND WITH_CINN)
6761
cinn_op_dialect
6862
pd_to_cinn_pass
6963
add_broadcast_to_elementwise_pass)
70-
set_tests_properties(test_pir_all_path PROPERTIES LABELS "RUN_TYPE=CINN")
7164

7265
paddle_test(
7366
test_group_op
@@ -79,13 +72,34 @@ if(WITH_TESTING AND WITH_CINN)
7972
op_with_group_merge_pass
8073
cinn_op_dialect
8174
pir_transforms)
82-
set_tests_properties(test_group_op PROPERTIES LABELS "RUN_TYPE=CINN")
8375

8476
paddle_test(test_pir_build_cinn_pass SRCS build_cinn_pass_test.cc DEPS
8577
pir_transforms pir)
86-
set_tests_properties(test_pir_build_cinn_pass PROPERTIES LABELS
87-
"RUN_TYPE=CINN")
8878

8979
paddle_test(test_compilation_task SRCS compilation_task_test.cc DEPS pir)
90-
set_tests_properties(test_compilation_task PROPERTIES LABELS "RUN_TYPE=CINN")
80+
81+
# DO NOT forget add test name here, otherwise it will not be executed in
82+
# CINN CI.
83+
set(cinn_unit_tests
84+
test_pir_compiler
85+
test_jit_instruction
86+
test_dialect_convert
87+
test_add_broadcast_to_elementwise
88+
test_sub_graph_extract
89+
test_ir_op_fusion
90+
test_pir_all_path
91+
test_group_op
92+
test_pir_build_cinn_pass
93+
test_compilation_task)
94+
95+
foreach(test_name ${cinn_unit_tests})
96+
get_property(
97+
env
98+
TEST ${test_name}
99+
PROPERTY ENVIRONMENT)
100+
set_property(TEST ${test_name}
101+
PROPERTY ENVIRONMENT "FLAGS_cinn_new_group_scheduler=1" ${env})
102+
set_tests_properties(${test_name} PROPERTIES LABELS "RUN_TYPE=CINN")
103+
endforeach()
104+
91105
endif()

test/cpp/pir/cinn/add_broadcast_to_elementwise_test.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ TEST(PatternRewrite, broadcast_elementwise) {
9898
it++;
9999
CHECK_EQ(it->isa<paddle::dialect::FullOp>(), true);
100100
it++;
101-
CHECK_EQ(it->isa<cinn::dialect::BroadcastOp>(), true);
101+
CHECK_EQ(it->isa<paddle::dialect::FullOp>(), true);
102102
it++;
103103
CHECK_EQ(it->isa<paddle::dialect::AddOp>(), true);
104104
}
@@ -124,9 +124,9 @@ TEST(PatternRewrite, broadcast_elementwise_both) {
124124
it++;
125125
CHECK_EQ(it->isa<paddle::dialect::FullOp>(), true);
126126
it++;
127-
CHECK_EQ(it->isa<cinn::dialect::BroadcastOp>(), true);
127+
CHECK_EQ(it->isa<paddle::dialect::FullOp>(), true);
128128
it++;
129-
CHECK_EQ(it->isa<cinn::dialect::BroadcastOp>(), true);
129+
CHECK_EQ(it->isa<paddle::dialect::FullOp>(), true);
130130
it++;
131131
CHECK_EQ(it->isa<paddle::dialect::AddOp>(), true);
132132
}
@@ -152,9 +152,9 @@ TEST(PatternRewrite, broadcast_elementwise_sub_both) {
152152
it++;
153153
CHECK_EQ(it->isa<paddle::dialect::FullOp>(), true);
154154
it++;
155-
CHECK_EQ(it->isa<cinn::dialect::BroadcastOp>(), true);
155+
CHECK_EQ(it->isa<paddle::dialect::FullOp>(), true);
156156
it++;
157-
CHECK_EQ(it->isa<cinn::dialect::BroadcastOp>(), true);
157+
CHECK_EQ(it->isa<paddle::dialect::FullOp>(), true);
158158
it++;
159159
CHECK_EQ(it->isa<paddle::dialect::SubtractOp>(), true);
160160
}

test/cpp/pir/cinn/pir_compiler_test.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,6 @@ TEST(PirCompier, CompileSoftmax) {
207207
executor.Run({}, true);
208208
auto out_tensor =
209209
executor.local_scope()->FindVar("out@fetch")->Get<phi::DenseTensor>();
210-
211210
bool res0 = simple_cmp(out_tensor.data<float>()[0], 1.0 / 16);
212211
EXPECT_EQ(res0, true);
213212
}

0 commit comments

Comments
 (0)