Skip to content

Commit 9f09d68

Browse files
committed
add enforce
1 parent baa6273 commit 9f09d68

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

paddle/fluid/operators/sampling_id_op.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ class SamplingIdKernel : public framework::OpKernel<T> {
3333
const int batch_size = static_cast<int>(input->dims()[0]);
3434
const int width = static_cast<int>(input->dims()[1]);
3535

36+
PADDLE_ENFORCE_GE(batch_size, 0,
37+
"batch_size(dims[0]) must be nonnegative.");
38+
PADDLE_ENFORCE_GE(width, 0, "width(dims[1]) must be nonnegative.");
39+
3640
std::vector<T> ins_vector;
3741
framework::TensorToVector(*input, context.device_context(), &ins_vector);
3842

paddle/fluid/operators/sampling_id_op.cu

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ class SamplingIdGPUKernel : public framework::OpKernel<T> {
4646
const int batch_size = static_cast<int>(input->dims()[0]);
4747
const int width = static_cast<int>(input->dims()[1]);
4848

49+
PADDLE_ENFORCE_GE(batch_size, 0,
50+
"batch_size(dims[0]) must be nonnegative.");
51+
PADDLE_ENFORCE_GE(width, 0, "width(dims[1]) must be nonnegative.");
52+
4953
std::vector<T> ins_vector;
5054
framework::TensorToVector(*input, context.device_context(), &ins_vector);
5155

@@ -56,10 +60,11 @@ class SamplingIdGPUKernel : public framework::OpKernel<T> {
5660
}
5761
T min = static_cast<T>(context.Attr<float>("min"));
5862
T max = static_cast<T>(context.Attr<float>("max"));
63+
UniformGenerator<T> gen = UniformGenerator<T>(min, max, seed);
5964

6065
std::vector<T> ids(batch_size);
6166
for (size_t i = 0; i < batch_size; ++i) {
62-
T r = UniformGenerator<T>(min, max, seed);
67+
T r = gen(0);
6368
int idx = width - 1;
6469
for (int j = 0; j < width; ++j) {
6570
if ((r -= ins_vector[i * width + j]) < 0) {

0 commit comments

Comments
 (0)