Skip to content

Commit df92776

Browse files
authored
Merge pull request #7269 from emailweixu/calc_gradient
Calculating gradients for partial graph
2 parents 5f98500 + 6e5eae1 commit df92776

File tree

8 files changed

+302
-45
lines changed

8 files changed

+302
-45
lines changed

paddle/framework/grad_op_desc_maker.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,11 @@ class GradOpDescMakerBase {
8787
auto onames = this->Output(name);
8888
ret_val.reserve(onames.size());
8989
std::transform(onames.begin(), onames.end(), std::back_inserter(ret_val),
90-
GradVarName);
90+
[this](const std::string& fwd_var_name) -> std::string {
91+
auto g_name = GradVarName(fwd_var_name);
92+
(*this->grad_to_var_)[g_name] = fwd_var_name;
93+
return g_name;
94+
});
9195
return ret_val;
9296
}
9397

paddle/framework/op_desc.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ class OpDesc {
129129
}
130130

131131
proto::OpDesc desc_;
132-
// input arg name => output variable names
132+
// input arg name => input variable names
133133
VariableNameMap inputs_;
134134
// output arg name => output variable names
135135
VariableNameMap outputs_;

paddle/operators/norm_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class NormOpMaker : public framework::OpProtoAndCheckerMaker {
3939
"M = C * H * W");
4040
AddComment(R"DOC(
4141
"Input shape: $(N, C, H, W)$
42-
Sclae shape: $(C, 1)$
42+
Scale shape: $(C, 1)$
4343
Output shape: $(N, C, H, W)$
4444
Where
4545
forward

paddle/operators/norm_op.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class NormKernel : public framework::OpKernel<T> {
6666
context.GetPlace());
6767
auto tmp = framework::EigenVector<T, Eigen::RowMajor,
6868
Eigen::DenseIndex>::Flatten(tmp_tensor);
69-
// get colsum and sqrt , inverse
69+
// get colsum and sqrt , inverse
7070
auto dim = Eigen::array<int, 1>({{0}});
7171
tmp.device(*place) = x_square_batch_eigen.sum(dim);
7272
tmp.device(*place) = (tmp + epsilon).sqrt().inverse();

0 commit comments

Comments
 (0)