Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit e3699c0

Browse files
piiswrongcjolivier01
authored andcommitted
fix makenonlossgrad bug (#8508)
1 parent ca3d56f commit e3699c0

File tree

3 files changed

+11
-10
lines changed

3 files changed

+11
-10
lines changed

src/operator/operator_common.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -412,8 +412,8 @@ inline std::vector<nnvm::NodeEntry> MakeZeroGradNodes(
412412

413413
// check whether all output grads are zero.
414414
inline bool CheckGradAllZero(const std::vector<nnvm::NodeEntry>& ograds) {
415-
const auto zero_op = nnvm::Op::Get("_zeros");
416-
const auto zero_like_op = nnvm::Op::Get("zeros_like");
415+
static const auto zero_op = nnvm::Op::Get("_zeros");
416+
static const auto zero_like_op = nnvm::Op::Get("zeros_like");
417417
if (!ograds.size()) return false;
418418
for (const auto& grad : ograds) {
419419
if (!grad.node) return false;

src/operator/tensor/broadcast_reduce_op_index.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,9 @@ Examples::
154154
.set_attr<FCompute>("FCompute<cpu>", PickOpForward<cpu>)
155155
.set_attr<nnvm::FGradient>("FGradient",
156156
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
157-
auto ret = MakeNonlossGradNode("_backward_pick", n, ograds,
158-
{n->inputs[1]}, n->attrs.dict);
157+
if (CheckGradAllZero(ograds)) return MakeZeroGradNodes(n, ograds);
158+
auto ret = MakeGradNode("_backward_pick", n, {ograds[0], n->inputs[1]},
159+
n->attrs.dict);
159160
auto p = MakeNode("zeros_like", n->attrs.name + "_index_backward",
160161
{n->inputs[1]}, nullptr, &n);
161162
ret.emplace_back(nnvm::NodeEntry{p, 0, 0});

src/operator/tensor/elemwise_unary_op_basic.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -241,9 +241,9 @@ NNVM_REGISTER_OP(_identity_with_attr_like_rhs)
241241
.set_attr<nnvm::FGradient>(
242242
"FGradient", [](const nnvm::NodePtr& n,
243243
const std::vector<nnvm::NodeEntry>& ograds) {
244-
auto lhs = MakeNonlossGradNode(
245-
"_backward_copy", n, ograds, {},
246-
std::unordered_map<std::string, std::string>());
244+
if (CheckGradAllZero(ograds)) return MakeZeroGradNodes(n, ograds);
245+
auto lhs = MakeGradNode("_backward_copy", n, ograds,
246+
std::unordered_map<std::string, std::string>());
247247
auto ng = MakeNode("zeros_like", n->attrs.name + "_rhs_backward",
248248
{n->inputs[1]}, nullptr, &n);
249249
lhs.push_back(nnvm::NodeEntry{ng, 0, 0});
@@ -284,9 +284,9 @@ NNVM_REGISTER_OP(reshape_like)
284284
.set_attr<nnvm::FGradient>(
285285
"FGradient", [](const nnvm::NodePtr& n,
286286
const std::vector<nnvm::NodeEntry>& ograds) {
287-
auto lhs = MakeNonlossGradNode(
288-
"_backward_copy", n, ograds, {},
289-
std::unordered_map<std::string, std::string>());
287+
if (CheckGradAllZero(ograds)) return MakeZeroGradNodes(n, ograds);
288+
auto lhs = MakeGradNode("_backward_copy", n, ograds,
289+
std::unordered_map<std::string, std::string>());
290290
auto ng = MakeNode("zeros_like", n->attrs.name + "_rhs_backward",
291291
{n->inputs[1]}, nullptr, &n);
292292
lhs.push_back(nnvm::NodeEntry{ng, 0, 0});

0 commit comments

Comments
 (0)