18
18
#include < deque>
19
19
#include < list>
20
20
#include < memory>
21
+ #include < unordered_set>
21
22
22
23
#include " paddle/framework/block_desc.h"
23
24
#include " paddle/framework/op_registry.h"
@@ -285,6 +286,15 @@ static bool AllGradInSet(const std::vector<std::string>& names,
285
286
return true ;
286
287
}
287
288
289
+ static std::string FwdName (const std::string& grad_name) {
290
+ auto pos = grad_name.find (" @GRAD" );
291
+ if (pos == std::string::npos) {
292
+ return " " ;
293
+ } else {
294
+ return grad_name.substr (0 , pos);
295
+ }
296
+ }
297
+
288
298
static void CreateGradVarInBlock (
289
299
size_t grad_op_start_index,
290
300
const std::unordered_map<std::string, std::string>& param_name_map,
@@ -294,15 +304,15 @@ static void CreateGradVarInBlock(
294
304
for (size_t op_index = grad_op_start_index; op_index < ops.size ();
295
305
++op_index) {
296
306
bool need_infer_shape = false ;
307
+ std::unordered_set<std::string> new_vars;
297
308
ForEachVarName (ops[op_index]->Outputs (),
298
309
[&](const std::string& grad_var_name) {
299
310
if (block_desc->HasVar (grad_var_name)) {
300
311
return false ;
301
312
}
302
313
need_infer_shape = true ;
303
314
auto var = block_desc->Var (grad_var_name);
304
- // FIXME(qiao) infer the datatype
305
- var->SetDataType (framework::DataType::FP32);
315
+ new_vars.insert (var->Name ());
306
316
auto it = param_name_map.find (grad_var_name);
307
317
if (it == param_name_map.end ()) {
308
318
return false ;
@@ -316,6 +326,21 @@ static void CreateGradVarInBlock(
316
326
});
317
327
if (need_infer_shape) {
318
328
ops[op_index]->InferVarType (block_desc);
329
+ for (auto & arg : ops[op_index]->OutputArgumentNames ()) {
330
+ if (new_vars.find (arg) == new_vars.end ()) {
331
+ continue ;
332
+ }
333
+ auto pname = FwdName (arg);
334
+ auto * param = block_desc->FindVar (pname);
335
+ auto * grad = block_desc->FindVar (arg);
336
+ if (param == nullptr ) {
337
+ LOG (WARNING) << " Cannot find forward variable of " << arg
338
+ << " . Set its gradient to FP32" ;
339
+ grad->SetDataType (DataType::FP32);
340
+ } else {
341
+ grad->SetDataType (param->GetDataType ());
342
+ }
343
+ }
319
344
ops[op_index]->InferShape (*block_desc);
320
345
}
321
346
}
0 commit comments