Skip to content

Commit 470fb7c

Browse files
committed
bug fix
1 parent 60dda7b commit 470fb7c

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

paddle/fluid/operators/sampling_id_op.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,15 @@ class SamplingIdKernel : public framework::OpKernel<T> {
3636
std::vector<T> ins_vector;
3737
framework::TensorToVector(*input, context.device_context(), &ins_vector);
3838

39-
unsigned int seed = static_cast<unsigned int>(ctx.Attr<int>("seed"));
39+
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
4040
std::minstd_rand engine;
4141
if (seed == 0) {
4242
seed = std::random_device()();
4343
}
4444
engine.seed(seed);
4545
std::uniform_real_distribution<T> dist(
46-
static_cast<T>(ctx.Attr<float>("min")),
47-
static_cast<T>(ctx.Attr<float>("max")));
46+
static_cast<T>(context.Attr<float>("min")),
47+
static_cast<T>(context.Attr<float>("max")));
4848

4949
std::vector<T> ids(batch_size);
5050
for (size_t i = 0; i < batch_size; ++i) {

paddle/fluid/operators/sampling_id_op.cu

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ namespace operators {
3939
using Tensor = framework::Tensor;
4040

4141
template <typename T>
42-
class SamplingIdKernel : public framework::OpKernel<T> {
42+
class SamplingIdGPUKernel : public framework::OpKernel<T> {
4343
public:
4444
void Compute(const framework::ExecutionContext& context) const override {
4545
const Tensor* input = context.Input<Tensor>("X");
@@ -83,5 +83,6 @@ class SamplingIdKernel : public framework::OpKernel<T> {
8383
} // namespace operators
8484
} // namespace paddle
8585

86-
REGISTER_OP_CPU_KERNEL(sampling_id, paddle::operators::SamplingIdKernel<float>,
87-
paddle::operators::SamplingIdKernel<double>);
86+
REGISTER_OP_CUDA_KERNEL(sampling_id,
87+
paddle::operators::SamplingIdGPUKernel<float>,
88+
paddle::operators::SamplingIdGPUKernel<double>);

0 commit comments

Comments
 (0)