Skip to content

Commit d5d7fe5

Browse files
authored
[CINN] Add The TileDiscreteReductionTactic (#71489)
1 parent 2ffcf1d commit d5d7fe5

File tree

5 files changed

+347
-180
lines changed

5 files changed

+347
-180
lines changed

paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "paddle/cinn/ir/group_schedule/tactic/compute_at_reduction_tactic.h"
2020
#include "paddle/cinn/ir/group_schedule/tactic/compute_inline_tactic.h"
2121
#include "paddle/cinn/ir/group_schedule/tactic/tile_broadcast_tactic.h"
22+
#include "paddle/cinn/ir/group_schedule/tactic/tile_discrete_reduction_tactic.h"
2223
#include "paddle/cinn/ir/group_schedule/tactic/tile_first_general_tactic.h"
2324
#include "paddle/cinn/ir/group_schedule/tactic/tile_transpose_tactic.h"
2425
#include "paddle/cinn/ir/ir_analyzer/ir_analyzer.h"
@@ -38,6 +39,7 @@ void DynamicShapeGroupScheduler::Init() {
3839
tactics_.emplace_back(CreateAlignIterSpaceTactic());
3940
tactics_.emplace_back(CreateTileBroadcastTactic());
4041
tactics_.emplace_back(CreateTileTransposeTactic());
42+
tactics_.emplace_back(CreateTileDiscreteReductionTactic());
4143
tactics_.emplace_back(CreateTileFirstGeneralTactic());
4244
tactics_.emplace_back(CreateComputeInlineTactic());
4345
tactics_.emplace_back(CreateComputeAtReductionTactic());

paddle/cinn/ir/group_schedule/tactic/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ gather_srcs(cinnapi_src SRCS bind_cuda_tactic.cc)
99
gather_srcs(cinnapi_src SRCS arrange_storage_tactic.cc)
1010
gather_srcs(cinnapi_src SRCS tile_broadcast_tactic.cc)
1111
gather_srcs(cinnapi_src SRCS tile_transpose_tactic.cc)
12+
gather_srcs(cinnapi_src SRCS tile_discrete_reduction_tactic.cc)
1213
gather_srcs(cinnapi_src SRCS tile_first_general_tactic.cc)
Lines changed: 317 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,317 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/cinn/ir/group_schedule/tactic/tile_discrete_reduction_tactic.h"
16+
#include "paddle/cinn/adt/adt.h"
17+
#include "paddle/cinn/common/integer_set.h"
18+
#include "paddle/cinn/common/target.h"
19+
#include "paddle/cinn/ir/ir.h"
20+
#include "paddle/cinn/ir/ir_analyzer/ir_analyzer.h"
21+
#include "paddle/cinn/ir/schedule/ir_schedule_util.h"
22+
23+
namespace cinn {
24+
namespace ir {
25+
26+
using cinn::ir::analyzer::IsReductionSBlock;
27+
using BoundVariableMap = std::unordered_map<std::string, std::vector<Var>>;
28+
29+
bool UseDiscreteDataTile(const ScheduleConfig& config) {
30+
// use discrete data tile for [RS]
31+
for (const auto& iter_space : config.base_info->iter_space_type) {
32+
if (iter_space.first == "R") {
33+
if (config.base_info->iter_space_type.back().first == "S") {
34+
return true;
35+
}
36+
}
37+
}
38+
return false;
39+
}
40+
41+
class TileDiscreteReductionTactic final : public ScheduleTactic {
42+
public:
43+
void Init(ScheduleContext* context, ir::IRSchedule* sch) override;
44+
void Apply(ir::IRSchedule* sch, const std::string& block_id) override;
45+
std::string TacticName() const override {
46+
return "TileDiscreteReductionTactic";
47+
}
48+
49+
private:
50+
void MergeDiscreteFlattenAxis(ir::IRSchedule* sch,
51+
const std::string& block_id);
52+
void MergeReduceAxis(ir::IRSchedule* sch, const std::string& block_id);
53+
void SplitSptialInner(ir::IRSchedule* sch, const std::string& block_id);
54+
void SplitReduceInner(ir::IRSchedule* sch, const std::string& block_id);
55+
void VariableTypeAssignment(ir::IRSchedule* sch, const std::string& block_id);
56+
void SetDiscreteReduceType(ir::IRSchedule* sch, const std::string& block_id);
57+
void BindCudaInfo(ir::IRSchedule* sch, const std::string& block_id);
58+
59+
private:
60+
ScheduleContext* context_;
61+
bool can_apply_;
62+
std::vector<int32_t> vec_spatial_axis_first_;
63+
std::vector<int32_t> vec_spatial_axis_last_;
64+
std::vector<int32_t> vec_flatten_axis_;
65+
std::vector<int32_t> vec_reduce_axis_;
66+
std::unordered_map<std::string, std::string> map_rf_block_;
67+
std::unordered_map<std::string, std::string> map_global_rf_block_;
68+
};
69+
70+
void TileDiscreteReductionTactic::Init(ScheduleContext* context,
71+
ir::IRSchedule* sch) {
72+
context_ = context;
73+
can_apply_ = false;
74+
75+
// Check whether this group has been tiled by previous tactic.
76+
ir::Expr module_root = sch->GetModule().GetExprs().front();
77+
ir::Expr root_block = ir::analyzer::GetRootSBlock(module_root);
78+
auto* root_node = root_block.As<ir::ScheduleBlockRealize>()
79+
->schedule_block.As<ir::ScheduleBlock>();
80+
if (root_node->attrs.count(kTileMethod) > 0) {
81+
return;
82+
}
83+
if (!UseDiscreteDataTile(context_->config)) {
84+
return;
85+
}
86+
can_apply_ = true;
87+
root_node->attrs[kTileMethod] = TacticName();
88+
89+
// reduce axes have been re-ordered to the last
90+
vec_flatten_axis_.clear();
91+
vec_reduce_axis_.clear();
92+
int data_rank = context_->config.base_info->loop_ranges.size();
93+
int32_t reduce_start_idx =
94+
data_rank - context_->config.base_info->reduce_axis.size();
95+
for (int32_t i = 0; i < data_rank; ++i) {
96+
if (i >= reduce_start_idx) {
97+
vec_reduce_axis_.push_back(i);
98+
} else {
99+
vec_flatten_axis_.push_back(i);
100+
}
101+
}
102+
vec_spatial_axis_first_.clear();
103+
vec_spatial_axis_last_.clear();
104+
105+
if (!context_->config.base_info->reduce_axis.empty()) {
106+
int64_t first_reduce_axis = context_->config.base_info->reduce_axis.front();
107+
for (auto axis : context_->config.base_info->reduce_axis) {
108+
if (context->config.base_info->loop_strides[axis] >
109+
context->config.base_info->loop_strides[first_reduce_axis]) {
110+
first_reduce_axis = axis;
111+
}
112+
}
113+
for (int32_t i = 0; i < reduce_start_idx; ++i) {
114+
if (i < first_reduce_axis) {
115+
vec_spatial_axis_first_.push_back(i);
116+
} else {
117+
vec_spatial_axis_last_.push_back(i);
118+
}
119+
}
120+
}
121+
122+
map_rf_block_.clear();
123+
}
124+
125+
void TileDiscreteReductionTactic::Apply(ir::IRSchedule* sch,
126+
const std::string& block_id) {
127+
if (!can_apply_) return;
128+
if (ir::IsReduceInitTensorName(block_id)) return;
129+
130+
MergeReduceAxis(sch, block_id);
131+
VLOG(6) << "After MergeReduceAxis on block: [" << block_id
132+
<< "], loop nest:\n"
133+
<< sch->GetLoops(block_id)[0];
134+
MergeDiscreteFlattenAxis(sch, block_id);
135+
VLOG(6) << "After MergeDiscreteFlattenAxis on block: [" << block_id
136+
<< "], loop nest:\n"
137+
<< sch->GetLoops(block_id)[0];
138+
SplitSptialInner(sch, block_id);
139+
VLOG(6) << "After SplitSptialInner on block: [" << block_id
140+
<< "], loop nest:\n"
141+
<< sch->GetLoops(block_id)[0];
142+
SplitReduceInner(sch, block_id);
143+
VLOG(6) << "After SplitReduceInner on block: [" << block_id
144+
<< "], loop nest:\n"
145+
<< sch->GetLoops(block_id)[0];
146+
BindCudaInfo(sch, block_id);
147+
VLOG(6) << "After BindCudaInfo on block: [" << block_id << "], loop nest:\n"
148+
<< sch->GetLoops(block_id)[0];
149+
VariableTypeAssignment(sch, block_id);
150+
VLOG(6) << "After VariableTypeAssignment on block: [" << block_id
151+
<< "], loop nest:\n"
152+
<< sch->GetLoops(block_id)[0];
153+
SetDiscreteReduceType(sch, block_id);
154+
}
155+
156+
void TileDiscreteReductionTactic::MergeDiscreteFlattenAxis(
157+
ir::IRSchedule* sch, const std::string& block_id) {
158+
// Note: We need to fuse loops from bottom to top,
159+
// because the loop index will be changed when the upper loops fused.
160+
if (vec_spatial_axis_last_.size() >= 2) {
161+
sch->Fuse(block_id, vec_spatial_axis_last_);
162+
}
163+
if (vec_spatial_axis_first_.size() >= 2) {
164+
sch->Fuse(block_id, vec_spatial_axis_first_);
165+
}
166+
}
167+
168+
void TileDiscreteReductionTactic::MergeReduceAxis(ir::IRSchedule* sch,
169+
const std::string& block_id) {
170+
std::vector<ir::Expr> loops = sch->GetLoops(block_id);
171+
int32_t max_loop_idx = 0;
172+
for (int32_t idx : vec_reduce_axis_) {
173+
max_loop_idx = std::max(max_loop_idx, idx);
174+
PADDLE_ENFORCE_EQ(idx < loops.size() || loops.size() == 1,
175+
true,
176+
::common::errors::InvalidArgument(
177+
"The reduce axis should meet: axis's idx < "
178+
"loops.size() or loops.size() == 1, but received "
179+
"idx= %d ,loops.size() = %d",
180+
idx,
181+
loops.size()));
182+
}
183+
if (max_loop_idx < loops.size() && vec_reduce_axis_.size() >= 2) {
184+
sch->Fuse(block_id, vec_reduce_axis_);
185+
}
186+
}
187+
188+
void TileDiscreteReductionTactic::SplitSptialInner(
189+
ir::IRSchedule* sch, const std::string& block_id) {
190+
auto loops = sch->GetLoops(block_id);
191+
if (loops.size() == 3) {
192+
// [S, S', R] => [S, S'(-1), S'(32), R]
193+
auto split_loops = sch->Split(loops[1], std::vector<int>({-1, 32}));
194+
// [S, S'(-1), S'(32), R] => [S, S'(32), R]
195+
sch->Fuse(block_id, std::vector<int>{0, 1});
196+
} else if (loops.size() == 2) {
197+
// [S, R] => [S(-1), S(32), R]
198+
auto split_loops = sch->Split(loops[0], std::vector<int>({-1, 32}));
199+
}
200+
}
201+
202+
void TileDiscreteReductionTactic::SplitReduceInner(
203+
ir::IRSchedule* sch, const std::string& block_id) {
204+
const int64_t rd_block = context_->config.tile_config.grid_reduce_num;
205+
const int64_t rd_thread = 16;
206+
const int cur_reduce_axis = 2;
207+
208+
// [ R ] => [ rd_block*rd_thread, rd_inner ]
209+
auto loops = sch->GetLoops(block_id);
210+
sch->Split(loops[cur_reduce_axis],
211+
std::vector<int>{-1, rd_block * rd_thread});
212+
loops = sch->GetLoops(block_id);
213+
sch->Reorder({loops[cur_reduce_axis + 1], loops[cur_reduce_axis]});
214+
215+
loops = sch->GetLoops(block_id);
216+
if (IsReductionSBlock(sch->GetBlock(block_id)) &&
217+
ir::GetLoopExtent(loops[2]) != 1) {
218+
ir::Expr rf_tensor =
219+
sch->FactorizeReduction(loops[cur_reduce_axis],
220+
/* rf_axis = */ 0,
221+
/* with_write_back_block_init = */ false);
222+
map_rf_block_[block_id] = rf_tensor.as_tensor_ref()->name;
223+
}
224+
225+
// [ rd_block*rd_thread ] => [ rd_block, rd_thread ]
226+
if (rd_block > 1) {
227+
loops = sch->GetLoops(block_id);
228+
sch->Split(loops[cur_reduce_axis], {rd_block, rd_thread});
229+
230+
if (IsReductionSBlock(sch->GetBlock(block_id))) {
231+
loops = sch->GetLoops(map_rf_block_[block_id]);
232+
sch->Split(loops[cur_reduce_axis], {rd_block, rd_thread});
233+
234+
loops = sch->GetLoops(block_id);
235+
ir::Expr rf_tensor =
236+
sch->FactorizeReduction(loops[cur_reduce_axis],
237+
/* rf_axis = */ 0,
238+
/* with_write_back_block_init = */ false);
239+
std::string tensor_name = rf_tensor.as_tensor_ref()->name;
240+
map_global_rf_block_[block_id] = tensor_name;
241+
rf_tensor.as_tensor_ref()->WithBuffer("global", "_" + tensor_name);
242+
}
243+
}
244+
}
245+
246+
void TileDiscreteReductionTactic::VariableTypeAssignment(
247+
ir::IRSchedule* sch, const std::string& block_id) {
248+
const auto IsOutputTensor = [&](const std::string& tensor_name) -> bool {
249+
return context_->output_names.count(tensor_name) > 0;
250+
};
251+
const auto HasConsumers = [&](const ir::Expr& block) -> bool {
252+
return !ir::analyzer::GetConsumerSBlocks(block, sch->GetRootBlock(block))
253+
.empty();
254+
};
255+
256+
auto block = sch->GetBlock(block_id);
257+
if (!IsOutputTensor(block_id) && HasConsumers(block)) {
258+
sch->SetBuffer(block, "local", false);
259+
}
260+
261+
if (map_rf_block_.count(block_id) > 0) {
262+
auto block = sch->GetBlock(map_rf_block_[block_id]);
263+
sch->SetBuffer(block, "local", false);
264+
}
265+
}
266+
267+
void TileDiscreteReductionTactic::SetDiscreteReduceType(
268+
ir::IRSchedule* sch, const std::string& block_id) {
269+
if (IsReductionSBlock(sch->GetBlock(block_id))) {
270+
auto block = sch->GetBlock(block_id)
271+
.As<ir::ScheduleBlockRealize>()
272+
->schedule_block.As<ir::ScheduleBlock>();
273+
block->reduce_method = cinn::ir::DiscreteReduceMethod();
274+
}
275+
if (map_global_rf_block_.count(block_id) > 0) {
276+
auto block = sch->GetBlock(map_global_rf_block_[block_id])
277+
.As<ir::ScheduleBlockRealize>()
278+
->schedule_block.As<ir::ScheduleBlock>();
279+
block->reduce_method = cinn::ir::DiscreteReduceMethod();
280+
}
281+
}
282+
283+
void TileDiscreteReductionTactic::BindCudaInfo(ir::IRSchedule* sch,
284+
const std::string& block_id) {
285+
auto loops = sch->GetLoops(block_id);
286+
287+
// [S(-1), S(32), R(16), R(-1)] =>
288+
// [S(blockIdx.x), S(threadIdx.x), R(threadIdx.y), R(inner_loop)]
289+
const auto DoBind = [&](const std::vector<ir::Expr>& loops) {
290+
sch->Bind(loops[0], "blockIdx.x");
291+
sch->Bind(loops[1], "threadIdx.x");
292+
if (context_->config.tile_config.grid_reduce_num > 1) {
293+
sch->Bind(loops[2], "blockIdx.y");
294+
if (loops.size() > 3) {
295+
sch->Bind(loops[3], "threadIdx.y");
296+
}
297+
} else {
298+
sch->Bind(loops[2], "threadIdx.y");
299+
}
300+
};
301+
302+
DoBind(sch->GetLoops(block_id));
303+
304+
if (map_rf_block_.count(block_id) > 0) {
305+
DoBind(sch->GetLoops(map_rf_block_[block_id]));
306+
}
307+
if (map_global_rf_block_.count(block_id) > 0) {
308+
DoBind(sch->GetLoops(map_global_rf_block_[block_id]));
309+
}
310+
}
311+
312+
std::unique_ptr<ScheduleTactic> CreateTileDiscreteReductionTactic() {
313+
return std::make_unique<TileDiscreteReductionTactic>();
314+
}
315+
316+
} // namespace ir
317+
} // namespace cinn
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// Copyright (c) 2024 CINN Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <string>
18+
#include "paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h"
19+
20+
namespace cinn {
21+
namespace ir {
22+
23+
std::unique_ptr<ScheduleTactic> CreateTileDiscreteReductionTactic();
24+
25+
} // namespace ir
26+
} // namespace cinn

0 commit comments

Comments
 (0)