Skip to content

Commit 478f73c

Browse files
committed
merge header in cc
1 parent 64a4925 commit 478f73c

File tree

2 files changed

+52
-84
lines changed

2 files changed

+52
-84
lines changed

paddle/fluid/operators/sampling_id_op.cc

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,68 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#include "paddle/fluid/operators/sampling_id_op.h"
15+
#include <algorithm>
16+
#include <iostream>
17+
#include <iterator>
18+
#include <random>
19+
#include <sstream>
20+
#include <vector>
21+
#include "paddle/fluid/framework/op_registry.h"
1622

1723
namespace paddle {
1824
namespace operators {
1925

2026
using Tensor = framework::Tensor;
2127

28+
template <typename DeviceContext, typename T>
29+
class SamplingIdKernel : public framework::OpKernel<T> {
30+
public:
31+
void Compute(const framework::ExecutionContext& context) const override {
32+
const Tensor* input = context.Input<Tensor>("X");
33+
const int batch_size = static_cast<int>(input->dims()[0]);
34+
const int width = static_cast<int>(input->dims()[1]);
35+
36+
std::vector<T> ins_vector;
37+
framework::TensorToVector(*input, context.device_context(), &ins_vector);
38+
39+
std::vector<T> ids(batch_size);
40+
for (size_t i = 0; i < batch_size; ++i) {
41+
double r = getRandReal();
42+
int idx = width - 1;
43+
for (int j = 0; j < width; ++j) {
44+
if ((r -= ins_vector[i * width + j]) < 0) {
45+
idx = j;
46+
break;
47+
}
48+
}
49+
ids[i] = ins_vector[i * width + idx];
50+
}
51+
52+
std::vector<int64_t> out_dim;
53+
out_dim.push_back(static_cast<int64_t>(batch_size));
54+
55+
Tensor* output = context.Output<Tensor>("Out");
56+
output->Resize(framework::make_ddim(out_dim));
57+
output->mutable_data<T>(context.GetPlace());
58+
framework::TensorFromVector(ids, context.device_context(), output);
59+
}
60+
61+
private:
62+
double getRandReal() const {
63+
std::random_device
64+
rd; // Will be used to obtain a seed for the random number engine
65+
std::mt19937 gen(rd()); // Standard mersenne_twister_engine seeded with
66+
// rd()
67+
std::uniform_real_distribution<> dis(1.0, 2.0);
68+
return dis(gen);
69+
}
70+
};
71+
2272
class SamplingIdOp : public framework::OperatorWithKernel {
2373
public:
2474
using framework::OperatorWithKernel::OperatorWithKernel;
2575

26-
void InferShape(framework::InferShapeContext *ctx) const override {
76+
void InferShape(framework::InferShapeContext* ctx) const override {
2777
PADDLE_ENFORCE(ctx->HasInput("X"),
2878
"Input(X) of SamplingIdOp should not be null.");
2979
PADDLE_ENFORCE(ctx->HasOutput("Out"),

paddle/fluid/operators/sampling_id_op.h

Lines changed: 0 additions & 82 deletions
This file was deleted.

0 commit comments

Comments
 (0)