@@ -46,15 +46,15 @@ void IRElementwiseSchedule(ir::IRSchedule &ir_sch, // NOLINT
46
46
<< ir_sch.GetModule ().GetExprs ().at (0 );
47
47
if (target == common::DefaultNVGPUTarget ()) {
48
48
auto blocks = ir_sch.GetAllBlocks ();
49
- ir_sch.FlattenLoops (ir_sch.GetLoops (blocks[0 ]), true );
49
+ std::vector<ir::Expr> loops = ir_sch.GetLoops (blocks[0 ]);
50
+ ir::Expr loop = ir_sch.Fuse (loops);
50
51
51
- auto loops = ir_sch.GetLoops (blocks[0 ]);
52
52
auto size = std::accumulate (
53
53
output_shape.begin (), output_shape.end (), 1 , std::multiplies<int >());
54
54
if (size <= target.max_num_threads ()) {
55
- ir_sch.Bind (loops[ 0 ] , " threadIdx.x" );
55
+ ir_sch.Bind (loop , " threadIdx.x" );
56
56
} else {
57
- auto splited = ir_sch.Split (loops[ 0 ] , {-1 , target.max_num_threads ()});
57
+ auto splited = ir_sch.Split (loop , {-1 , target.max_num_threads ()});
58
58
ir_sch.Bind (splited[0 ], " blockIdx.x" );
59
59
ir_sch.Bind (splited[1 ], " threadIdx.x" );
60
60
}
@@ -74,15 +74,15 @@ void IRInjectiveSchedule(ir::IRSchedule &ir_sch, // NOLINT
74
74
<< ir_sch.GetModule ().GetExprs ().at (0 );
75
75
if (target == common::DefaultNVGPUTarget ()) {
76
76
auto blocks = ir_sch.GetAllBlocks ();
77
- ir_sch.FlattenLoops (ir_sch.GetLoops (blocks[0 ]), false );
77
+ std::vector<ir::Expr> loops = ir_sch.GetLoops (blocks[0 ]);
78
+ ir::Expr loop = ir_sch.Fuse (loops);
78
79
79
- auto loops = ir_sch.GetLoops (blocks[0 ]);
80
80
auto size = std::accumulate (
81
81
output_shape.begin (), output_shape.end (), 1 , std::multiplies<int >());
82
82
if (size <= target.max_num_threads ()) {
83
- ir_sch.Bind (loops[ 0 ] , " threadIdx.x" );
83
+ ir_sch.Bind (loop , " threadIdx.x" );
84
84
} else {
85
- auto splited = ir_sch.Split (loops[ 0 ] , {-1 , target.max_num_threads ()});
85
+ auto splited = ir_sch.Split (loop , {-1 , target.max_num_threads ()});
86
86
ir_sch.Bind (splited[0 ], " blockIdx.x" );
87
87
ir_sch.Bind (splited[1 ], " threadIdx.x" );
88
88
}
0 commit comments