Skip to content

Commit 0318f47

Browse files
reyoungQiJune
authored andcommitted
Enhance in backward (#5262)
Set gradient's data type based on its forward variable
1 parent 1363ddb commit 0318f47

File tree

1 file changed

+27
-2
lines changed

1 file changed

+27
-2
lines changed

paddle/framework/backward.cc

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <deque>
1919
#include <list>
2020
#include <memory>
21+
#include <unordered_set>
2122

2223
#include "paddle/framework/block_desc.h"
2324
#include "paddle/framework/op_registry.h"
@@ -285,6 +286,15 @@ static bool AllGradInSet(const std::vector<std::string>& names,
285286
return true;
286287
}
287288

289+
static std::string FwdName(const std::string& grad_name) {
290+
auto pos = grad_name.find("@GRAD");
291+
if (pos == std::string::npos) {
292+
return "";
293+
} else {
294+
return grad_name.substr(0, pos);
295+
}
296+
}
297+
288298
static void CreateGradVarInBlock(
289299
size_t grad_op_start_index,
290300
const std::unordered_map<std::string, std::string>& param_name_map,
@@ -294,15 +304,15 @@ static void CreateGradVarInBlock(
294304
for (size_t op_index = grad_op_start_index; op_index < ops.size();
295305
++op_index) {
296306
bool need_infer_shape = false;
307+
std::unordered_set<std::string> new_vars;
297308
ForEachVarName(ops[op_index]->Outputs(),
298309
[&](const std::string& grad_var_name) {
299310
if (block_desc->HasVar(grad_var_name)) {
300311
return false;
301312
}
302313
need_infer_shape = true;
303314
auto var = block_desc->Var(grad_var_name);
304-
// FIXME(qiao) infer the datatype
305-
var->SetDataType(framework::DataType::FP32);
315+
new_vars.insert(var->Name());
306316
auto it = param_name_map.find(grad_var_name);
307317
if (it == param_name_map.end()) {
308318
return false;
@@ -316,6 +326,21 @@ static void CreateGradVarInBlock(
316326
});
317327
if (need_infer_shape) {
318328
ops[op_index]->InferVarType(block_desc);
329+
for (auto& arg : ops[op_index]->OutputArgumentNames()) {
330+
if (new_vars.find(arg) == new_vars.end()) {
331+
continue;
332+
}
333+
auto pname = FwdName(arg);
334+
auto* param = block_desc->FindVar(pname);
335+
auto* grad = block_desc->FindVar(arg);
336+
if (param == nullptr) {
337+
LOG(WARNING) << "Cannot find forward variable of " << arg
338+
<< ". Set its gradient to FP32";
339+
grad->SetDataType(DataType::FP32);
340+
} else {
341+
grad->SetDataType(param->GetDataType());
342+
}
343+
}
319344
ops[op_index]->InferShape(*block_desc);
320345
}
321346
}

0 commit comments

Comments
 (0)