Skip to content

Commit da2cc99

Browse files
committed
sampling op optimize
1 parent 4973e07 commit da2cc99

File tree

4 files changed

+49
-18
lines changed

4 files changed

+49
-18
lines changed

paddle/fluid/operators/sampling_id_op.cc

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ class SamplingIdOp : public framework::OperatorWithKernel {
2525

2626
void InferShape(framework::InferShapeContext *ctx) const override {
2727
PADDLE_ENFORCE(ctx->HasInput("X"),
28-
"Input(X) of RowConvOp should not be null.");
28+
"Input(X) of SamplingIdOp should not be null.");
2929
PADDLE_ENFORCE(ctx->HasOutput("Out"),
30-
"Output(Out) of RowConvOp should not be null.");
30+
"Output(Out) of SamplingIdOp should not be null.");
3131

3232
auto input_dims = ctx->GetInputDim("X");
3333

@@ -43,8 +43,7 @@ class SamplingIdOpMaker : public framework::OpProtoAndCheckerMaker {
4343
AddInput("X",
4444
"The input tensor of softmax. "
4545
"2-D with shape [batch_size, input_feature_dimensions].");
46-
AddOutput("Out", "Sliced data tensor.");
47-
46+
AddOutput("Out", "SamplingId data tensor.");
4847
AddComment(R"DOC(
4948
SamplingId Operator.
5049
@brief A layer for sampling id from multinomial distribution from the

paddle/fluid/operators/sampling_id_op.cu

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,6 @@ limitations under the License. */
1616
#include <vector>
1717
#include "paddle/fluid/operators/sampling_id_op.h"
1818

19-
namespace paddle {
20-
namespace operators {
21-
22-
using Tensor = framework::Tensor;
23-
24-
class SamplingIdOp : public framework::OperatorWithKernel {
25-
public:
26-
using framework::OperatorWithKernel::OperatorWithKernel;
27-
void InferShape(framework::InferShapeContext *ctx) const override {}
28-
}
29-
} // namespace operators
30-
} // namespace paddle
31-
3219
namespace ops = paddle::operators;
3320
REGISTER_OP_CUDA_KERNEL(
3421
sampling_id,

paddle/fluid/operators/sampling_id_op.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class SamplingIdKernel : public framework::OpKernel<T> {
5050
std::vector<int64_t> out_dim;
5151
out_dim.push_back(static_cast<int64_t>(batch_size));
5252

53-
Tensor* output = context.Output<Tensor>("Output");
53+
Tensor* output = context.Output<Tensor>("Out");
5454
output->Resize(framework::make_ddim(out_dim));
5555
output->mutable_data<T>(context.GetPlace());
5656
framework::TensorFromVector(ids, context.device_context(), output);
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
import numpy as np
17+
from op_test import OpTest
18+
19+
import paddle.fluid.core as core
20+
from paddle.fluid.op import Operator
21+
22+
23+
class TestSamplingIdOp(OpTest):
24+
def setUp(self):
25+
self.op_type = "sampling_id"
26+
self.use_mkldnn = False
27+
self.init_kernel_type()
28+
X = np.random.random((3, 4)).astype('float32')
29+
self.inputs = {"X": X}
30+
Y = np.random.random(3).astype('float32')
31+
self.outputs = {'Out': Y}
32+
self.attrs = {'use_mkldnn': self.use_mkldnn}
33+
34+
def test_check_output(self):
35+
self.check_output()
36+
37+
def test_check_grad(self):
38+
self.check_grad(['X'], 'Out')
39+
40+
def init_kernel_type(self):
41+
pass
42+
43+
44+
if __name__ == "__main__":
45+
unittest.main()

0 commit comments

Comments
 (0)