Skip to content

Commit 822496f

Browse files
committed
merge cpu and gpu
1 parent 9f09d68 commit 822496f

File tree

3 files changed

+85
-133
lines changed

3 files changed

+85
-133
lines changed

paddle/fluid/operators/sampling_id_op.cc

Lines changed: 1 addition & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -12,67 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#include <algorithm>
16-
#include <iostream>
17-
#include <iterator>
18-
#include <random>
19-
#include <sstream>
20-
#include <vector>
21-
#include "paddle/fluid/framework/op_registry.h"
15+
#include "paddle/fluid/operators/sampling_id_op.h"
2216

2317
namespace paddle {
2418
namespace operators {
2519

2620
using Tensor = framework::Tensor;
2721

28-
template <typename T>
29-
class SamplingIdKernel : public framework::OpKernel<T> {
30-
public:
31-
void Compute(const framework::ExecutionContext& context) const override {
32-
const Tensor* input = context.Input<Tensor>("X");
33-
const int batch_size = static_cast<int>(input->dims()[0]);
34-
const int width = static_cast<int>(input->dims()[1]);
35-
36-
PADDLE_ENFORCE_GE(batch_size, 0,
37-
"batch_size(dims[0]) must be nonnegative.");
38-
PADDLE_ENFORCE_GE(width, 0, "width(dims[1]) must be nonnegative.");
39-
40-
std::vector<T> ins_vector;
41-
framework::TensorToVector(*input, context.device_context(), &ins_vector);
42-
43-
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
44-
std::minstd_rand engine;
45-
if (seed == 0) {
46-
seed = std::random_device()();
47-
}
48-
engine.seed(seed);
49-
std::uniform_real_distribution<T> dist(
50-
static_cast<T>(context.Attr<float>("min")),
51-
static_cast<T>(context.Attr<float>("max")));
52-
53-
std::vector<T> ids(batch_size);
54-
for (size_t i = 0; i < batch_size; ++i) {
55-
T r = dist(engine);
56-
int idx = width - 1;
57-
for (int j = 0; j < width; ++j) {
58-
if ((r -= ins_vector[i * width + j]) < 0) {
59-
idx = j;
60-
break;
61-
}
62-
}
63-
ids[i] = ins_vector[i * width + idx];
64-
}
65-
66-
std::vector<int64_t> out_dim;
67-
out_dim.push_back(static_cast<int64_t>(batch_size));
68-
69-
Tensor* output = context.Output<Tensor>("Out");
70-
output->Resize(framework::make_ddim(out_dim));
71-
output->mutable_data<T>(context.GetPlace());
72-
framework::TensorFromVector(ids, context.device_context(), output);
73-
}
74-
};
75-
7622
class SamplingIdOp : public framework::OperatorWithKernel {
7723
public:
7824
using framework::OperatorWithKernel::OperatorWithKernel;

paddle/fluid/operators/sampling_id_op.cu

Lines changed: 4 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -11,83 +11,9 @@
1111
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
14-
#include <thrust/random.h>
15-
#include <thrust/transform.h>
16-
#include "paddle/fluid/framework/op_registry.h"
17-
#include "paddle/fluid/framework/operator.h"
1814

19-
template <typename T>
20-
struct UniformGenerator {
21-
T min_, max_;
22-
unsigned int seed_;
15+
#include "paddle/fluid/operators/sampling_id_op.h"
2316

24-
__host__ __device__ UniformGenerator(T min, T max, int seed)
25-
: min_(min), max_(max), seed_(seed) {}
26-
27-
__host__ __device__ T operator()(const unsigned int n) const {
28-
thrust::minstd_rand rng;
29-
rng.seed(seed_);
30-
thrust::uniform_real_distribution<T> dist(min_, max_);
31-
rng.discard(n);
32-
return dist(rng);
33-
}
34-
};
35-
36-
namespace paddle {
37-
namespace operators {
38-
39-
using Tensor = framework::Tensor;
40-
41-
template <typename T>
42-
class SamplingIdGPUKernel : public framework::OpKernel<T> {
43-
public:
44-
void Compute(const framework::ExecutionContext& context) const override {
45-
const Tensor* input = context.Input<Tensor>("X");
46-
const int batch_size = static_cast<int>(input->dims()[0]);
47-
const int width = static_cast<int>(input->dims()[1]);
48-
49-
PADDLE_ENFORCE_GE(batch_size, 0,
50-
"batch_size(dims[0]) must be nonnegative.");
51-
PADDLE_ENFORCE_GE(width, 0, "width(dims[1]) must be nonnegative.");
52-
53-
std::vector<T> ins_vector;
54-
framework::TensorToVector(*input, context.device_context(), &ins_vector);
55-
56-
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
57-
if (seed == 0) {
58-
std::random_device rd;
59-
seed = rd();
60-
}
61-
T min = static_cast<T>(context.Attr<float>("min"));
62-
T max = static_cast<T>(context.Attr<float>("max"));
63-
UniformGenerator<T> gen = UniformGenerator<T>(min, max, seed);
64-
65-
std::vector<T> ids(batch_size);
66-
for (size_t i = 0; i < batch_size; ++i) {
67-
T r = gen(0);
68-
int idx = width - 1;
69-
for (int j = 0; j < width; ++j) {
70-
if ((r -= ins_vector[i * width + j]) < 0) {
71-
idx = j;
72-
break;
73-
}
74-
}
75-
ids[i] = ins_vector[i * width + idx];
76-
}
77-
78-
std::vector<int64_t> out_dim;
79-
out_dim.push_back(static_cast<int64_t>(batch_size));
80-
81-
Tensor* output = context.Output<Tensor>("Out");
82-
output->Resize(framework::make_ddim(out_dim));
83-
output->mutable_data<T>(context.GetPlace());
84-
framework::TensorFromVector(ids, context.device_context(), output);
85-
}
86-
};
87-
88-
} // namespace operators
89-
} // namespace paddle
90-
91-
REGISTER_OP_CUDA_KERNEL(sampling_id,
92-
paddle::operators::SamplingIdGPUKernel<float>,
93-
paddle::operators::SamplingIdGPUKernel<double>);
17+
namespace ops = paddle::operators;
18+
REGISTER_OP_CUDA_KERNEL(sampling_id, paddle::operators::SamplingIdKernel<float>,
19+
paddle::operators::SamplingIdKernel<double>);
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <algorithm>
18+
#include <iostream>
19+
#include <iterator>
20+
#include <random>
21+
#include <sstream>
22+
#include <vector>
23+
24+
#include "paddle/fluid/framework/op_registry.h"
25+
26+
namespace paddle {
27+
namespace operators {
28+
29+
using Tensor = framework::Tensor;
30+
31+
template <typename T>
32+
class SamplingIdKernel : public framework::OpKernel<T> {
33+
public:
34+
void Compute(const framework::ExecutionContext& context) const override {
35+
const Tensor* input = context.Input<Tensor>("X");
36+
const int batch_size = static_cast<int>(input->dims()[0]);
37+
const int width = static_cast<int>(input->dims()[1]);
38+
39+
PADDLE_ENFORCE_GE(batch_size, 0,
40+
"batch_size(dims[0]) must be nonnegative.");
41+
PADDLE_ENFORCE_GE(width, 0, "width(dims[1]) must be nonnegative.");
42+
43+
std::vector<T> ins_vector;
44+
framework::TensorToVector(*input, context.device_context(), &ins_vector);
45+
46+
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
47+
std::minstd_rand engine;
48+
if (seed == 0) {
49+
seed = std::random_device()();
50+
}
51+
engine.seed(seed);
52+
std::uniform_real_distribution<T> dist(
53+
static_cast<T>(context.Attr<float>("min")),
54+
static_cast<T>(context.Attr<float>("max")));
55+
56+
std::vector<T> ids(batch_size);
57+
for (size_t i = 0; i < batch_size; ++i) {
58+
T r = dist(engine);
59+
int idx = width - 1;
60+
for (int j = 0; j < width; ++j) {
61+
if ((r -= ins_vector[i * width + j]) < 0) {
62+
idx = j;
63+
break;
64+
}
65+
}
66+
ids[i] = ins_vector[i * width + idx];
67+
}
68+
69+
std::vector<int64_t> out_dim;
70+
out_dim.push_back(static_cast<int64_t>(batch_size));
71+
72+
Tensor* output = context.Output<Tensor>("Out");
73+
output->Resize(framework::make_ddim(out_dim));
74+
output->mutable_data<T>(context.GetPlace());
75+
framework::TensorFromVector(ids, context.device_context(), output);
76+
}
77+
};
78+
79+
} // namespace operators
80+
} // namespace paddle

0 commit comments

Comments
 (0)