File tree Expand file tree Collapse file tree 2 files changed +7
-6
lines changed Expand file tree Collapse file tree 2 files changed +7
-6
lines changed Original file line number Diff line number Diff line change @@ -36,15 +36,15 @@ class SamplingIdKernel : public framework::OpKernel<T> {
36
36
std::vector<T> ins_vector;
37
37
framework::TensorToVector (*input, context.device_context (), &ins_vector);
38
38
39
- unsigned int seed = static_cast <unsigned int >(ctx .Attr <int >(" seed" ));
39
+ unsigned int seed = static_cast <unsigned int >(context .Attr <int >(" seed" ));
40
40
std::minstd_rand engine;
41
41
if (seed == 0 ) {
42
42
seed = std::random_device ()();
43
43
}
44
44
engine.seed (seed);
45
45
std::uniform_real_distribution<T> dist (
46
- static_cast <T>(ctx .Attr <float >(" min" )),
47
- static_cast <T>(ctx .Attr <float >(" max" )));
46
+ static_cast <T>(context .Attr <float >(" min" )),
47
+ static_cast <T>(context .Attr <float >(" max" )));
48
48
49
49
std::vector<T> ids (batch_size);
50
50
for (size_t i = 0 ; i < batch_size; ++i) {
Original file line number Diff line number Diff line change @@ -39,7 +39,7 @@ namespace operators {
39
39
using Tensor = framework::Tensor;
40
40
41
41
template <typename T>
42
- class SamplingIdKernel : public framework ::OpKernel<T> {
42
+ class SamplingIdGPUKernel : public framework ::OpKernel<T> {
43
43
public:
44
44
void Compute (const framework::ExecutionContext& context) const override {
45
45
const Tensor* input = context.Input <Tensor>(" X" );
@@ -83,5 +83,6 @@ class SamplingIdKernel : public framework::OpKernel<T> {
83
83
} // namespace operators
84
84
} // namespace paddle
85
85
86
- REGISTER_OP_CPU_KERNEL (sampling_id, paddle::operators::SamplingIdKernel<float >,
87
- paddle::operators::SamplingIdKernel<double >);
86
+ REGISTER_OP_CUDA_KERNEL (sampling_id,
87
+ paddle::operators::SamplingIdGPUKernel<float >,
88
+ paddle::operators::SamplingIdGPUKernel<double >);
You can’t perform that action at this time.
0 commit comments