Skip to content

Commit 12e1719

Browse files
authored
Merge pull request #14352 from JiabinYang/enhance_hierachical_sigmod_op
Enhance hierarchical sigmoid op
2 parents 36e26a5 + eda0690 commit 12e1719

File tree

9 files changed

+765
-131
lines changed

9 files changed

+765
-131
lines changed

paddle/fluid/API.spec

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ paddle.fluid.layers.sequence_reshape ArgSpec(args=['input', 'new_dim'], varargs=
9898
paddle.fluid.layers.transpose ArgSpec(args=['x', 'perm', 'name'], varargs=None, keywords=None, defaults=(None,))
9999
paddle.fluid.layers.im2sequence ArgSpec(args=['input', 'filter_size', 'stride', 'padding', 'input_image_size', 'out_stride', 'name'], varargs=None, keywords=None, defaults=(1, 1, 0, None, 1, None))
100100
paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples', 'name', 'sampler', 'custom_dist', 'seed', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, None, 'uniform', None, 0, False))
101-
paddle.fluid.layers.hsigmoid ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(None, None, None))
101+
paddle.fluid.layers.hsigmoid ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr', 'name', 'path_table', 'path_code', 'is_custom', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, None, False, False))
102102
paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 'scores', 'beam_size', 'end_id', 'level', 'name'], varargs=None, keywords=None, defaults=(0, None))
103103
paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None))
104104
paddle.fluid.layers.multiplex ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None)

paddle/fluid/framework/selected_rows.h

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,22 @@ class SelectedRows {
120120
*/
121121
int64_t AutoGrownIndex(int64_t key, bool auto_grown, bool is_test = false);
122122

123-
void SyncIndex();
123+
/*
124+
* @brief Get the index of the key from id_to_index_ map.
125+
*/
126+
inline int64_t GetIndexFromId(int64_t key) {
127+
auto iter = id_to_index_.find(key);
128+
if (iter == id_to_index_.end()) {
129+
return -1;
130+
} else {
131+
return iter->second;
132+
}
133+
}
124134

135+
void SyncIndex();
136+
/*
137+
* @brief Get complete Dims before
138+
*/
125139
DDim GetCompleteDims() const {
126140
std::vector<int64_t> dims = vectorize(value_->dims());
127141
dims[0] = height_;
@@ -133,9 +147,10 @@ class SelectedRows {
133147
// SelectedRows are simply concated when adding together. Until a
134148
// SelectedRows add a Tensor, will the duplicate rows be handled.
135149
Vector<int64_t> rows_;
136-
std::unordered_map<int64_t, int64_t> id_to_index_;
150+
std::unordered_map<int64_t, int64_t>
151+
id_to_index_; // should not be used when rows_ has duplicate member
137152
std::unique_ptr<Tensor> value_{nullptr};
138-
int64_t height_;
153+
int64_t height_; // height indicates the underline tensor's height
139154
std::unique_ptr<RWLock> rwlock_{nullptr};
140155
};
141156

paddle/fluid/operators/hierarchical_sigmoid_op.cc

Lines changed: 89 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/operators/hierarchical_sigmoid_op.h"
16+
#include <string>
1617
#include <vector>
17-
1818
namespace paddle {
1919
namespace operators {
2020

@@ -70,13 +70,14 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel {
7070
const int64_t batch_size = ctx->GetInputDim("X")[0];
7171
std::vector<int64_t> output_shape({batch_size, 1});
7272
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
73+
ctx->ShareLoD("X", /*->*/ "Out");
7374
}
7475

7576
protected:
7677
framework::OpKernelType GetExpectedKernelType(
7778
const framework::ExecutionContext& ctx) const override {
7879
return framework::OpKernelType(
79-
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
80+
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
8081
ctx.GetPlace());
8182
}
8283
};
@@ -86,27 +87,40 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
8687
public:
8788
void Make() override {
8889
AddInput("X",
89-
"(Tensor, required) The input tensor with shape [N, D], "
90+
"(LoDTensor, required) The input tensor with shape [N, D], "
9091
"where N is the size of mini-batch, and D is the feature size.");
9192
AddInput("W",
92-
"(Tensor, required), The parameters of hierarchical "
93+
"(LoDTensor, required), The parameters of hierarchical "
9394
"sigmoid operator, each of them is a 2-D tensor, the shape is"
94-
"[num_classes - 1, D].");
95+
"[K, D]. Which K is the num of non-leaf node in Path Tree");
9596
AddInput("Label",
96-
"(Tensor, required), The labels of training data. It's a"
97+
"(LoDTensor, required), The labels of training data. It's a"
9798
"tensor with shape [N, 1].");
99+
AddInput("PTable",
100+
"(LoDTensor, optional), The Path Table from root to current word"
101+
"it should have shape like [N, L], L is the length of the Path")
102+
.AsDispensable();
103+
AddInput(
104+
"PathCode",
105+
"(LoDTensor, optional), The Code on each Node of the Path from root "
106+
"to current word"
107+
"it should have shape like [N, L], L is the length of the Path")
108+
.AsDispensable();
98109
AddInput("Bias",
99-
"(Tensor, optional), The bias is a tensor with shape"
100-
"[1, num_classes - 1].");
101-
AddOutput("Out",
102-
"(Tensor, required) The output of hierarchical sigmoid operator."
103-
"The shape is [N, 1].");
110+
"(LoDTensor, optional), The bias is a tensor with shape or "
111+
"[num_classes, 1]"
112+
"[num_classes - 1, 1].")
113+
.AsDispensable();
114+
AddOutput(
115+
"Out",
116+
"(LoDTensor, required) The output of hierarchical sigmoid operator."
117+
"The shape is [N, 1].");
104118
AddOutput("PreOut",
105-
"(Tensor, required) A intermedia 2-D tensor with shape "
119+
"(LoDTensor, required) A intermedia 2-D tensor with shape "
106120
"[batch_size, code_length], where code_length represents the "
107121
"maximum path length from root to leaf nodes.")
108122
.AsIntermediate();
109-
AddAttr<AttrType>("num_classes", "(int, required), The number of classes")
123+
AddAttr<AttrType>("num_classes", "(int, optional), The number of classes")
110124
.SetDefault(2);
111125
AddComment(R"DOC(
112126
The hierarchical sigmoid operator organize the classes into a binary tree.
@@ -115,6 +129,10 @@ belonging to the right branch. This idea is from
115129
"F. Morin, Y. Bengio (AISTATS 05):
116130
Hierarchical Probabilistic Neural Network Language Model."
117131
)DOC");
132+
AddAttr<bool>("is_sparse",
133+
"(boolean, default false) "
134+
"Sparse update.")
135+
.SetDefault(false);
118136
}
119137
};
120138

@@ -124,36 +142,86 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
124142
void InferShape(framework::InferShapeContext* ctx) const override {
125143
PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) should not be null.");
126144
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null.");
145+
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
146+
"Input(Out@Grad) should not be null");
127147
PADDLE_ENFORCE(ctx->HasInput("PreOut"),
128148
"Input(Preout) should not be null.");
129149
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("W")),
130-
"Output(W@Grad should not be null.)");
131-
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")));
132-
if (ctx->HasOutput(framework::GradVarName("Bias"))) {
133-
ctx->SetOutputDim(framework::GradVarName("Bias"),
134-
ctx->GetInputDim("Bias"));
150+
"Output(W@Grad should not be null.");
151+
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
152+
"Output(X@Grad should not be null.");
153+
if (!ctx->Attrs().Get<bool>("is_sparse")) {
154+
if (ctx->HasOutput(framework::GradVarName("Bias"))) {
155+
ctx->SetOutputDim(framework::GradVarName("Bias"),
156+
ctx->GetInputDim("Bias"));
157+
}
158+
ctx->SetOutputDim(framework::GradVarName("W"), ctx->GetInputDim("W"));
135159
}
136-
ctx->SetOutputDim(framework::GradVarName("W"), ctx->GetInputDim("W"));
137160
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
138161
}
139162

140163
protected:
141164
framework::OpKernelType GetExpectedKernelType(
142165
const framework::ExecutionContext& ctx) const override {
143166
return framework::OpKernelType(
144-
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
167+
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
145168
ctx.GetPlace());
146169
}
147170
};
148171

172+
class HierarchicalSigmoidGradOpGradVarTypeInference
173+
: public framework::VarTypeInference {
174+
public:
175+
void operator()(const framework::OpDesc& op_desc,
176+
framework::BlockDesc* block) const override {
177+
auto w_grad_var_name = op_desc.Output(framework::GradVarName("W")).front();
178+
auto bias_grad_var_name_vec =
179+
op_desc.Output(framework::GradVarName("Bias"));
180+
std::string bias_grad_var_name;
181+
bool hasBias = false;
182+
if (bias_grad_var_name_vec.size()) {
183+
hasBias = true;
184+
bias_grad_var_name =
185+
op_desc.Output(framework::GradVarName("Bias")).front();
186+
}
187+
auto attr = op_desc.GetAttr("is_sparse");
188+
bool is_sparse = boost::get<bool>(attr);
189+
if (is_sparse) {
190+
VLOG(30) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W")
191+
<< " is set to SelectedRows";
192+
block->Var(w_grad_var_name)
193+
->SetType(framework::proto::VarType::SELECTED_ROWS);
194+
if (hasBias) {
195+
VLOG(30) << "hierarchical_sigmoid_grad op "
196+
<< framework::GradVarName("Bias") << " is set to SelectedRows";
197+
block->Var(bias_grad_var_name)
198+
->SetType(framework::proto::VarType::SELECTED_ROWS);
199+
}
200+
} else {
201+
VLOG(30) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W")
202+
<< " is set to LoDTensor";
203+
block->Var(w_grad_var_name)
204+
->SetType(framework::proto::VarType::LOD_TENSOR);
205+
if (hasBias) {
206+
VLOG(30) << "hierarchical_sigmoid_grad op "
207+
<< framework::GradVarName("Bias") << " is set to LoDTensor";
208+
block->Var(bias_grad_var_name)
209+
->SetType(framework::proto::VarType::LOD_TENSOR);
210+
}
211+
}
212+
block->Var(w_grad_var_name)->SetDataType(block->Var("W")->GetDataType());
213+
}
214+
};
215+
149216
} // namespace operators
150217
} // namespace paddle
151218

152219
namespace ops = paddle::operators;
153220
REGISTER_OPERATOR(hierarchical_sigmoid, ops::HierarchicalSigmoidOp,
154221
ops::HierarchicalSigmoidOpMaker<int>,
155222
paddle::framework::DefaultGradOpDescMaker<true>);
156-
REGISTER_OPERATOR(hierarchical_sigmoid_grad, ops::HierarchicalSigmoidGradOp);
223+
REGISTER_OPERATOR(hierarchical_sigmoid_grad, ops::HierarchicalSigmoidGradOp,
224+
ops::HierarchicalSigmoidGradOpGradVarTypeInference);
157225
REGISTER_OP_CPU_KERNEL(
158226
hierarchical_sigmoid,
159227
ops::HierarchicalSigmoidOpKernel<paddle::platform::CPUDeviceContext, float>,

0 commit comments

Comments
 (0)