@@ -33,31 +33,33 @@ namespace operators {
33
33
34
34
template <typename T>
35
35
__global__ void NormalizeProbability (T* norm_probs, const T* in_data,
36
- T* sum_rows) {
36
+ T* sum_rows, int64_t num_distributions,
37
+ int64_t num_categories) {
37
38
int id = threadIdx .x + blockIdx .x * blockDim .x +
38
39
blockIdx .y * gridDim .x * blockDim .x ;
39
- PADDLE_ENFORCE (
40
- in_data[id] >= 0.0 ,
41
- " The input of multinomial distribution should be >= 0, but got %f." ,
42
- in_data[id]);
43
- PADDLE_ENFORCE (sum_rows[blockIdx .y ] > 0.0 ,
44
- " The sum of one multinomial distribution probability should "
45
- " be > 0, but got %f." ,
46
- sum_rows[blockIdx .y ]);
47
- norm_probs[id] = in_data[id] / sum_rows[blockIdx .y ];
40
+ if (id < num_distributions * num_categories) {
41
+ PADDLE_ENFORCE (
42
+ in_data[id] >= 0.0 ,
43
+ " The input of multinomial distribution should be >= 0, but got %f." ,
44
+ in_data[id]);
45
+ int64_t row_id = id / num_categories;
46
+ PADDLE_ENFORCE (sum_rows[row_id] > 0.0 ,
47
+ " The sum of one multinomial distribution probability should "
48
+ " be > 0, but got %f." ,
49
+ sum_rows[row_id]);
50
+ norm_probs[id] = in_data[id] / sum_rows[row_id];
51
+ }
48
52
}
49
53
50
54
template <typename T>
51
55
__global__ void GetCumulativeProbs (T* norm_probs_data,
52
56
int64_t num_distributions,
53
57
int64_t num_categories,
54
58
T* cumulative_probs) {
55
- for (int id = blockIdx .x ; id < num_distributions; id += gridDim .x ) {
56
- thrust::inclusive_scan (thrust::device,
57
- norm_probs_data + id * num_categories,
58
- norm_probs_data + (id + 1 ) * num_categories,
59
- cumulative_probs + id * num_categories);
60
- }
59
+ int id = blockIdx .x ;
60
+ thrust::inclusive_scan (thrust::device, norm_probs_data + id * num_categories,
61
+ norm_probs_data + (id + 1 ) * num_categories,
62
+ cumulative_probs + id * num_categories);
61
63
}
62
64
63
65
template <typename T>
@@ -108,23 +110,19 @@ __global__ void sampleMultinomialWithReplacement(
108
110
// use binary search to get the selected category sample id.
109
111
// let cumulative_probs[id-1] < rng_data < cumulative_probs[id].
110
112
111
- int idx = threadIdx .x + blockIdx .x * blockDim .x +
112
- blockIdx .y * gridDim .x * blockDim .x ;
113
-
114
113
// for every distribution
115
- for (int dist = blockIdx .y ; dist < num_distributions; dist += gridDim .y ) {
116
- // for every sample
117
- for (int sample = blockIdx .x * blockDim .x + threadIdx .x ;
118
- sample < num_samples; sample += blockDim .x * gridDim .x ) {
119
- T rng_number = rng_data[sample + dist * num_samples];
120
-
121
- // Find the bucket that a uniform random number lies in
122
- int selected_category = binarySearchFunctor<T>(
123
- cumulative_probs + dist * num_categories,
124
- norm_probs_data + dist * num_categories, num_categories, rng_number);
125
-
126
- out_data[sample + dist * num_samples] = selected_category;
127
- }
114
+ int dist = blockIdx .y ;
115
+ // for every sample
116
+ int sample = blockIdx .x * blockDim .x + threadIdx .x ;
117
+ if (sample < num_samples) {
118
+ T rng_number = rng_data[sample + dist * num_samples];
119
+
120
+ // Find the bucket that a uniform random number lies in
121
+ int selected_category = binarySearchFunctor<T>(
122
+ cumulative_probs + dist * num_categories,
123
+ norm_probs_data + dist * num_categories, num_categories, rng_number);
124
+
125
+ out_data[sample + dist * num_samples] = selected_category;
128
126
}
129
127
}
130
128
@@ -215,10 +213,11 @@ class MultinomialOpKernel<platform::CUDADeviceContext, T>
215
213
216
214
// number of threads in a block is min(num_categories, 512)
217
215
dim3 block_norm (num_categories < 512 ? num_categories : 512 );
218
- dim3 grid_norm ((num_categories - 1 ) / block_norm.x + 1 , num_distributions );
216
+ dim3 grid_norm ((num_distributions * num_categories - 1 ) / block_norm.x + 1 );
219
217
NormalizeProbability<
220
218
T><<<grid_norm, block_norm, 0 , ctx.cuda_device_context().stream()>>> (
221
- norm_probs_data, in_data, sum_rows_data);
219
+ norm_probs_data, in_data, sum_rows_data, num_distributions,
220
+ num_categories);
222
221
223
222
// Get cumulative probability of each distribution. It's the same function
224
223
// of
0 commit comments