Skip to content

Commit 67dc159

Browse files
zyfncgSigureMo
andauthored
Add cuda_graph op and pass (#73393)
* add cuda_graph op and pass * add cuda_graph instruction * [Dy2St] Support pass cuda graph state and dispatch key in run program op * prepare all needed names at run program beginning * fix bug * mark override * pass names to run program op * save code * add missing op_function_common.cc * fix bug * optimize code * fix complie bug * fix complie bug * refine code * add hip define * fix add cuda_graph_dispatch_key to scope cache key * add more log info * delete useless include --------- Co-authored-by: SigureMo <[email protected]>
1 parent 6a5494b commit 67dc159

File tree

15 files changed

+666
-3
lines changed

15 files changed

+666
-3
lines changed

paddle/common/flags.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1195,6 +1195,19 @@ PHI_DEFINE_EXPORTED_bool(
11951195
"cudaGraphInstantiateFlagAutoFreeOnLaunch so it would automatically "
11961196
"release graph-owned blocks that have not freed before relaunching.");
11971197

1198+
/*
1199+
* CUDA Graph related FLAG
1200+
* Name: FLAGS_cuda_graph_blacklist
1201+
* Since Version: 3.1
1202+
* Value Range: string, default=""
1203+
* Example: FLAGS_cuda_graph_blacklist="op1,op2,op3" would
1204+
* blacklist op1, op2, op3 from being captured in CUDA Graph.
1205+
*/
1206+
PHI_DEFINE_EXPORTED_string(
1207+
cuda_graph_blacklist,
1208+
"",
1209+
"CUDA Graph blacklist, split by ',', e.g., 'op1,op2,op3'");
1210+
11981211
/*
11991212
* Executor related FLAG
12001213
* Name: FLAGS_executor_log_deps_every_microseconds

paddle/fluid/eager/to_static/run_program_op_node.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "paddle/fluid/ir_adaptor/translator/program_translator.h"
2626
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
2727
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
28+
#include "paddle/fluid/pir/transforms/cuda_graph_extract_pass.h"
2829
#include "paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h"
2930
#include "paddle/fluid/pir/utils/name_analysis.h"
3031
#include "paddle/fluid/platform/enforce.h"
@@ -35,6 +36,7 @@
3536
#include "paddle/pir/include/core/builtin_attribute.h"
3637
#include "paddle/pir/include/core/program.h"
3738
#include "paddle/pir/include/core/value.h"
39+
#include "paddle/pir/include/pass/pass_manager.h"
3840

3941
#ifdef PADDLE_WITH_DNNL
4042
#include "paddle/fluid/platform/onednn_helper.h"
@@ -494,6 +496,18 @@ inline void PirRunProgramAPI(
494496
}
495497
}
496498
}
499+
500+
auto program = forward_program;
501+
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
502+
if (details::is_use_cuda_graph(cuda_graph_state)) {
503+
pir::PassManager pass_pm(::pir::IrContext::Instance(), 3);
504+
pass_pm.AddPass(pir::CreateCudaGraphExtractPass());
505+
pir::IrMapping ir_mapping;
506+
program = forward_program->Clone(ir_mapping);
507+
pass_pm.Run(program.get());
508+
}
509+
#endif
510+
497511
auto passed_kernel_program = paddle::framework::ApplyIrPass(
498512
forward_program.get(), place, no_need_buffer_name_set);
499513
const auto &new_block = passed_kernel_program->block();
@@ -505,6 +519,9 @@ inline void PirRunProgramAPI(
505519
global_inner_scope,
506520
cache_key,
507521
in_sot_mode);
522+
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
523+
interpreter_core->SetCUDAGraphState(static_cast<uint8_t>(cuda_graph_state));
524+
#endif
508525
// Step 4. get all eager gc vars (skip_names = backward_inputs -
509526
// no_need_buffers + outputs)
510527
std::vector<std::string> skip_names;
@@ -529,6 +546,9 @@ inline void PirRunProgramAPI(
529546
// Step 1. get cache interpretercore
530547
auto &cached_value = cache.GetMutable(cache_key);
531548
interpreter_core = cached_value.core_;
549+
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
550+
interpreter_core->SetCUDAGraphState(static_cast<uint8_t>(cuda_graph_state));
551+
#endif
532552
// Step 2. update scope for cache interpretercore
533553
details::ShareTensorsIntoScopeWithName(x, input_names, global_inner_scope);
534554
details::ShareTensorsIntoScopeWithName(
Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/new_executor/instruction/cuda_graph_instruction.h"
16+
17+
#include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h"
18+
#include "paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h"
19+
#include "paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h"
20+
#include "paddle/fluid/framework/new_executor/pir_interpreter.h"
21+
#include "paddle/fluid/framework/scope.h"
22+
#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
23+
#include "paddle/phi/core/platform/collective_helper.h"
24+
#include "paddle/phi/core/platform/cuda_graph_with_memory_pool.h"
25+
#include "paddle/phi/core/platform/device_context.h"
26+
#include "paddle/phi/core/tensor_utils.h"
27+
#include "paddle/phi/core/type_defs.h"
28+
29+
#include "paddle/pir/include/core/builtin_attribute.h"
30+
#include "paddle/pir/include/core/operation.h"
31+
#include "paddle/pir/include/core/value.h"
32+
33+
#include "paddle/fluid/framework/new_executor/instruction/instruction_util.h"
34+
#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h"
35+
36+
#ifdef PADDLE_WITH_CUDA || defined(PADDLE_WITH_HIP)
37+
38+
namespace paddle::framework {
39+
40+
CudaGraphInstruction::CudaGraphInstruction(
41+
size_t id,
42+
const phi::Place& place,
43+
pir::Operation* op,
44+
uint8_t* cuda_graph_state_ref,
45+
int64_t cuda_graph_capture_pool_id,
46+
ValueExecutionInfo* value_exec_info,
47+
interpreter::ExecutionConfig execution_config)
48+
: InstructionBase(id, place),
49+
op_(op),
50+
place_(place),
51+
cuda_graph_state_ref_(cuda_graph_state_ref),
52+
cuda_graph_capture_pool_id_(cuda_graph_capture_pool_id),
53+
name_("cuda_graph_instruction"),
54+
input_vars_(),
55+
output_vars_(),
56+
interpreter_(nullptr),
57+
skip_gc_names_() {
58+
PADDLE_ENFORCE(op->isa<paddle::dialect::CudaGraphOp>(),
59+
common::errors::PreconditionNotMet(
60+
"CudaGraph instruction only support cuda_graph op"));
61+
op_ = op;
62+
63+
SetKernelType(OpFuncType::kGpuAsync);
64+
VLOG(6) << "finish process analyse kernel type";
65+
66+
auto cuda_graph_op = op->dyn_cast<paddle::dialect::CudaGraphOp>();
67+
68+
std::unordered_map<pir::Value, std::vector<int>> inputs;
69+
GetInputIds(op, *value_exec_info, &inputs);
70+
const auto outside_inputs =
71+
GetExternalInputs(cuda_graph_op.block(), *value_exec_info, &inputs);
72+
for (size_t i = 0; i < outside_inputs.size(); ++i) {
73+
input_vars_.push_back(value_exec_info->GetScope()->GetVar(
74+
value_exec_info->GetValue2VarName().at(outside_inputs.at(i))));
75+
}
76+
VLOG(6) << "finish process input_vars";
77+
78+
for (size_t i = 0; i < cuda_graph_op.num_results(); ++i) {
79+
output_vars_.push_back(value_exec_info->GetScope()->GetVar(
80+
value_exec_info->GetValue2VarName().at(cuda_graph_op.result(i))));
81+
}
82+
VLOG(6) << "finish process output_vars";
83+
84+
for (auto& item : inputs) {
85+
auto& var_vec = item.second;
86+
for (auto it = var_vec.begin(); it != var_vec.end();) {
87+
if (*it == -1) {
88+
it = var_vec.erase(it);
89+
} else {
90+
++it;
91+
}
92+
}
93+
}
94+
SetInputs(inputs);
95+
96+
std::unordered_map<pir::Value, std::vector<int>> outputs;
97+
bool is_last_op = true;
98+
for (size_t i = 0; i < op->num_results(); i++) {
99+
pir::Value value = op->result(i);
100+
if (value && value.type()) {
101+
PADDLE_ENFORCE_EQ(
102+
value_exec_info->HasValue(value),
103+
true,
104+
common::errors::PreconditionNotMet(
105+
"input should in name map, [%d] 'th input of [%s] op",
106+
i,
107+
"if op"));
108+
outputs.emplace(value, GetValueIds(value, *value_exec_info));
109+
}
110+
if (value.use_count() > 0) {
111+
VLOG(6) << "value " << i << " use conutn != 0";
112+
is_last_op = false;
113+
}
114+
}
115+
116+
InsertInplacedExternalInputsToOuts(
117+
cuda_graph_op.block(), outside_inputs, *value_exec_info, &outputs);
118+
119+
for (auto& item : outputs) {
120+
auto& var_vec = item.second;
121+
for (auto it = var_vec.begin(); it != var_vec.end();) {
122+
if (*it == -1) {
123+
it = var_vec.erase(it);
124+
} else {
125+
++it;
126+
}
127+
}
128+
}
129+
SetOutputs(outputs);
130+
VLOG(6) << "finish process inputs outputs index";
131+
132+
Scope* scope = &(value_exec_info->GetScope()->NewScope());
133+
auto skip_gc_vars = execution_config.skip_gc_vars;
134+
execution_config.skip_gc_vars.clear();
135+
execution_config.create_local_scope = true;
136+
interpreter_ = new PirInterpreter(place,
137+
{},
138+
cuda_graph_op.block(),
139+
scope,
140+
value_exec_info->NewChild(scope),
141+
execution_config);
142+
143+
std::set<std::string> skip_gc_names_set;
144+
for (auto value : outside_inputs) {
145+
skip_gc_names_.push_back(interpreter_->GetNameByValue(value));
146+
skip_gc_names_set.insert(interpreter_->GetNameByValue(value));
147+
}
148+
for (const auto& var_name : skip_gc_vars) {
149+
skip_gc_names_.push_back(var_name);
150+
skip_gc_names_set.insert(var_name);
151+
}
152+
interpreter_->SetSkipGcVars(skip_gc_names_set);
153+
VLOG(6) << "finish process interpreter";
154+
}
155+
156+
CudaGraphInstruction::~CudaGraphInstruction() { delete interpreter_; }
157+
158+
void CudaGraphInstruction::SetOutputHooks(
159+
const std::vector<PirHookFunc>& hookfuncs) {
160+
interpreter_->SetOutputHooks(hookfuncs);
161+
}
162+
163+
void CudaGraphInstruction::SetInputHooks(
164+
const std::vector<PirHookFunc>& hookfuncs) {
165+
interpreter_->SetInputHooks(hookfuncs);
166+
}
167+
168+
void CudaGraphInstruction::Run() {
169+
if (cuda_graph_ != nullptr && *cuda_graph_state_ref_ == 3) {
170+
VLOG(4) << "Start replaying cuda graph @" << cuda_graph_.get();
171+
for (size_t i = 0; i < input_vars_.size(); ++i) {
172+
if (input_vars_[i]->IsType<phi::DenseTensor>()) {
173+
auto* tensor = input_vars_[i]->GetMutable<phi::DenseTensor>();
174+
if (tensor->data() != input_tensors_.at(i).data()) {
175+
LOG(WARNING) << "The input [" << i << "] tensor addr for "
176+
<< "cuda graph is changed. Pay attention to this!";
177+
if (phi::is_gpu_place(tensor->place())) {
178+
const auto* dev_ctx =
179+
phi::DeviceContextPool::Instance().Get(place_);
180+
phi::Copy(*dev_ctx, *tensor, place_, false, &input_tensors_.at(i));
181+
}
182+
}
183+
}
184+
}
185+
186+
cuda_graph_->Replay();
187+
188+
// set the output tensors into scope
189+
for (size_t i = 0; i < output_vars_.size(); ++i) {
190+
*(output_vars_[i]->GetMutable<phi::DenseTensor>()) =
191+
output_tensors_.at(i);
192+
}
193+
VLOG(4) << "Finish replaying cuda graph";
194+
return;
195+
}
196+
if (*cuda_graph_state_ref_ == 2 && cuda_graph_ == nullptr) {
197+
VLOG(4) << "Warmup before capturing";
198+
interpreter_->Run({}, false);
199+
VLOG(4) << "Start capturing cuda graph ...";
200+
platform::BeginCUDAGraphCapture(
201+
place_, cudaStreamCaptureModeRelaxed, cuda_graph_capture_pool_id_);
202+
203+
auto RecordTensorsForReplay = [&](const std::vector<Variable*>& vars) {
204+
std::vector<phi::DenseTensor> record_tensors;
205+
record_tensors.reserve(vars.size());
206+
for (auto& var : vars) {
207+
auto& tensor = var->Get<phi::DenseTensor>();
208+
const auto& holder = tensor.Holder();
209+
// Note: new_holder only record the memory address of the tensor for
210+
// cuda graph, original tensor memory will be freed to allocator after
211+
// graph capture.
212+
auto new_holder = std::make_shared<phi::Allocation>(
213+
holder->ptr(), holder->size(), holder->place());
214+
record_tensors.emplace_back(new_holder, tensor.meta());
215+
}
216+
return record_tensors;
217+
};
218+
219+
// record the input tensors for replay
220+
input_tensors_ = RecordTensorsForReplay(input_vars_);
221+
222+
interpreter_->Run({}, false);
223+
224+
// record the output tensors for replay
225+
output_tensors_ = RecordTensorsForReplay(output_vars_);
226+
227+
cuda_graph_ = platform::EndCUDAGraphCapture();
228+
VLOG(4) << "Finish capturing cuda graph @" << cuda_graph_.get();
229+
230+
// compute the right result
231+
cuda_graph_->Replay();
232+
} else {
233+
VLOG(4) << "Run interpreter without cuda graph";
234+
interpreter_->Run({}, false);
235+
}
236+
}
237+
238+
} // namespace paddle::framework
239+
240+
#endif // PADDLE_WITH_CUDA
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#ifdef PADDLE_WITH_CUDA || defined(PADDLE_WITH_HIP)
18+
19+
#include "paddle/fluid/framework/new_executor/instruction/instruction_base.h"
20+
#include "paddle/fluid/framework/new_executor/interpreter/execution_config.h"
21+
#include "paddle/fluid/framework/new_executor/new_executor_defs.h"
22+
#include "paddle/phi/backends/gpu/cuda/cuda_graph.h"
23+
24+
namespace ir {
25+
class Operation;
26+
} // namespace ir
27+
28+
namespace paddle {
29+
namespace framework {
30+
class Scope;
31+
class Value;
32+
class PirInterpreter;
33+
class ValueExecutionInfo;
34+
35+
class CudaGraphInstruction : public InstructionBase {
36+
public:
37+
CudaGraphInstruction(size_t id,
38+
const phi::Place& place,
39+
::pir::Operation* op,
40+
uint8_t* cuda_graph_state_ref,
41+
int64_t cuda_graph_capture_pool_id,
42+
ValueExecutionInfo* value_exe_info,
43+
interpreter::ExecutionConfig execution_config);
44+
45+
~CudaGraphInstruction();
46+
47+
void Run() override;
48+
49+
const std::string& Name() const override { return name_; }
50+
51+
::pir::Operation* Operation() const override { return op_; }
52+
53+
PirInterpreter* interpreter() const { return interpreter_; }
54+
55+
void SetOutputHooks(const std::vector<PirHookFunc>& hookfuncs);
56+
57+
void SetInputHooks(const std::vector<PirHookFunc>& hookfuncs);
58+
59+
private:
60+
const phi::Place& place_;
61+
pir::Operation* op_;
62+
uint8_t* cuda_graph_state_ref_ = nullptr;
63+
int64_t cuda_graph_capture_pool_id_ = -1;
64+
65+
std::string name_{"cuda_graph_instruction"};
66+
67+
std::vector<Variable*> input_vars_;
68+
std::vector<Variable*> output_vars_;
69+
70+
PirInterpreter* interpreter_ = nullptr;
71+
72+
std::vector<std::string> skip_gc_names_;
73+
74+
std::unique_ptr<phi::backends::gpu::CUDAGraph> cuda_graph_ = nullptr;
75+
std::vector<phi::DenseTensor> input_tensors_;
76+
std::vector<phi::DenseTensor> output_tensors_;
77+
};
78+
79+
} // namespace framework
80+
} // namespace paddle
81+
82+
#endif // PADDLE_WITH_CUDA

0 commit comments

Comments
 (0)