@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
15
15
#include " paddle/fluid/operators/sample_logits_op.h"
16
+ #include < memory>
16
17
#include " paddle/fluid/operators/math/sample_prob.h"
17
18
18
19
namespace paddle {
@@ -60,6 +61,10 @@ class SampleLogitsOpMaker : public framework::OpProtoAndCheckerMaker {
60
61
" (Tensor, default: Tensor<float>), A 2-D tensor with shape [N, NT + S]."
61
62
" The probabilites of sampled positive and negtive labels." )
62
63
.AsIntermediate ();
64
+ AddOutput (" LogitsDim" , " Store dim information of Logits for gradient op" )
65
+ .AsIntermediate ();
66
+ AddOutput (" LabelsDim" , " Store dim information of Logits for gradient op" )
67
+ .AsIntermediate ();
63
68
AddOutput (" SampledLogits" ,
64
69
" (Tensor, default: Tensor<float>), A 2-D tensor with shape"
65
70
" [N, NT + S]. The outputs value of sampled logits, which will be"
@@ -121,6 +126,10 @@ class SampleLogitsOp : public framework::OperatorWithKernel {
121
126
" Output(SampledLogits) should be not null." );
122
127
PADDLE_ENFORCE (ctx->HasOutput (" SampledLabels" ),
123
128
" Output(SampledLabels) should be not null." );
129
+ PADDLE_ENFORCE (ctx->HasOutput (" LogitsDim" ),
130
+ " Output(LogitsDim) should be not null." );
131
+ PADDLE_ENFORCE (ctx->HasOutput (" LabelsDim" ),
132
+ " Output(LabelsDim) should be not null." );
124
133
125
134
auto logits_dims = ctx->GetInputDim (" Logits" );
126
135
auto labels_dims = ctx->GetInputDim (" Labels" );
@@ -137,6 +146,15 @@ class SampleLogitsOp : public framework::OperatorWithKernel {
137
146
ctx->SetOutputDim (" Probabilities" , {logits_dims[0 ], num_sampled_classes});
138
147
ctx->SetOutputDim (" SampledLogits" , {logits_dims[0 ], num_sampled_classes});
139
148
ctx->SetOutputDim (" SampledLabels" , {logits_dims[0 ], labels_dims[1 ]});
149
+
150
+ // append 0 to shape variable to avoid optimized by memory optimize pass
151
+ auto logits_dim_vec = framework::vectorize (logits_dims);
152
+ logits_dim_vec.push_back (0 );
153
+ ctx->SetOutputDim (" LogitsDim" , framework::make_ddim (logits_dim_vec));
154
+
155
+ auto labels_dim_vec = framework::vectorize (labels_dims);
156
+ labels_dim_vec.push_back (0 );
157
+ ctx->SetOutputDim (" LabelsDim" , framework::make_ddim (labels_dim_vec));
140
158
}
141
159
142
160
protected:
@@ -155,28 +173,27 @@ class SampleLogitsOpGrad : public framework::OperatorWithKernel {
155
173
using framework::OperatorWithKernel::OperatorWithKernel;
156
174
157
175
void InferShape (framework::InferShapeContext* ctx) const override {
158
- PADDLE_ENFORCE (ctx->HasInput (" Logits " ),
159
- " Input(Logits ) should not be null." );
160
- PADDLE_ENFORCE (ctx->HasInput (" Labels " ),
161
- " Input(Labels ) should be not null." );
176
+ PADDLE_ENFORCE (ctx->HasInput (" LogitsDim " ),
177
+ " Input(LogitsDim ) should not be null." );
178
+ PADDLE_ENFORCE (ctx->HasInput (" LabelsDim " ),
179
+ " Input(LabelsDim ) should be not null." );
162
180
PADDLE_ENFORCE (ctx->HasInput (" Samples" ),
163
181
" Input(Samples) should be not null." );
164
- PADDLE_ENFORCE (ctx->HasInput (" SampledLogits" ),
165
- " Input(SampledLogits) should be not null." );
166
182
PADDLE_ENFORCE (ctx->HasInput (framework::GradVarName (" SampledLogits" )),
167
183
" Input(SampledLogits@Grad) should not be null." );
168
184
PADDLE_ENFORCE (ctx->HasOutput (framework::GradVarName (" Logits" )),
169
185
" Output(Logits@Grad) should be not null." );
170
186
171
- auto logit_dims = ctx->GetInputDim (" Logits" );
172
- auto label_dims = ctx->GetInputDim (" Labels" );
173
- PADDLE_ENFORCE_EQ (label_dims.size (), 2UL ,
187
+ auto logits_dims = ctx->GetInputDim (" LogitsDim" );
188
+ logits_dims = framework::DDim (logits_dims.Get (), logits_dims.size () - 1 );
189
+ auto labels_dims = ctx->GetInputDim (" LabelsDim" );
190
+ labels_dims = framework::DDim (labels_dims.Get (), labels_dims.size () - 1 );
191
+ PADDLE_ENFORCE_EQ (labels_dims.size (), 2UL ,
174
192
" The label should be a 2-D tensor." );
175
- PADDLE_ENFORCE_EQ (logit_dims .size (), 2UL ,
193
+ PADDLE_ENFORCE_EQ (logits_dims .size (), 2UL ,
176
194
" The logits should be a 2-D tensor." );
177
195
178
- ctx->SetOutputDim (framework::GradVarName (" Logits" ),
179
- ctx->GetInputDim (" Logits" ));
196
+ ctx->SetOutputDim (framework::GradVarName (" Logits" ), logits_dims);
180
197
}
181
198
182
199
protected:
@@ -199,10 +216,9 @@ class SampleLogitsGradMaker : public framework::SingleGradOpDescMaker {
199
216
std::unique_ptr<framework::OpDesc> Apply () const override {
200
217
auto * grad_op = new framework::OpDesc ();
201
218
grad_op->SetType (" sample_logits_grad" );
202
- grad_op->SetInput (" Logits " , Input ( " Logits " ));
203
- grad_op->SetInput (" Labels " , Input ( " Labels " ));
219
+ grad_op->SetInput (" LogitsDim " , Output ( " LogitsDim " ));
220
+ grad_op->SetInput (" LabelsDim " , Output ( " LabelsDim " ));
204
221
grad_op->SetInput (" Samples" , Output (" Samples" ));
205
- grad_op->SetInput (" SampledLogits" , Output (" SampledLogits" ));
206
222
grad_op->SetInput (framework::GradVarName (" SampledLogits" ),
207
223
OutputGrad (" SampledLogits" ));
208
224
grad_op->SetOutput (framework::GradVarName (" Logits" ), InputGrad (" Logits" ));
0 commit comments