Skip to content

Commit 613186b

Browse files
SigureMoCopilot
andauthored
[Dy2St] Support pass cuda graph state and dispatch key in run program op and move get value name to python side (#73417)
--------- Co-authored-by: Copilot <[email protected]>
1 parent 23b4a3b commit 613186b

File tree

18 files changed

+453
-395
lines changed

18 files changed

+453
-395
lines changed

paddle/fluid/eager/to_static/run_program_op_func.h

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -47,17 +47,22 @@ static bool IsFakeValue(const pir::Value& value) {
4747
return value.impl() == nullptr || !value.type();
4848
}
4949

50+
static bool IsFakeValueName(const std::string& name) {
51+
return name == paddle::framework::kFakeVarName ||
52+
name == paddle::framework::kEmptyVarName;
53+
}
54+
5055
// Filter params without grads in global block. In this case, we will
5156
// tag its AutogradMeta with stop_gradient = True to avoid fault from
5257
// reducer while training on multi-cards.
5358
static void pir_clear_no_grad_edges(
5459
const std::vector<paddle::Tensor>& params,
55-
const std::vector<pir::Value>& backward_params_grad,
60+
const std::vector<std::string>& backward_params_grad_names,
5661
const pir::Block* backward_block,
5762
egr::GradNodeBase* grad_node,
5863
size_t slot_id) {
5964
for (size_t i = 0; i < params.size(); ++i) {
60-
if (IsFakeValue(backward_params_grad[i])) {
65+
if (IsFakeValueName(backward_params_grad_names[i])) {
6166
VLOG(3) << "clear edge of " << params[i].name();
6267
grad_node->MutableOutputMeta()[slot_id][i].GetMutableEdge().Clear();
6368
}
@@ -86,10 +91,9 @@ static void clear_unused_out_var_in_backward(
8691
}
8792

8893
static void pir_clear_unused_out_var_in_backward(
89-
const std::vector<pir::Value>& fo,
94+
const std::vector<std::string>& out_names,
9095
const pir::Block* backward_block,
9196
paddle::framework::Scope* scope) {
92-
auto out_names = details::GetNameFromValue(fo);
9397
std::deque<std::shared_ptr<paddle::memory::Allocation>>* garbages =
9498
new std::deque<std::shared_ptr<paddle::memory::Allocation>>();
9599
for (auto out_name : out_names) {
@@ -124,13 +128,12 @@ static std::vector<paddle::Tensor> filter_unused_input_var_in_backward(
124128

125129
static std::vector<paddle::Tensor> pir_filter_unused_input_var_in_backward(
126130
const std::vector<paddle::Tensor>& x,
127-
const std::string x_key_name,
128131
const paddle::framework::AttributeMap& attrs) {
129-
auto values =
130-
PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at(x_key_name));
132+
const auto& names =
133+
PADDLE_GET_CONST(std::vector<std::string>, attrs.at("bx_names"));
131134
auto filter_x = std::vector<paddle::Tensor>(x);
132135
for (size_t i = 0; i < x.size(); i++) {
133-
if (values[i].impl() == nullptr) {
136+
if (IsFakeValueName(names[i])) {
134137
auto fake = paddle::Tensor(std::make_shared<phi::DenseTensor>());
135138
fake.set_name(paddle::framework::kFakeVarName);
136139
filter_x[i] = fake;
@@ -143,17 +146,17 @@ static std::vector<paddle::Tensor>
143146
pir_filter_no_need_buffer_input_var_in_backward(
144147
const std::vector<paddle::Tensor>& x,
145148
const paddle::framework::AttributeMap& attrs) {
146-
auto forward_inputs_values =
147-
PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("fx"));
148-
auto no_need_buffers_values =
149-
PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("no_need_buffers"));
149+
const auto& forward_inputs_names =
150+
PADDLE_GET_CONST(std::vector<std::string>, attrs.at("fx_names"));
151+
const auto& no_need_buffers_names = PADDLE_GET_CONST(
152+
std::vector<std::string>, attrs.at("no_need_buffers_names"));
150153
auto filter_x = std::vector<paddle::Tensor>(x);
151154
std::deque<std::shared_ptr<paddle::memory::Allocation>>* garbages =
152155
new std::deque<std::shared_ptr<paddle::memory::Allocation>>();
153156
for (size_t i = 0; i < x.size(); i++) {
154-
if (std::find(no_need_buffers_values.begin(),
155-
no_need_buffers_values.end(),
156-
forward_inputs_values[i]) != no_need_buffers_values.end()) {
157+
if (std::find(no_need_buffers_names.begin(),
158+
no_need_buffers_names.end(),
159+
forward_inputs_names[i]) != no_need_buffers_names.end()) {
157160
auto& tensor = filter_x[i];
158161
if (tensor.has_allocation() && tensor.is_dense_tensor()) {
159162
auto copied_dense_tensor = std::make_shared<phi::DenseTensor>(
@@ -238,7 +241,7 @@ inline void run_program_ad_func(
238241
is_test = PADDLE_GET_CONST(bool, attrs.at("is_test"));
239242
}
240243
if (!is_test && require_any_grad) {
241-
auto x_names =
244+
const auto& x_names =
242245
PADDLE_GET_CONST(std::vector<std::string>, attrs.at("x_names"));
243246

244247
// Create GradOpNode (1 means [out_grad], 2 means [x_grad, paramx_grad])
@@ -304,10 +307,6 @@ inline void pir_run_program_ad_func(
304307
bool require_any_grad = egr::EagerUtils::ComputeRequireGrad(
305308
trace_backward, &p_autograd_x, &p_autograd_params);
306309

307-
// Create Middle Output for GradNode.
308-
auto middle_values =
309-
PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("fm"));
310-
311310
auto is_test = false;
312311
if (attrs.count("is_test")) {
313312
is_test = PADDLE_GET_CONST(bool, attrs.at("is_test"));
@@ -348,19 +347,19 @@ inline void pir_run_program_ad_func(
348347
// For the first kind, we can create a empty Tensor to replace it.
349348
// For the second kind, we need to keep the meta only Tensor.
350349
auto filter_x = pir_filter_no_need_buffer_input_var_in_backward(
351-
pir_filter_unused_input_var_in_backward(x_tmp, "bx", attrs), attrs);
350+
pir_filter_unused_input_var_in_backward(x_tmp, attrs), attrs);
352351
// Set TensorWrappers
353352
grad_node->SetFwdX(filter_x);
354353

355354
std::shared_ptr<::pir::Program> backward_program = PADDLE_GET_CONST(
356355
std::shared_ptr<::pir::Program>, attrs.at("backward_program"));
357-
auto forward_outputs =
358-
PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("fo"));
359-
auto backward_params_grad =
360-
PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("bp_g"));
356+
const auto& forward_outputs_names =
357+
PADDLE_GET_CONST(std::vector<std::string>, attrs.at("fo_names"));
358+
const auto& backward_params_grad_names =
359+
PADDLE_GET_CONST(std::vector<std::string>, attrs.at("bp_g_names"));
361360

362361
pir_clear_unused_out_var_in_backward(
363-
forward_outputs, backward_program->block(), step_scope[0]);
362+
forward_outputs_names, backward_program->block(), step_scope[0]);
364363

365364
grad_node->SetFwdParams(params_tmp);
366365

@@ -372,7 +371,7 @@ inline void pir_run_program_ad_func(
372371
// Clear no grad edges
373372
VLOG(2) << "clear no grad edges.";
374373
pir_clear_no_grad_edges(params,
375-
backward_params_grad,
374+
backward_params_grad_names,
376375
backward_program->block(),
377376
grad_node.get(),
378377
/*slot id*/ 1);

0 commit comments

Comments
 (0)