@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
15
15
#include " paddle/fluid/operators/hierarchical_sigmoid_op.h"
16
+ #include < string>
16
17
#include < vector>
17
-
18
18
namespace paddle {
19
19
namespace operators {
20
20
@@ -70,13 +70,14 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel {
70
70
const int64_t batch_size = ctx->GetInputDim (" X" )[0 ];
71
71
std::vector<int64_t > output_shape ({batch_size, 1 });
72
72
ctx->SetOutputDim (" Out" , framework::make_ddim (output_shape));
73
+ ctx->ShareLoD (" X" , /* ->*/ " Out" );
73
74
}
74
75
75
76
protected:
76
77
framework::OpKernelType GetExpectedKernelType (
77
78
const framework::ExecutionContext& ctx) const override {
78
79
return framework::OpKernelType (
79
- framework::ToDataType (ctx.Input <framework::Tensor >(" X" )->type ()),
80
+ framework::ToDataType (ctx.Input <framework::LoDTensor >(" X" )->type ()),
80
81
ctx.GetPlace ());
81
82
}
82
83
};
@@ -86,27 +87,40 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
86
87
public:
87
88
void Make () override {
88
89
AddInput (" X" ,
89
- " (Tensor , required) The input tensor with shape [N, D], "
90
+ " (LoDTensor , required) The input tensor with shape [N, D], "
90
91
" where N is the size of mini-batch, and D is the feature size." );
91
92
AddInput (" W" ,
92
- " (Tensor , required), The parameters of hierarchical "
93
+ " (LoDTensor , required), The parameters of hierarchical "
93
94
" sigmoid operator, each of them is a 2-D tensor, the shape is"
94
- " [num_classes - 1 , D]." );
95
+ " [K , D]. Which K is the num of non-leaf node in Path Tree " );
95
96
AddInput (" Label" ,
96
- " (Tensor , required), The labels of training data. It's a"
97
+ " (LoDTensor , required), The labels of training data. It's a"
97
98
" tensor with shape [N, 1]." );
99
+ AddInput (" PTable" ,
100
+ " (LoDTensor, optional), The Path Table from root to current word"
101
+ " it should have shape like [N, L], L is the length of the Path" )
102
+ .AsDispensable ();
103
+ AddInput (
104
+ " PathCode" ,
105
+ " (LoDTensor, optional), The Code on each Node of the Path from root "
106
+ " to current word"
107
+ " it should have shape like [N, L], L is the length of the Path" )
108
+ .AsDispensable ();
98
109
AddInput (" Bias" ,
99
- " (Tensor, optional), The bias is a tensor with shape"
100
- " [1, num_classes - 1]." );
101
- AddOutput (" Out" ,
102
- " (Tensor, required) The output of hierarchical sigmoid operator."
103
- " The shape is [N, 1]." );
110
+ " (LoDTensor, optional), The bias is a tensor with shape or "
111
+ " [num_classes, 1]"
112
+ " [num_classes - 1, 1]." )
113
+ .AsDispensable ();
114
+ AddOutput (
115
+ " Out" ,
116
+ " (LoDTensor, required) The output of hierarchical sigmoid operator."
117
+ " The shape is [N, 1]." );
104
118
AddOutput (" PreOut" ,
105
- " (Tensor , required) A intermedia 2-D tensor with shape "
119
+ " (LoDTensor , required) A intermedia 2-D tensor with shape "
106
120
" [batch_size, code_length], where code_length represents the "
107
121
" maximum path length from root to leaf nodes." )
108
122
.AsIntermediate ();
109
- AddAttr<AttrType>(" num_classes" , " (int, required ), The number of classes" )
123
+ AddAttr<AttrType>(" num_classes" , " (int, optional ), The number of classes" )
110
124
.SetDefault (2 );
111
125
AddComment (R"DOC(
112
126
The hierarchical sigmoid operator organize the classes into a binary tree.
@@ -115,6 +129,10 @@ belonging to the right branch. This idea is from
115
129
"F. Morin, Y. Bengio (AISTATS 05):
116
130
Hierarchical Probabilistic Neural Network Language Model."
117
131
)DOC" );
132
+ AddAttr<bool >(" is_sparse" ,
133
+ " (boolean, default false) "
134
+ " Sparse update." )
135
+ .SetDefault (false );
118
136
}
119
137
};
120
138
@@ -124,36 +142,86 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
124
142
void InferShape (framework::InferShapeContext* ctx) const override {
125
143
PADDLE_ENFORCE (ctx->HasInput (" W" ), " Input(W) should not be null." );
126
144
PADDLE_ENFORCE (ctx->HasInput (" Label" ), " Input(Label) should not be null." );
145
+ PADDLE_ENFORCE (ctx->HasInput (framework::GradVarName (" Out" )),
146
+ " Input(Out@Grad) should not be null" );
127
147
PADDLE_ENFORCE (ctx->HasInput (" PreOut" ),
128
148
" Input(Preout) should not be null." );
129
149
PADDLE_ENFORCE (ctx->HasOutput (framework::GradVarName (" W" )),
130
- " Output(W@Grad should not be null.)" );
131
- PADDLE_ENFORCE (ctx->HasOutput (framework::GradVarName (" X" )));
132
- if (ctx->HasOutput (framework::GradVarName (" Bias" ))) {
133
- ctx->SetOutputDim (framework::GradVarName (" Bias" ),
134
- ctx->GetInputDim (" Bias" ));
150
+ " Output(W@Grad should not be null." );
151
+ PADDLE_ENFORCE (ctx->HasOutput (framework::GradVarName (" X" )),
152
+ " Output(X@Grad should not be null." );
153
+ if (!ctx->Attrs ().Get <bool >(" is_sparse" )) {
154
+ if (ctx->HasOutput (framework::GradVarName (" Bias" ))) {
155
+ ctx->SetOutputDim (framework::GradVarName (" Bias" ),
156
+ ctx->GetInputDim (" Bias" ));
157
+ }
158
+ ctx->SetOutputDim (framework::GradVarName (" W" ), ctx->GetInputDim (" W" ));
135
159
}
136
- ctx->SetOutputDim (framework::GradVarName (" W" ), ctx->GetInputDim (" W" ));
137
160
ctx->SetOutputDim (framework::GradVarName (" X" ), ctx->GetInputDim (" X" ));
138
161
}
139
162
140
163
protected:
141
164
framework::OpKernelType GetExpectedKernelType (
142
165
const framework::ExecutionContext& ctx) const override {
143
166
return framework::OpKernelType (
144
- framework::ToDataType (ctx.Input <framework::Tensor >(" X" )->type ()),
167
+ framework::ToDataType (ctx.Input <framework::LoDTensor >(" X" )->type ()),
145
168
ctx.GetPlace ());
146
169
}
147
170
};
148
171
172
+ class HierarchicalSigmoidGradOpGradVarTypeInference
173
+ : public framework::VarTypeInference {
174
+ public:
175
+ void operator ()(const framework::OpDesc& op_desc,
176
+ framework::BlockDesc* block) const override {
177
+ auto w_grad_var_name = op_desc.Output (framework::GradVarName (" W" )).front ();
178
+ auto bias_grad_var_name_vec =
179
+ op_desc.Output (framework::GradVarName (" Bias" ));
180
+ std::string bias_grad_var_name;
181
+ bool hasBias = false ;
182
+ if (bias_grad_var_name_vec.size ()) {
183
+ hasBias = true ;
184
+ bias_grad_var_name =
185
+ op_desc.Output (framework::GradVarName (" Bias" )).front ();
186
+ }
187
+ auto attr = op_desc.GetAttr (" is_sparse" );
188
+ bool is_sparse = boost::get<bool >(attr);
189
+ if (is_sparse) {
190
+ VLOG (30 ) << " hierarchical_sigmoid_grad op " << framework::GradVarName (" W" )
191
+ << " is set to SelectedRows" ;
192
+ block->Var (w_grad_var_name)
193
+ ->SetType (framework::proto::VarType::SELECTED_ROWS);
194
+ if (hasBias) {
195
+ VLOG (30 ) << " hierarchical_sigmoid_grad op "
196
+ << framework::GradVarName (" Bias" ) << " is set to SelectedRows" ;
197
+ block->Var (bias_grad_var_name)
198
+ ->SetType (framework::proto::VarType::SELECTED_ROWS);
199
+ }
200
+ } else {
201
+ VLOG (30 ) << " hierarchical_sigmoid_grad op " << framework::GradVarName (" W" )
202
+ << " is set to LoDTensor" ;
203
+ block->Var (w_grad_var_name)
204
+ ->SetType (framework::proto::VarType::LOD_TENSOR);
205
+ if (hasBias) {
206
+ VLOG (30 ) << " hierarchical_sigmoid_grad op "
207
+ << framework::GradVarName (" Bias" ) << " is set to LoDTensor" ;
208
+ block->Var (bias_grad_var_name)
209
+ ->SetType (framework::proto::VarType::LOD_TENSOR);
210
+ }
211
+ }
212
+ block->Var (w_grad_var_name)->SetDataType (block->Var (" W" )->GetDataType ());
213
+ }
214
+ };
215
+
149
216
} // namespace operators
150
217
} // namespace paddle
151
218
152
219
namespace ops = paddle::operators;
153
220
REGISTER_OPERATOR (hierarchical_sigmoid, ops::HierarchicalSigmoidOp,
154
221
ops::HierarchicalSigmoidOpMaker<int >,
155
222
paddle::framework::DefaultGradOpDescMaker<true >);
156
- REGISTER_OPERATOR (hierarchical_sigmoid_grad, ops::HierarchicalSigmoidGradOp);
223
+ REGISTER_OPERATOR (hierarchical_sigmoid_grad, ops::HierarchicalSigmoidGradOp,
224
+ ops::HierarchicalSigmoidGradOpGradVarTypeInference);
157
225
REGISTER_OP_CPU_KERNEL (
158
226
hierarchical_sigmoid,
159
227
ops::HierarchicalSigmoidOpKernel<paddle::platform::CPUDeviceContext, float >,
0 commit comments