Skip to content

Commit f2f20a5

Browse files
authored
Fix beam search acc issue in workgroup reduce (#3771) (#3796)
* Fix beam search acc issue in workgroup reduce (#3771) Signed-off-by: majing <[email protected]>
1 parent bd034e7 commit f2f20a5

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

csrc/gpu/aten/operators/BeamSearch.h

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,8 @@ void beam_search_topk_stage2(
201201
int64_t batch_size,
202202
int32_t num_wg_per_beam) {
203203
int slm_size = (sizeof(int) + sizeof(scalar_t)) * MAX_K * num_wg_per_beam;
204+
int sub_wg_size = dpcppMaxSubGroupSize();
205+
int wg_size = (num_wg_per_beam + sub_wg_size - 1) / sub_wg_size * sub_wg_size;
204206
auto& dpcpp_queue = dpcppGetCurrentQueue();
205207
auto cgf = DPCPP_Q_CGF(cgh) {
206208
dpcpp_local_acc_t<char> shared(slm_size, cgh);
@@ -233,7 +235,7 @@ void beam_search_topk_stage2(
233235
TopK<scalar_t, MAX_K>,
234236
decltype(item),
235237
decltype(combine),
236-
1>(item, num_wg_per_beam, shared, value, combine);
238+
1>(item, item.get_local_range(0), shared, value, combine);
237239
TopK<scalar_t, MAX_K> total = value[0];
238240

239241
if (wi_id == 0) {
@@ -249,8 +251,8 @@ void beam_search_topk_stage2(
249251

250252
cgh.parallel_for(
251253
sycl::nd_range<1>(
252-
sycl::range<1>(batch_size * beam_size * num_wg_per_beam),
253-
sycl::range<1>(num_wg_per_beam)),
254+
sycl::range<1>(batch_size * beam_size * wg_size),
255+
sycl::range<1>(wg_size)),
254256
kfn);
255257
};
256258
DPCPP_Q_SUBMIT(dpcpp_queue, cgf);
@@ -667,7 +669,9 @@ void finalize(
667669
const int64_t out_sentence_num,
668670
const int64_t pad_token_id) {
669671
auto& dpcpp_queue = dpcppGetCurrentQueue();
670-
int32_t wg_size = 2 * beam_size;
672+
int32_t sub_wg_size = dpcppMaxSubGroupSize();
673+
int32_t wg_size =
674+
(2 * beam_size + sub_wg_size - 1) / sub_wg_size * sub_wg_size;
671675

672676
auto cgf = DPCPP_Q_CGF(cgh) {
673677
auto s_score = dpcpp_local_acc_t<scalar_t>(wg_size, cgh);
@@ -676,7 +680,7 @@ void finalize(
676680
const int32_t wi_id = item.get_local_id(0);
677681
const int32_t wg_id = item.get_group(0);
678682
if (wi_id < beam_hyps_num_beams[wg_id]) {
679-
s_score[wi_id] = beam_hyps_normed_scores[wg_id * wg_size + wi_id];
683+
s_score[wi_id] = beam_hyps_normed_scores[wg_id * 2 * beam_size + wi_id];
680684
} else {
681685
s_score[wi_id] = std::numeric_limits<scalar_t>::lowest();
682686
}

csrc/gpu/aten/operators/Reduce.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ inline at::detail::Array<arg_t, out_vec_sz> group_reduce(
4141
int sg_lid = sg.get_local_linear_id();
4242
int sg_gid = sg.get_group_linear_id();
4343
int sg_range = sg.get_group_range()[0];
44+
// group reduce requests workgroup size is multiple of subgroup size
45+
SYCL_KERNEL_ASSERT(
46+
wg_size % sg_size == 0 && "unsupported workgroup size for group reduce");
4447

4548
for (int offset = 1; offset < sg_size; offset <<= 1) {
4649
#pragma unroll(out_vec_sz)

0 commit comments

Comments
 (0)