Skip to content

Commit d38cd6c

Browse files
authored
[CINN][Fix] change Flatten to Fuse (#56719)
Change FlattenLoops in the elementwise schedule to Fuse
1 parent a28e6f6 commit d38cd6c

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

paddle/cinn/hlir/pe/ir_schedule_pe.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,15 @@ void IRElementwiseSchedule(ir::IRSchedule &ir_sch, // NOLINT
4646
<< ir_sch.GetModule().GetExprs().at(0);
4747
if (target == common::DefaultNVGPUTarget()) {
4848
auto blocks = ir_sch.GetAllBlocks();
49-
ir_sch.FlattenLoops(ir_sch.GetLoops(blocks[0]), true);
49+
std::vector<ir::Expr> loops = ir_sch.GetLoops(blocks[0]);
50+
ir::Expr loop = ir_sch.Fuse(loops);
5051

51-
auto loops = ir_sch.GetLoops(blocks[0]);
5252
auto size = std::accumulate(
5353
output_shape.begin(), output_shape.end(), 1, std::multiplies<int>());
5454
if (size <= target.max_num_threads()) {
55-
ir_sch.Bind(loops[0], "threadIdx.x");
55+
ir_sch.Bind(loop, "threadIdx.x");
5656
} else {
57-
auto splited = ir_sch.Split(loops[0], {-1, target.max_num_threads()});
57+
auto splited = ir_sch.Split(loop, {-1, target.max_num_threads()});
5858
ir_sch.Bind(splited[0], "blockIdx.x");
5959
ir_sch.Bind(splited[1], "threadIdx.x");
6060
}
@@ -74,15 +74,15 @@ void IRInjectiveSchedule(ir::IRSchedule &ir_sch, // NOLINT
7474
<< ir_sch.GetModule().GetExprs().at(0);
7575
if (target == common::DefaultNVGPUTarget()) {
7676
auto blocks = ir_sch.GetAllBlocks();
77-
ir_sch.FlattenLoops(ir_sch.GetLoops(blocks[0]), false);
77+
std::vector<ir::Expr> loops = ir_sch.GetLoops(blocks[0]);
78+
ir::Expr loop = ir_sch.Fuse(loops);
7879

79-
auto loops = ir_sch.GetLoops(blocks[0]);
8080
auto size = std::accumulate(
8181
output_shape.begin(), output_shape.end(), 1, std::multiplies<int>());
8282
if (size <= target.max_num_threads()) {
83-
ir_sch.Bind(loops[0], "threadIdx.x");
83+
ir_sch.Bind(loop, "threadIdx.x");
8484
} else {
85-
auto splited = ir_sch.Split(loops[0], {-1, target.max_num_threads()});
85+
auto splited = ir_sch.Split(loop, {-1, target.max_num_threads()});
8686
ir_sch.Bind(splited[0], "blockIdx.x");
8787
ir_sch.Bind(splited[1], "threadIdx.x");
8888
}

test/legacy_test/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1061,7 +1061,7 @@ set_tests_properties(
10611061
PROPERTIES TIMEOUT 120)
10621062
set_tests_properties(test_conv_nn_grad PROPERTIES TIMEOUT 120)
10631063
set_tests_properties(test_program_prune_backward PROPERTIES TIMEOUT 120)
1064-
set_tests_properties(test_group_norm_op PROPERTIES TIMEOUT 300)
1064+
set_tests_properties(test_group_norm_op PROPERTIES TIMEOUT 1000)
10651065
set_tests_properties(test_imperative_optimizer PROPERTIES TIMEOUT 250)
10661066
set_tests_properties(test_imperative_optimizer_v2 PROPERTIES TIMEOUT 250)
10671067
set_tests_properties(test_pool2d_op PROPERTIES TIMEOUT 120)

0 commit comments

Comments
 (0)