@@ -70,13 +70,14 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel {
70
70
const int64_t batch_size = ctx->GetInputDim (" X" )[0 ];
71
71
std::vector<int64_t > output_shape ({batch_size, 1 });
72
72
ctx->SetOutputDim (" Out" , framework::make_ddim (output_shape));
73
+ ctx->ShareLoD (" X" , /* ->*/ " Out" );
73
74
}
74
75
75
76
protected:
76
77
framework::OpKernelType GetExpectedKernelType (
77
78
const framework::ExecutionContext& ctx) const override {
78
79
return framework::OpKernelType (
79
- framework::ToDataType (ctx.Input <framework::Tensor >(" X" )->type ()),
80
+ framework::ToDataType (ctx.Input <framework::LoDTensor >(" X" )->type ()),
80
81
ctx.GetPlace ());
81
82
}
82
83
};
@@ -86,32 +87,34 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
86
87
public:
87
88
void Make () override {
88
89
AddInput (" X" ,
89
- " (Tensor , required) The input tensor with shape [N, D], "
90
+ " (LoDTensor , required) The input tensor with shape [N, D], "
90
91
" where N is the size of mini-batch, and D is the feature size." );
91
92
AddInput (" W" ,
92
- " (Tensor , required), The parameters of hierarchical "
93
+ " (LoDTensor , required), The parameters of hierarchical "
93
94
" sigmoid operator, each of them is a 2-D tensor, the shape is"
94
95
" [K, D]. Which K is the num of non-leaf node in Path Tree" );
95
96
AddInput (" Label" ,
96
- " (Tensor , required), The labels of training data. It's a"
97
+ " (LoDTensor , required), The labels of training data. It's a"
97
98
" tensor with shape [N, 1]." );
98
99
AddInput (" PTable" ,
99
- " (Tensor , optional), The Path Table from root to current word"
100
+ " (LoDTensor , optional), The Path Table from root to current word"
100
101
" it should have shape like [N, L], L is the length of the Path" )
101
102
.AsDispensable ();
102
- AddInput (" PCode" ,
103
- " (Tensor, optional), The Code on each Node of the Path from root "
104
- " to current word"
105
- " it should have shape like [N, L], L is the length of the Path" )
103
+ AddInput (
104
+ " PCode" ,
105
+ " (LoDTensor, optional), The Code on each Node of the Path from root "
106
+ " to current word"
107
+ " it should have shape like [N, L], L is the length of the Path" )
106
108
.AsDispensable ();
107
109
AddInput (" Bias" ,
108
- " (Tensor , optional), The bias is a tensor with shape"
110
+ " (LoDTensor , optional), The bias is a tensor with shape"
109
111
" [1, num_classes - 1]." );
110
- AddOutput (" Out" ,
111
- " (Tensor, required) The output of hierarchical sigmoid operator."
112
- " The shape is [N, 1]." );
112
+ AddOutput (
113
+ " Out" ,
114
+ " (LoDTensor, required) The output of hierarchical sigmoid operator."
115
+ " The shape is [N, 1]." );
113
116
AddOutput (" PreOut" ,
114
- " (Tensor , required) A intermedia 2-D tensor with shape "
117
+ " (LoDTensor , required) A intermedia 2-D tensor with shape "
115
118
" [batch_size, code_length], where code_length represents the "
116
119
" maximum path length from root to leaf nodes." )
117
120
.AsIntermediate ();
@@ -124,6 +127,10 @@ belonging to the right branch. This idea is from
124
127
"F. Morin, Y. Bengio (AISTATS 05):
125
128
Hierarchical Probabilistic Neural Network Language Model."
126
129
)DOC" );
130
+ AddAttr<bool >(" is_sparse" ,
131
+ " (boolean, default false) "
132
+ " Sparse update." )
133
+ .SetDefault (false );
127
134
}
128
135
};
129
136
@@ -133,6 +140,8 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
133
140
void InferShape (framework::InferShapeContext* ctx) const override {
134
141
PADDLE_ENFORCE (ctx->HasInput (" W" ), " Input(W) should not be null." );
135
142
PADDLE_ENFORCE (ctx->HasInput (" Label" ), " Input(Label) should not be null." );
143
+ PADDLE_ENFORCE (ctx->HasInput (framework::GradVarName (" Out" )),
144
+ " Input(Out@Grad) should not be null" );
136
145
PADDLE_ENFORCE (ctx->HasInput (" PreOut" ),
137
146
" Input(Preout) should not be null." );
138
147
PADDLE_ENFORCE (ctx->HasOutput (framework::GradVarName (" W" )),
@@ -142,27 +151,52 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
142
151
ctx->SetOutputDim (framework::GradVarName (" Bias" ),
143
152
ctx->GetInputDim (" Bias" ));
144
153
}
145
- ctx->SetOutputDim (framework::GradVarName (" W" ), ctx->GetInputDim (" W" ));
154
+ if (!ctx->Attrs ().Get <bool >(" is_sparse" )) {
155
+ ctx->SetOutputDim (framework::GradVarName (" W" ), ctx->GetInputDim (" W" ));
156
+ }
146
157
ctx->SetOutputDim (framework::GradVarName (" X" ), ctx->GetInputDim (" X" ));
147
158
}
148
159
149
160
protected:
150
161
framework::OpKernelType GetExpectedKernelType (
151
162
const framework::ExecutionContext& ctx) const override {
152
163
return framework::OpKernelType (
153
- framework::ToDataType (ctx.Input <framework::Tensor >(" X" )->type ()),
164
+ framework::ToDataType (ctx.Input <framework::LoDTensor >(" X" )->type ()),
154
165
ctx.GetPlace ());
155
166
}
156
167
};
157
168
169
+ class HierarchicalSigmoidGradOpGradVarTypeInference
170
+ : public framework::VarTypeInference {
171
+ public:
172
+ void operator ()(const framework::OpDesc& op_desc,
173
+ framework::BlockDesc* block) const override {
174
+ auto out_var_name = op_desc.Output (framework::GradVarName (" W" )).front ();
175
+ auto attr = op_desc.GetAttr (" is_sparse" );
176
+ bool is_sparse = boost::get<bool >(attr);
177
+ if (is_sparse) {
178
+ VLOG (3 ) << " hierarchical_sigmoid_grad op " << framework::GradVarName (" W" )
179
+ << " is set to SelectedRows" ;
180
+ block->Var (out_var_name)
181
+ ->SetType (framework::proto::VarType::SELECTED_ROWS);
182
+ } else {
183
+ VLOG (3 ) << " hierarchical_sigmoid_grad op " << framework::GradVarName (" W" )
184
+ << " is set to LoDTensor" ;
185
+ block->Var (out_var_name)->SetType (framework::proto::VarType::LOD_TENSOR);
186
+ }
187
+ block->Var (out_var_name)->SetDataType (block->Var (" W" )->GetDataType ());
188
+ }
189
+ };
190
+
158
191
} // namespace operators
159
192
} // namespace paddle
160
193
161
194
namespace ops = paddle::operators;
162
195
REGISTER_OPERATOR (hierarchical_sigmoid, ops::HierarchicalSigmoidOp,
163
196
ops::HierarchicalSigmoidOpMaker<int >,
164
197
paddle::framework::DefaultGradOpDescMaker<true >);
165
- REGISTER_OPERATOR (hierarchical_sigmoid_grad, ops::HierarchicalSigmoidGradOp);
198
+ REGISTER_OPERATOR (hierarchical_sigmoid_grad, ops::HierarchicalSigmoidGradOp,
199
+ ops::HierarchicalSigmoidGradOpGradVarTypeInference);
166
200
REGISTER_OP_CPU_KERNEL (
167
201
hierarchical_sigmoid,
168
202
ops::HierarchicalSigmoidOpKernel<paddle::platform::CPUDeviceContext, float >,
0 commit comments