|
| 1 | +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +#include "paddle/cinn/ir/group_schedule/tactic/tile_discrete_reduction_tactic.h" |
| 16 | +#include "paddle/cinn/adt/adt.h" |
| 17 | +#include "paddle/cinn/common/integer_set.h" |
| 18 | +#include "paddle/cinn/common/target.h" |
| 19 | +#include "paddle/cinn/ir/ir.h" |
| 20 | +#include "paddle/cinn/ir/ir_analyzer/ir_analyzer.h" |
| 21 | +#include "paddle/cinn/ir/schedule/ir_schedule_util.h" |
| 22 | + |
| 23 | +namespace cinn { |
| 24 | +namespace ir { |
| 25 | + |
| 26 | +using cinn::ir::analyzer::IsReductionSBlock; |
| 27 | +using BoundVariableMap = std::unordered_map<std::string, std::vector<Var>>; |
| 28 | + |
| 29 | +bool UseDiscreteDataTile(const ScheduleConfig& config) { |
| 30 | + // use discrete data tile for [RS] |
| 31 | + for (const auto& iter_space : config.base_info->iter_space_type) { |
| 32 | + if (iter_space.first == "R") { |
| 33 | + if (config.base_info->iter_space_type.back().first == "S") { |
| 34 | + return true; |
| 35 | + } |
| 36 | + } |
| 37 | + } |
| 38 | + return false; |
| 39 | +} |
| 40 | + |
| 41 | +class TileDiscreteReductionTactic final : public ScheduleTactic { |
| 42 | + public: |
| 43 | + void Init(ScheduleContext* context, ir::IRSchedule* sch) override; |
| 44 | + void Apply(ir::IRSchedule* sch, const std::string& block_id) override; |
| 45 | + std::string TacticName() const override { |
| 46 | + return "TileDiscreteReductionTactic"; |
| 47 | + } |
| 48 | + |
| 49 | + private: |
| 50 | + void MergeDiscreteFlattenAxis(ir::IRSchedule* sch, |
| 51 | + const std::string& block_id); |
| 52 | + void MergeReduceAxis(ir::IRSchedule* sch, const std::string& block_id); |
| 53 | + void SplitSptialInner(ir::IRSchedule* sch, const std::string& block_id); |
| 54 | + void SplitReduceInner(ir::IRSchedule* sch, const std::string& block_id); |
| 55 | + void VariableTypeAssignment(ir::IRSchedule* sch, const std::string& block_id); |
| 56 | + void SetDiscreteReduceType(ir::IRSchedule* sch, const std::string& block_id); |
| 57 | + void BindCudaInfo(ir::IRSchedule* sch, const std::string& block_id); |
| 58 | + |
| 59 | + private: |
| 60 | + ScheduleContext* context_; |
| 61 | + bool can_apply_; |
| 62 | + std::vector<int32_t> vec_spatial_axis_first_; |
| 63 | + std::vector<int32_t> vec_spatial_axis_last_; |
| 64 | + std::vector<int32_t> vec_flatten_axis_; |
| 65 | + std::vector<int32_t> vec_reduce_axis_; |
| 66 | + std::unordered_map<std::string, std::string> map_rf_block_; |
| 67 | + std::unordered_map<std::string, std::string> map_global_rf_block_; |
| 68 | +}; |
| 69 | + |
| 70 | +void TileDiscreteReductionTactic::Init(ScheduleContext* context, |
| 71 | + ir::IRSchedule* sch) { |
| 72 | + context_ = context; |
| 73 | + can_apply_ = false; |
| 74 | + |
| 75 | + // Check whether this group has been tiled by previous tactic. |
| 76 | + ir::Expr module_root = sch->GetModule().GetExprs().front(); |
| 77 | + ir::Expr root_block = ir::analyzer::GetRootSBlock(module_root); |
| 78 | + auto* root_node = root_block.As<ir::ScheduleBlockRealize>() |
| 79 | + ->schedule_block.As<ir::ScheduleBlock>(); |
| 80 | + if (root_node->attrs.count(kTileMethod) > 0) { |
| 81 | + return; |
| 82 | + } |
| 83 | + if (!UseDiscreteDataTile(context_->config)) { |
| 84 | + return; |
| 85 | + } |
| 86 | + can_apply_ = true; |
| 87 | + root_node->attrs[kTileMethod] = TacticName(); |
| 88 | + |
| 89 | + // reduce axes have been re-ordered to the last |
| 90 | + vec_flatten_axis_.clear(); |
| 91 | + vec_reduce_axis_.clear(); |
| 92 | + int data_rank = context_->config.base_info->loop_ranges.size(); |
| 93 | + int32_t reduce_start_idx = |
| 94 | + data_rank - context_->config.base_info->reduce_axis.size(); |
| 95 | + for (int32_t i = 0; i < data_rank; ++i) { |
| 96 | + if (i >= reduce_start_idx) { |
| 97 | + vec_reduce_axis_.push_back(i); |
| 98 | + } else { |
| 99 | + vec_flatten_axis_.push_back(i); |
| 100 | + } |
| 101 | + } |
| 102 | + vec_spatial_axis_first_.clear(); |
| 103 | + vec_spatial_axis_last_.clear(); |
| 104 | + |
| 105 | + if (!context_->config.base_info->reduce_axis.empty()) { |
| 106 | + int64_t first_reduce_axis = context_->config.base_info->reduce_axis.front(); |
| 107 | + for (auto axis : context_->config.base_info->reduce_axis) { |
| 108 | + if (context->config.base_info->loop_strides[axis] > |
| 109 | + context->config.base_info->loop_strides[first_reduce_axis]) { |
| 110 | + first_reduce_axis = axis; |
| 111 | + } |
| 112 | + } |
| 113 | + for (int32_t i = 0; i < reduce_start_idx; ++i) { |
| 114 | + if (i < first_reduce_axis) { |
| 115 | + vec_spatial_axis_first_.push_back(i); |
| 116 | + } else { |
| 117 | + vec_spatial_axis_last_.push_back(i); |
| 118 | + } |
| 119 | + } |
| 120 | + } |
| 121 | + |
| 122 | + map_rf_block_.clear(); |
| 123 | +} |
| 124 | + |
| 125 | +void TileDiscreteReductionTactic::Apply(ir::IRSchedule* sch, |
| 126 | + const std::string& block_id) { |
| 127 | + if (!can_apply_) return; |
| 128 | + if (ir::IsReduceInitTensorName(block_id)) return; |
| 129 | + |
| 130 | + MergeReduceAxis(sch, block_id); |
| 131 | + VLOG(6) << "After MergeReduceAxis on block: [" << block_id |
| 132 | + << "], loop nest:\n" |
| 133 | + << sch->GetLoops(block_id)[0]; |
| 134 | + MergeDiscreteFlattenAxis(sch, block_id); |
| 135 | + VLOG(6) << "After MergeDiscreteFlattenAxis on block: [" << block_id |
| 136 | + << "], loop nest:\n" |
| 137 | + << sch->GetLoops(block_id)[0]; |
| 138 | + SplitSptialInner(sch, block_id); |
| 139 | + VLOG(6) << "After SplitSptialInner on block: [" << block_id |
| 140 | + << "], loop nest:\n" |
| 141 | + << sch->GetLoops(block_id)[0]; |
| 142 | + SplitReduceInner(sch, block_id); |
| 143 | + VLOG(6) << "After SplitReduceInner on block: [" << block_id |
| 144 | + << "], loop nest:\n" |
| 145 | + << sch->GetLoops(block_id)[0]; |
| 146 | + BindCudaInfo(sch, block_id); |
| 147 | + VLOG(6) << "After BindCudaInfo on block: [" << block_id << "], loop nest:\n" |
| 148 | + << sch->GetLoops(block_id)[0]; |
| 149 | + VariableTypeAssignment(sch, block_id); |
| 150 | + VLOG(6) << "After VariableTypeAssignment on block: [" << block_id |
| 151 | + << "], loop nest:\n" |
| 152 | + << sch->GetLoops(block_id)[0]; |
| 153 | + SetDiscreteReduceType(sch, block_id); |
| 154 | +} |
| 155 | + |
| 156 | +void TileDiscreteReductionTactic::MergeDiscreteFlattenAxis( |
| 157 | + ir::IRSchedule* sch, const std::string& block_id) { |
| 158 | + // Note: We need to fuse loops from bottom to top, |
| 159 | + // because the loop index will be changed when the upper loops fused. |
| 160 | + if (vec_spatial_axis_last_.size() >= 2) { |
| 161 | + sch->Fuse(block_id, vec_spatial_axis_last_); |
| 162 | + } |
| 163 | + if (vec_spatial_axis_first_.size() >= 2) { |
| 164 | + sch->Fuse(block_id, vec_spatial_axis_first_); |
| 165 | + } |
| 166 | +} |
| 167 | + |
| 168 | +void TileDiscreteReductionTactic::MergeReduceAxis(ir::IRSchedule* sch, |
| 169 | + const std::string& block_id) { |
| 170 | + std::vector<ir::Expr> loops = sch->GetLoops(block_id); |
| 171 | + int32_t max_loop_idx = 0; |
| 172 | + for (int32_t idx : vec_reduce_axis_) { |
| 173 | + max_loop_idx = std::max(max_loop_idx, idx); |
| 174 | + PADDLE_ENFORCE_EQ(idx < loops.size() || loops.size() == 1, |
| 175 | + true, |
| 176 | + ::common::errors::InvalidArgument( |
| 177 | + "The reduce axis should meet: axis's idx < " |
| 178 | + "loops.size() or loops.size() == 1, but received " |
| 179 | + "idx= %d ,loops.size() = %d", |
| 180 | + idx, |
| 181 | + loops.size())); |
| 182 | + } |
| 183 | + if (max_loop_idx < loops.size() && vec_reduce_axis_.size() >= 2) { |
| 184 | + sch->Fuse(block_id, vec_reduce_axis_); |
| 185 | + } |
| 186 | +} |
| 187 | + |
| 188 | +void TileDiscreteReductionTactic::SplitSptialInner( |
| 189 | + ir::IRSchedule* sch, const std::string& block_id) { |
| 190 | + auto loops = sch->GetLoops(block_id); |
| 191 | + if (loops.size() == 3) { |
| 192 | + // [S, S', R] => [S, S'(-1), S'(32), R] |
| 193 | + auto split_loops = sch->Split(loops[1], std::vector<int>({-1, 32})); |
| 194 | + // [S, S'(-1), S'(32), R] => [S, S'(32), R] |
| 195 | + sch->Fuse(block_id, std::vector<int>{0, 1}); |
| 196 | + } else if (loops.size() == 2) { |
| 197 | + // [S, R] => [S(-1), S(32), R] |
| 198 | + auto split_loops = sch->Split(loops[0], std::vector<int>({-1, 32})); |
| 199 | + } |
| 200 | +} |
| 201 | + |
| 202 | +void TileDiscreteReductionTactic::SplitReduceInner( |
| 203 | + ir::IRSchedule* sch, const std::string& block_id) { |
| 204 | + const int64_t rd_block = context_->config.tile_config.grid_reduce_num; |
| 205 | + const int64_t rd_thread = 16; |
| 206 | + const int cur_reduce_axis = 2; |
| 207 | + |
| 208 | + // [ R ] => [ rd_block*rd_thread, rd_inner ] |
| 209 | + auto loops = sch->GetLoops(block_id); |
| 210 | + sch->Split(loops[cur_reduce_axis], |
| 211 | + std::vector<int>{-1, rd_block * rd_thread}); |
| 212 | + loops = sch->GetLoops(block_id); |
| 213 | + sch->Reorder({loops[cur_reduce_axis + 1], loops[cur_reduce_axis]}); |
| 214 | + |
| 215 | + loops = sch->GetLoops(block_id); |
| 216 | + if (IsReductionSBlock(sch->GetBlock(block_id)) && |
| 217 | + ir::GetLoopExtent(loops[2]) != 1) { |
| 218 | + ir::Expr rf_tensor = |
| 219 | + sch->FactorizeReduction(loops[cur_reduce_axis], |
| 220 | + /* rf_axis = */ 0, |
| 221 | + /* with_write_back_block_init = */ false); |
| 222 | + map_rf_block_[block_id] = rf_tensor.as_tensor_ref()->name; |
| 223 | + } |
| 224 | + |
| 225 | + // [ rd_block*rd_thread ] => [ rd_block, rd_thread ] |
| 226 | + if (rd_block > 1) { |
| 227 | + loops = sch->GetLoops(block_id); |
| 228 | + sch->Split(loops[cur_reduce_axis], {rd_block, rd_thread}); |
| 229 | + |
| 230 | + if (IsReductionSBlock(sch->GetBlock(block_id))) { |
| 231 | + loops = sch->GetLoops(map_rf_block_[block_id]); |
| 232 | + sch->Split(loops[cur_reduce_axis], {rd_block, rd_thread}); |
| 233 | + |
| 234 | + loops = sch->GetLoops(block_id); |
| 235 | + ir::Expr rf_tensor = |
| 236 | + sch->FactorizeReduction(loops[cur_reduce_axis], |
| 237 | + /* rf_axis = */ 0, |
| 238 | + /* with_write_back_block_init = */ false); |
| 239 | + std::string tensor_name = rf_tensor.as_tensor_ref()->name; |
| 240 | + map_global_rf_block_[block_id] = tensor_name; |
| 241 | + rf_tensor.as_tensor_ref()->WithBuffer("global", "_" + tensor_name); |
| 242 | + } |
| 243 | + } |
| 244 | +} |
| 245 | + |
| 246 | +void TileDiscreteReductionTactic::VariableTypeAssignment( |
| 247 | + ir::IRSchedule* sch, const std::string& block_id) { |
| 248 | + const auto IsOutputTensor = [&](const std::string& tensor_name) -> bool { |
| 249 | + return context_->output_names.count(tensor_name) > 0; |
| 250 | + }; |
| 251 | + const auto HasConsumers = [&](const ir::Expr& block) -> bool { |
| 252 | + return !ir::analyzer::GetConsumerSBlocks(block, sch->GetRootBlock(block)) |
| 253 | + .empty(); |
| 254 | + }; |
| 255 | + |
| 256 | + auto block = sch->GetBlock(block_id); |
| 257 | + if (!IsOutputTensor(block_id) && HasConsumers(block)) { |
| 258 | + sch->SetBuffer(block, "local", false); |
| 259 | + } |
| 260 | + |
| 261 | + if (map_rf_block_.count(block_id) > 0) { |
| 262 | + auto block = sch->GetBlock(map_rf_block_[block_id]); |
| 263 | + sch->SetBuffer(block, "local", false); |
| 264 | + } |
| 265 | +} |
| 266 | + |
| 267 | +void TileDiscreteReductionTactic::SetDiscreteReduceType( |
| 268 | + ir::IRSchedule* sch, const std::string& block_id) { |
| 269 | + if (IsReductionSBlock(sch->GetBlock(block_id))) { |
| 270 | + auto block = sch->GetBlock(block_id) |
| 271 | + .As<ir::ScheduleBlockRealize>() |
| 272 | + ->schedule_block.As<ir::ScheduleBlock>(); |
| 273 | + block->reduce_method = cinn::ir::DiscreteReduceMethod(); |
| 274 | + } |
| 275 | + if (map_global_rf_block_.count(block_id) > 0) { |
| 276 | + auto block = sch->GetBlock(map_global_rf_block_[block_id]) |
| 277 | + .As<ir::ScheduleBlockRealize>() |
| 278 | + ->schedule_block.As<ir::ScheduleBlock>(); |
| 279 | + block->reduce_method = cinn::ir::DiscreteReduceMethod(); |
| 280 | + } |
| 281 | +} |
| 282 | + |
| 283 | +void TileDiscreteReductionTactic::BindCudaInfo(ir::IRSchedule* sch, |
| 284 | + const std::string& block_id) { |
| 285 | + auto loops = sch->GetLoops(block_id); |
| 286 | + |
| 287 | + // [S(-1), S(32), R(16), R(-1)] => |
| 288 | + // [S(blockIdx.x), S(threadIdx.x), R(threadIdx.y), R(inner_loop)] |
| 289 | + const auto DoBind = [&](const std::vector<ir::Expr>& loops) { |
| 290 | + sch->Bind(loops[0], "blockIdx.x"); |
| 291 | + sch->Bind(loops[1], "threadIdx.x"); |
| 292 | + if (context_->config.tile_config.grid_reduce_num > 1) { |
| 293 | + sch->Bind(loops[2], "blockIdx.y"); |
| 294 | + if (loops.size() > 3) { |
| 295 | + sch->Bind(loops[3], "threadIdx.y"); |
| 296 | + } |
| 297 | + } else { |
| 298 | + sch->Bind(loops[2], "threadIdx.y"); |
| 299 | + } |
| 300 | + }; |
| 301 | + |
| 302 | + DoBind(sch->GetLoops(block_id)); |
| 303 | + |
| 304 | + if (map_rf_block_.count(block_id) > 0) { |
| 305 | + DoBind(sch->GetLoops(map_rf_block_[block_id])); |
| 306 | + } |
| 307 | + if (map_global_rf_block_.count(block_id) > 0) { |
| 308 | + DoBind(sch->GetLoops(map_global_rf_block_[block_id])); |
| 309 | + } |
| 310 | +} |
| 311 | + |
| 312 | +std::unique_ptr<ScheduleTactic> CreateTileDiscreteReductionTactic() { |
| 313 | + return std::make_unique<TileDiscreteReductionTactic>(); |
| 314 | +} |
| 315 | + |
| 316 | +} // namespace ir |
| 317 | +} // namespace cinn |
0 commit comments