|
| 1 | +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +#include "paddle/fluid/framework/new_executor/instruction/cuda_graph_instruction.h" |
| 16 | + |
| 17 | +#include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h" |
| 18 | +#include "paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h" |
| 19 | +#include "paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h" |
| 20 | +#include "paddle/fluid/framework/new_executor/pir_interpreter.h" |
| 21 | +#include "paddle/fluid/framework/scope.h" |
| 22 | +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" |
| 23 | +#include "paddle/phi/core/platform/collective_helper.h" |
| 24 | +#include "paddle/phi/core/platform/cuda_graph_with_memory_pool.h" |
| 25 | +#include "paddle/phi/core/platform/device_context.h" |
| 26 | +#include "paddle/phi/core/tensor_utils.h" |
| 27 | +#include "paddle/phi/core/type_defs.h" |
| 28 | + |
| 29 | +#include "paddle/pir/include/core/builtin_attribute.h" |
| 30 | +#include "paddle/pir/include/core/operation.h" |
| 31 | +#include "paddle/pir/include/core/value.h" |
| 32 | + |
| 33 | +#include "paddle/fluid/framework/new_executor/instruction/instruction_util.h" |
| 34 | +#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" |
| 35 | + |
| 36 | +#ifdef PADDLE_WITH_CUDA || defined(PADDLE_WITH_HIP) |
| 37 | + |
| 38 | +namespace paddle::framework { |
| 39 | + |
| 40 | +CudaGraphInstruction::CudaGraphInstruction( |
| 41 | + size_t id, |
| 42 | + const phi::Place& place, |
| 43 | + pir::Operation* op, |
| 44 | + uint8_t* cuda_graph_state_ref, |
| 45 | + int64_t cuda_graph_capture_pool_id, |
| 46 | + ValueExecutionInfo* value_exec_info, |
| 47 | + interpreter::ExecutionConfig execution_config) |
| 48 | + : InstructionBase(id, place), |
| 49 | + op_(op), |
| 50 | + place_(place), |
| 51 | + cuda_graph_state_ref_(cuda_graph_state_ref), |
| 52 | + cuda_graph_capture_pool_id_(cuda_graph_capture_pool_id), |
| 53 | + name_("cuda_graph_instruction"), |
| 54 | + input_vars_(), |
| 55 | + output_vars_(), |
| 56 | + interpreter_(nullptr), |
| 57 | + skip_gc_names_() { |
| 58 | + PADDLE_ENFORCE(op->isa<paddle::dialect::CudaGraphOp>(), |
| 59 | + common::errors::PreconditionNotMet( |
| 60 | + "CudaGraph instruction only support cuda_graph op")); |
| 61 | + op_ = op; |
| 62 | + |
| 63 | + SetKernelType(OpFuncType::kGpuAsync); |
| 64 | + VLOG(6) << "finish process analyse kernel type"; |
| 65 | + |
| 66 | + auto cuda_graph_op = op->dyn_cast<paddle::dialect::CudaGraphOp>(); |
| 67 | + |
| 68 | + std::unordered_map<pir::Value, std::vector<int>> inputs; |
| 69 | + GetInputIds(op, *value_exec_info, &inputs); |
| 70 | + const auto outside_inputs = |
| 71 | + GetExternalInputs(cuda_graph_op.block(), *value_exec_info, &inputs); |
| 72 | + for (size_t i = 0; i < outside_inputs.size(); ++i) { |
| 73 | + input_vars_.push_back(value_exec_info->GetScope()->GetVar( |
| 74 | + value_exec_info->GetValue2VarName().at(outside_inputs.at(i)))); |
| 75 | + } |
| 76 | + VLOG(6) << "finish process input_vars"; |
| 77 | + |
| 78 | + for (size_t i = 0; i < cuda_graph_op.num_results(); ++i) { |
| 79 | + output_vars_.push_back(value_exec_info->GetScope()->GetVar( |
| 80 | + value_exec_info->GetValue2VarName().at(cuda_graph_op.result(i)))); |
| 81 | + } |
| 82 | + VLOG(6) << "finish process output_vars"; |
| 83 | + |
| 84 | + for (auto& item : inputs) { |
| 85 | + auto& var_vec = item.second; |
| 86 | + for (auto it = var_vec.begin(); it != var_vec.end();) { |
| 87 | + if (*it == -1) { |
| 88 | + it = var_vec.erase(it); |
| 89 | + } else { |
| 90 | + ++it; |
| 91 | + } |
| 92 | + } |
| 93 | + } |
| 94 | + SetInputs(inputs); |
| 95 | + |
| 96 | + std::unordered_map<pir::Value, std::vector<int>> outputs; |
| 97 | + bool is_last_op = true; |
| 98 | + for (size_t i = 0; i < op->num_results(); i++) { |
| 99 | + pir::Value value = op->result(i); |
| 100 | + if (value && value.type()) { |
| 101 | + PADDLE_ENFORCE_EQ( |
| 102 | + value_exec_info->HasValue(value), |
| 103 | + true, |
| 104 | + common::errors::PreconditionNotMet( |
| 105 | + "input should in name map, [%d] 'th input of [%s] op", |
| 106 | + i, |
| 107 | + "if op")); |
| 108 | + outputs.emplace(value, GetValueIds(value, *value_exec_info)); |
| 109 | + } |
| 110 | + if (value.use_count() > 0) { |
| 111 | + VLOG(6) << "value " << i << " use conutn != 0"; |
| 112 | + is_last_op = false; |
| 113 | + } |
| 114 | + } |
| 115 | + |
| 116 | + InsertInplacedExternalInputsToOuts( |
| 117 | + cuda_graph_op.block(), outside_inputs, *value_exec_info, &outputs); |
| 118 | + |
| 119 | + for (auto& item : outputs) { |
| 120 | + auto& var_vec = item.second; |
| 121 | + for (auto it = var_vec.begin(); it != var_vec.end();) { |
| 122 | + if (*it == -1) { |
| 123 | + it = var_vec.erase(it); |
| 124 | + } else { |
| 125 | + ++it; |
| 126 | + } |
| 127 | + } |
| 128 | + } |
| 129 | + SetOutputs(outputs); |
| 130 | + VLOG(6) << "finish process inputs outputs index"; |
| 131 | + |
| 132 | + Scope* scope = &(value_exec_info->GetScope()->NewScope()); |
| 133 | + auto skip_gc_vars = execution_config.skip_gc_vars; |
| 134 | + execution_config.skip_gc_vars.clear(); |
| 135 | + execution_config.create_local_scope = true; |
| 136 | + interpreter_ = new PirInterpreter(place, |
| 137 | + {}, |
| 138 | + cuda_graph_op.block(), |
| 139 | + scope, |
| 140 | + value_exec_info->NewChild(scope), |
| 141 | + execution_config); |
| 142 | + |
| 143 | + std::set<std::string> skip_gc_names_set; |
| 144 | + for (auto value : outside_inputs) { |
| 145 | + skip_gc_names_.push_back(interpreter_->GetNameByValue(value)); |
| 146 | + skip_gc_names_set.insert(interpreter_->GetNameByValue(value)); |
| 147 | + } |
| 148 | + for (const auto& var_name : skip_gc_vars) { |
| 149 | + skip_gc_names_.push_back(var_name); |
| 150 | + skip_gc_names_set.insert(var_name); |
| 151 | + } |
| 152 | + interpreter_->SetSkipGcVars(skip_gc_names_set); |
| 153 | + VLOG(6) << "finish process interpreter"; |
| 154 | +} |
| 155 | + |
| 156 | +CudaGraphInstruction::~CudaGraphInstruction() { delete interpreter_; } |
| 157 | + |
| 158 | +void CudaGraphInstruction::SetOutputHooks( |
| 159 | + const std::vector<PirHookFunc>& hookfuncs) { |
| 160 | + interpreter_->SetOutputHooks(hookfuncs); |
| 161 | +} |
| 162 | + |
| 163 | +void CudaGraphInstruction::SetInputHooks( |
| 164 | + const std::vector<PirHookFunc>& hookfuncs) { |
| 165 | + interpreter_->SetInputHooks(hookfuncs); |
| 166 | +} |
| 167 | + |
| 168 | +void CudaGraphInstruction::Run() { |
| 169 | + if (cuda_graph_ != nullptr && *cuda_graph_state_ref_ == 3) { |
| 170 | + VLOG(4) << "Start replaying cuda graph @" << cuda_graph_.get(); |
| 171 | + for (size_t i = 0; i < input_vars_.size(); ++i) { |
| 172 | + if (input_vars_[i]->IsType<phi::DenseTensor>()) { |
| 173 | + auto* tensor = input_vars_[i]->GetMutable<phi::DenseTensor>(); |
| 174 | + if (tensor->data() != input_tensors_.at(i).data()) { |
| 175 | + LOG(WARNING) << "The input [" << i << "] tensor addr for " |
| 176 | + << "cuda graph is changed. Pay attention to this!"; |
| 177 | + if (phi::is_gpu_place(tensor->place())) { |
| 178 | + const auto* dev_ctx = |
| 179 | + phi::DeviceContextPool::Instance().Get(place_); |
| 180 | + phi::Copy(*dev_ctx, *tensor, place_, false, &input_tensors_.at(i)); |
| 181 | + } |
| 182 | + } |
| 183 | + } |
| 184 | + } |
| 185 | + |
| 186 | + cuda_graph_->Replay(); |
| 187 | + |
| 188 | + // set the output tensors into scope |
| 189 | + for (size_t i = 0; i < output_vars_.size(); ++i) { |
| 190 | + *(output_vars_[i]->GetMutable<phi::DenseTensor>()) = |
| 191 | + output_tensors_.at(i); |
| 192 | + } |
| 193 | + VLOG(4) << "Finish replaying cuda graph"; |
| 194 | + return; |
| 195 | + } |
| 196 | + if (*cuda_graph_state_ref_ == 2 && cuda_graph_ == nullptr) { |
| 197 | + VLOG(4) << "Warmup before capturing"; |
| 198 | + interpreter_->Run({}, false); |
| 199 | + VLOG(4) << "Start capturing cuda graph ..."; |
| 200 | + platform::BeginCUDAGraphCapture( |
| 201 | + place_, cudaStreamCaptureModeRelaxed, cuda_graph_capture_pool_id_); |
| 202 | + |
| 203 | + auto RecordTensorsForReplay = [&](const std::vector<Variable*>& vars) { |
| 204 | + std::vector<phi::DenseTensor> record_tensors; |
| 205 | + record_tensors.reserve(vars.size()); |
| 206 | + for (auto& var : vars) { |
| 207 | + auto& tensor = var->Get<phi::DenseTensor>(); |
| 208 | + const auto& holder = tensor.Holder(); |
| 209 | + // Note: new_holder only record the memory address of the tensor for |
| 210 | + // cuda graph, original tensor memory will be freed to allocator after |
| 211 | + // graph capture. |
| 212 | + auto new_holder = std::make_shared<phi::Allocation>( |
| 213 | + holder->ptr(), holder->size(), holder->place()); |
| 214 | + record_tensors.emplace_back(new_holder, tensor.meta()); |
| 215 | + } |
| 216 | + return record_tensors; |
| 217 | + }; |
| 218 | + |
| 219 | + // record the input tensors for replay |
| 220 | + input_tensors_ = RecordTensorsForReplay(input_vars_); |
| 221 | + |
| 222 | + interpreter_->Run({}, false); |
| 223 | + |
| 224 | + // record the output tensors for replay |
| 225 | + output_tensors_ = RecordTensorsForReplay(output_vars_); |
| 226 | + |
| 227 | + cuda_graph_ = platform::EndCUDAGraphCapture(); |
| 228 | + VLOG(4) << "Finish capturing cuda graph @" << cuda_graph_.get(); |
| 229 | + |
| 230 | + // compute the right result |
| 231 | + cuda_graph_->Replay(); |
| 232 | + } else { |
| 233 | + VLOG(4) << "Run interpreter without cuda graph"; |
| 234 | + interpreter_->Run({}, false); |
| 235 | + } |
| 236 | +} |
| 237 | + |
| 238 | +} // namespace paddle::framework |
| 239 | + |
| 240 | +#endif // PADDLE_WITH_CUDA |
0 commit comments