File tree Expand file tree Collapse file tree 2 files changed +10
-1
lines changed Expand file tree Collapse file tree 2 files changed +10
-1
lines changed Original file line number Diff line number Diff line change @@ -33,6 +33,10 @@ class SamplingIdKernel : public framework::OpKernel<T> {
33
33
const int batch_size = static_cast <int >(input->dims ()[0 ]);
34
34
const int width = static_cast <int >(input->dims ()[1 ]);
35
35
36
+ PADDLE_ENFORCE_GE (batch_size, 0 ,
37
+ " batch_size(dims[0]) must be nonnegative." );
38
+ PADDLE_ENFORCE_GE (width, 0 , " width(dims[1]) must be nonnegative." );
39
+
36
40
std::vector<T> ins_vector;
37
41
framework::TensorToVector (*input, context.device_context (), &ins_vector);
38
42
Original file line number Diff line number Diff line change @@ -46,6 +46,10 @@ class SamplingIdGPUKernel : public framework::OpKernel<T> {
46
46
const int batch_size = static_cast <int >(input->dims ()[0 ]);
47
47
const int width = static_cast <int >(input->dims ()[1 ]);
48
48
49
+ PADDLE_ENFORCE_GE (batch_size, 0 ,
50
+ " batch_size(dims[0]) must be nonnegative." );
51
+ PADDLE_ENFORCE_GE (width, 0 , " width(dims[1]) must be nonnegative." );
52
+
49
53
std::vector<T> ins_vector;
50
54
framework::TensorToVector (*input, context.device_context (), &ins_vector);
51
55
@@ -56,10 +60,11 @@ class SamplingIdGPUKernel : public framework::OpKernel<T> {
56
60
}
57
61
T min = static_cast <T>(context.Attr <float >(" min" ));
58
62
T max = static_cast <T>(context.Attr <float >(" max" ));
63
+ UniformGenerator<T> gen = UniformGenerator<T>(min, max, seed);
59
64
60
65
std::vector<T> ids (batch_size);
61
66
for (size_t i = 0 ; i < batch_size; ++i) {
62
- T r = UniformGenerator<T>(min, max, seed );
67
+ T r = gen ( 0 );
63
68
int idx = width - 1 ;
64
69
for (int j = 0 ; j < width; ++j) {
65
70
if ((r -= ins_vector[i * width + j]) < 0 ) {
You can’t perform that action at this time.
0 commit comments