@@ -47,17 +47,22 @@ static bool IsFakeValue(const pir::Value& value) {
47
47
return value.impl () == nullptr || !value.type ();
48
48
}
49
49
50
+ static bool IsFakeValueName (const std::string& name) {
51
+ return name == paddle::framework::kFakeVarName ||
52
+ name == paddle::framework::kEmptyVarName ;
53
+ }
54
+
50
55
// Filter params without grads in global block. In this case, we will
51
56
// tag its AutogradMeta with stop_gradient = True to avoid fault from
52
57
// reducer while training on multi-cards.
53
58
static void pir_clear_no_grad_edges (
54
59
const std::vector<paddle::Tensor>& params,
55
- const std::vector<pir::Value >& backward_params_grad ,
60
+ const std::vector<std::string >& backward_params_grad_names ,
56
61
const pir::Block* backward_block,
57
62
egr::GradNodeBase* grad_node,
58
63
size_t slot_id) {
59
64
for (size_t i = 0 ; i < params.size (); ++i) {
60
- if (IsFakeValue (backward_params_grad [i])) {
65
+ if (IsFakeValueName (backward_params_grad_names [i])) {
61
66
VLOG (3 ) << " clear edge of " << params[i].name ();
62
67
grad_node->MutableOutputMeta ()[slot_id][i].GetMutableEdge ().Clear ();
63
68
}
@@ -86,10 +91,9 @@ static void clear_unused_out_var_in_backward(
86
91
}
87
92
88
93
static void pir_clear_unused_out_var_in_backward (
89
- const std::vector<pir::Value >& fo ,
94
+ const std::vector<std::string >& out_names ,
90
95
const pir::Block* backward_block,
91
96
paddle::framework::Scope* scope) {
92
- auto out_names = details::GetNameFromValue (fo);
93
97
std::deque<std::shared_ptr<paddle::memory::Allocation>>* garbages =
94
98
new std::deque<std::shared_ptr<paddle::memory::Allocation>>();
95
99
for (auto out_name : out_names) {
@@ -124,13 +128,12 @@ static std::vector<paddle::Tensor> filter_unused_input_var_in_backward(
124
128
125
129
static std::vector<paddle::Tensor> pir_filter_unused_input_var_in_backward (
126
130
const std::vector<paddle::Tensor>& x,
127
- const std::string x_key_name,
128
131
const paddle::framework::AttributeMap& attrs) {
129
- auto values =
130
- PADDLE_GET_CONST (std::vector<::pir::Value >, attrs.at (x_key_name ));
132
+ const auto & names =
133
+ PADDLE_GET_CONST (std::vector<std::string >, attrs.at (" bx_names " ));
131
134
auto filter_x = std::vector<paddle::Tensor>(x);
132
135
for (size_t i = 0 ; i < x.size (); i++) {
133
- if (values [i]. impl () == nullptr ) {
136
+ if (IsFakeValueName (names [i]) ) {
134
137
auto fake = paddle::Tensor (std::make_shared<phi::DenseTensor>());
135
138
fake.set_name (paddle::framework::kFakeVarName );
136
139
filter_x[i] = fake;
@@ -143,17 +146,17 @@ static std::vector<paddle::Tensor>
143
146
pir_filter_no_need_buffer_input_var_in_backward (
144
147
const std::vector<paddle::Tensor>& x,
145
148
const paddle::framework::AttributeMap& attrs) {
146
- auto forward_inputs_values =
147
- PADDLE_GET_CONST (std::vector<::pir::Value >, attrs.at (" fx " ));
148
- auto no_need_buffers_values =
149
- PADDLE_GET_CONST ( std::vector<::pir::Value >, attrs.at (" no_need_buffers " ));
149
+ const auto & forward_inputs_names =
150
+ PADDLE_GET_CONST (std::vector<std::string >, attrs.at (" fx_names " ));
151
+ const auto & no_need_buffers_names = PADDLE_GET_CONST (
152
+ std::vector<std::string >, attrs.at (" no_need_buffers_names " ));
150
153
auto filter_x = std::vector<paddle::Tensor>(x);
151
154
std::deque<std::shared_ptr<paddle::memory::Allocation>>* garbages =
152
155
new std::deque<std::shared_ptr<paddle::memory::Allocation>>();
153
156
for (size_t i = 0 ; i < x.size (); i++) {
154
- if (std::find (no_need_buffers_values .begin (),
155
- no_need_buffers_values .end (),
156
- forward_inputs_values [i]) != no_need_buffers_values .end ()) {
157
+ if (std::find (no_need_buffers_names .begin (),
158
+ no_need_buffers_names .end (),
159
+ forward_inputs_names [i]) != no_need_buffers_names .end ()) {
157
160
auto & tensor = filter_x[i];
158
161
if (tensor.has_allocation () && tensor.is_dense_tensor ()) {
159
162
auto copied_dense_tensor = std::make_shared<phi::DenseTensor>(
@@ -238,7 +241,7 @@ inline void run_program_ad_func(
238
241
is_test = PADDLE_GET_CONST (bool , attrs.at (" is_test" ));
239
242
}
240
243
if (!is_test && require_any_grad) {
241
- auto x_names =
244
+ const auto & x_names =
242
245
PADDLE_GET_CONST (std::vector<std::string>, attrs.at (" x_names" ));
243
246
244
247
// Create GradOpNode (1 means [out_grad], 2 means [x_grad, paramx_grad])
@@ -304,10 +307,6 @@ inline void pir_run_program_ad_func(
304
307
bool require_any_grad = egr::EagerUtils::ComputeRequireGrad (
305
308
trace_backward, &p_autograd_x, &p_autograd_params);
306
309
307
- // Create Middle Output for GradNode.
308
- auto middle_values =
309
- PADDLE_GET_CONST (std::vector<::pir::Value>, attrs.at (" fm" ));
310
-
311
310
auto is_test = false ;
312
311
if (attrs.count (" is_test" )) {
313
312
is_test = PADDLE_GET_CONST (bool , attrs.at (" is_test" ));
@@ -348,19 +347,19 @@ inline void pir_run_program_ad_func(
348
347
// For the first kind, we can create a empty Tensor to replace it.
349
348
// For the second kind, we need to keep the meta only Tensor.
350
349
auto filter_x = pir_filter_no_need_buffer_input_var_in_backward (
351
- pir_filter_unused_input_var_in_backward (x_tmp, " bx " , attrs), attrs);
350
+ pir_filter_unused_input_var_in_backward (x_tmp, attrs), attrs);
352
351
// Set TensorWrappers
353
352
grad_node->SetFwdX (filter_x);
354
353
355
354
std::shared_ptr<::pir::Program> backward_program = PADDLE_GET_CONST (
356
355
std::shared_ptr<::pir::Program>, attrs.at (" backward_program" ));
357
- auto forward_outputs =
358
- PADDLE_GET_CONST (std::vector<::pir::Value >, attrs.at (" fo " ));
359
- auto backward_params_grad =
360
- PADDLE_GET_CONST (std::vector<::pir::Value >, attrs.at (" bp_g " ));
356
+ const auto & forward_outputs_names =
357
+ PADDLE_GET_CONST (std::vector<std::string >, attrs.at (" fo_names " ));
358
+ const auto & backward_params_grad_names =
359
+ PADDLE_GET_CONST (std::vector<std::string >, attrs.at (" bp_g_names " ));
361
360
362
361
pir_clear_unused_out_var_in_backward (
363
- forward_outputs , backward_program->block (), step_scope[0 ]);
362
+ forward_outputs_names , backward_program->block (), step_scope[0 ]);
364
363
365
364
grad_node->SetFwdParams (params_tmp);
366
365
@@ -372,7 +371,7 @@ inline void pir_run_program_ad_func(
372
371
// Clear no grad edges
373
372
VLOG (2 ) << " clear no grad edges." ;
374
373
pir_clear_no_grad_edges (params,
375
- backward_params_grad ,
374
+ backward_params_grad_names ,
376
375
backward_program->block (),
377
376
grad_node.get (),
378
377
/* slot id*/ 1 );
0 commit comments