Skip to content

Commit feac69a

Browse files
JiabinYangliupluswei
authored andcommitted
test=release/1.4, fix hsigmoid dereference nullptr (#16770)
* test=release/1.4, fix hsigmoid dereference nullptr * test=release/1.4, refine condition * test=release/1.4, refine comments
1 parent af53eb6 commit feac69a

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

paddle/fluid/operators/hierarchical_sigmoid_op.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,8 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
234234
zero(dev_ctx, w_grad, static_cast<T>(0.0));
235235
bit_code->MulGradWeight(pre_out_grad, w_grad, in);
236236
} else {
237+
PADDLE_ENFORCE(path != nullptr,
238+
"Sparse mode should not be used without custom tree!");
237239
framework::Vector<int64_t> real_rows = PathToRows(*path);
238240
auto* w_grad =
239241
ctx.Output<framework::SelectedRows>(framework::GradVarName("W"));

python/paddle/fluid/layers/nn.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5589,12 +5589,21 @@ def hsigmoid(input,
55895589
raise ValueError(
55905590
"num_classes must not be less than 2 with default tree")
55915591

5592+
if (not is_custom) and (is_sparse):
5593+
print("Sparse mode should not be used without custom tree")
5594+
is_sparse = False
5595+
5596+
if (not is_custom) and ((path_table is not None) or
5597+
(path_code is not None)):
5598+
raise ValueError(
5599+
"only num_classes should be passed without custom tree")
5600+
55925601
if (is_custom) and (path_code is None):
5593-
raise ValueError("path_code should not be None with costum tree")
5602+
raise ValueError("path_code should not be None with custom tree")
55945603
elif (is_custom) and (path_table is None):
5595-
raise ValueError("path_table should not be None with costum tree")
5604+
raise ValueError("path_table should not be None with custom tree")
55965605
elif (is_custom) and (num_classes is None):
5597-
raise ValueError("num_classes should not be None with costum tree")
5606+
raise ValueError("num_classes should not be None with custom tree")
55985607
else:
55995608
pass
56005609

0 commit comments

Comments
 (0)