Skip to content

Commit d8ffb26

Browse files
authored
【Cherry-pick PR 36511】fix out_of_range bug of multinomial op's cuda kernel (#36511) (#36808)
Cherry-pick PR #36511
1 parent e3db65d commit d8ffb26

File tree

2 files changed

+41
-34
lines changed

2 files changed

+41
-34
lines changed

paddle/fluid/operators/multinomial_op.cu

Lines changed: 33 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -33,31 +33,33 @@ namespace operators {
3333

3434
template <typename T>
3535
__global__ void NormalizeProbability(T* norm_probs, const T* in_data,
36-
T* sum_rows) {
36+
T* sum_rows, int64_t num_distributions,
37+
int64_t num_categories) {
3738
int id = threadIdx.x + blockIdx.x * blockDim.x +
3839
blockIdx.y * gridDim.x * blockDim.x;
39-
PADDLE_ENFORCE(
40-
in_data[id] >= 0.0,
41-
"The input of multinomial distribution should be >= 0, but got %f.",
42-
in_data[id]);
43-
PADDLE_ENFORCE(sum_rows[blockIdx.y] > 0.0,
44-
"The sum of one multinomial distribution probability should "
45-
"be > 0, but got %f.",
46-
sum_rows[blockIdx.y]);
47-
norm_probs[id] = in_data[id] / sum_rows[blockIdx.y];
40+
if (id < num_distributions * num_categories) {
41+
PADDLE_ENFORCE(
42+
in_data[id] >= 0.0,
43+
"The input of multinomial distribution should be >= 0, but got %f.",
44+
in_data[id]);
45+
int64_t row_id = id / num_categories;
46+
PADDLE_ENFORCE(sum_rows[row_id] > 0.0,
47+
"The sum of one multinomial distribution probability should "
48+
"be > 0, but got %f.",
49+
sum_rows[row_id]);
50+
norm_probs[id] = in_data[id] / sum_rows[row_id];
51+
}
4852
}
4953

5054
template <typename T>
5155
__global__ void GetCumulativeProbs(T* norm_probs_data,
5256
int64_t num_distributions,
5357
int64_t num_categories,
5458
T* cumulative_probs) {
55-
for (int id = blockIdx.x; id < num_distributions; id += gridDim.x) {
56-
thrust::inclusive_scan(thrust::device,
57-
norm_probs_data + id * num_categories,
58-
norm_probs_data + (id + 1) * num_categories,
59-
cumulative_probs + id * num_categories);
60-
}
59+
int id = blockIdx.x;
60+
thrust::inclusive_scan(thrust::device, norm_probs_data + id * num_categories,
61+
norm_probs_data + (id + 1) * num_categories,
62+
cumulative_probs + id * num_categories);
6163
}
6264

6365
template <typename T>
@@ -108,23 +110,19 @@ __global__ void sampleMultinomialWithReplacement(
108110
// use binary search to get the selected category sample id.
109111
// let cumulative_probs[id-1] < rng_data < cumulative_probs[id].
110112

111-
int idx = threadIdx.x + blockIdx.x * blockDim.x +
112-
blockIdx.y * gridDim.x * blockDim.x;
113-
114113
// for every distribution
115-
for (int dist = blockIdx.y; dist < num_distributions; dist += gridDim.y) {
116-
// for every sample
117-
for (int sample = blockIdx.x * blockDim.x + threadIdx.x;
118-
sample < num_samples; sample += blockDim.x * gridDim.x) {
119-
T rng_number = rng_data[sample + dist * num_samples];
120-
121-
// Find the bucket that a uniform random number lies in
122-
int selected_category = binarySearchFunctor<T>(
123-
cumulative_probs + dist * num_categories,
124-
norm_probs_data + dist * num_categories, num_categories, rng_number);
125-
126-
out_data[sample + dist * num_samples] = selected_category;
127-
}
114+
int dist = blockIdx.y;
115+
// for every sample
116+
int sample = blockIdx.x * blockDim.x + threadIdx.x;
117+
if (sample < num_samples) {
118+
T rng_number = rng_data[sample + dist * num_samples];
119+
120+
// Find the bucket that a uniform random number lies in
121+
int selected_category = binarySearchFunctor<T>(
122+
cumulative_probs + dist * num_categories,
123+
norm_probs_data + dist * num_categories, num_categories, rng_number);
124+
125+
out_data[sample + dist * num_samples] = selected_category;
128126
}
129127
}
130128

@@ -215,10 +213,11 @@ class MultinomialOpKernel<platform::CUDADeviceContext, T>
215213

216214
// number of threads in a block is min(num_categories, 512)
217215
dim3 block_norm(num_categories < 512 ? num_categories : 512);
218-
dim3 grid_norm((num_categories - 1) / block_norm.x + 1, num_distributions);
216+
dim3 grid_norm((num_distributions * num_categories - 1) / block_norm.x + 1);
219217
NormalizeProbability<
220218
T><<<grid_norm, block_norm, 0, ctx.cuda_device_context().stream()>>>(
221-
norm_probs_data, in_data, sum_rows_data);
219+
norm_probs_data, in_data, sum_rows_data, num_distributions,
220+
num_categories);
222221

223222
// Get cumulative probability of each distribution. It's the same function
224223
// of

python/paddle/fluid/tests/unittests/test_multinomial_op.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,14 @@ def test_dygraph3(self):
141141
"replacement is False. categories can't be sampled repeatedly")
142142
paddle.enable_static()
143143

144+
def test_dygraph4(self):
145+
paddle.disable_static()
146+
logits = -1 * paddle.ones([2800])
147+
# Categorical.sample API will call multinomial op with replacement=True
148+
cat = paddle.distribution.Categorical(logits.exp())
149+
cat.sample([1])
150+
paddle.enable_static()
151+
144152
def test_static(self):
145153
paddle.enable_static()
146154
startup_program = fluid.Program()

0 commit comments

Comments
 (0)