@@ -201,6 +201,8 @@ void beam_search_topk_stage2(
201201 int64_t batch_size,
202202 int32_t num_wg_per_beam) {
203203 int slm_size = (sizeof (int ) + sizeof (scalar_t )) * MAX_K * num_wg_per_beam;
204+ int sub_wg_size = dpcppMaxSubGroupSize ();
205+ int wg_size = (num_wg_per_beam + sub_wg_size - 1 ) / sub_wg_size * sub_wg_size;
204206 auto & dpcpp_queue = dpcppGetCurrentQueue ();
205207 auto cgf = DPCPP_Q_CGF (cgh) {
206208 dpcpp_local_acc_t <char > shared (slm_size, cgh);
@@ -233,7 +235,7 @@ void beam_search_topk_stage2(
233235 TopK<scalar_t , MAX_K>,
234236 decltype (item),
235237 decltype (combine),
236- 1 >(item, num_wg_per_beam , shared, value, combine);
238+ 1 >(item, item. get_local_range ( 0 ) , shared, value, combine);
237239 TopK<scalar_t , MAX_K> total = value[0 ];
238240
239241 if (wi_id == 0 ) {
@@ -249,8 +251,8 @@ void beam_search_topk_stage2(
249251
250252 cgh.parallel_for (
251253 sycl::nd_range<1 >(
252- sycl::range<1 >(batch_size * beam_size * num_wg_per_beam ),
253- sycl::range<1 >(num_wg_per_beam )),
254+ sycl::range<1 >(batch_size * beam_size * wg_size ),
255+ sycl::range<1 >(wg_size )),
254256 kfn);
255257 };
256258 DPCPP_Q_SUBMIT (dpcpp_queue, cgf);
@@ -667,7 +669,9 @@ void finalize(
667669 const int64_t out_sentence_num,
668670 const int64_t pad_token_id) {
669671 auto & dpcpp_queue = dpcppGetCurrentQueue ();
670- int32_t wg_size = 2 * beam_size;
672+ int32_t sub_wg_size = dpcppMaxSubGroupSize ();
673+ int32_t wg_size =
674+ (2 * beam_size + sub_wg_size - 1 ) / sub_wg_size * sub_wg_size;
671675
672676 auto cgf = DPCPP_Q_CGF (cgh) {
673677 auto s_score = dpcpp_local_acc_t <scalar_t >(wg_size, cgh);
@@ -676,7 +680,7 @@ void finalize(
676680 const int32_t wi_id = item.get_local_id (0 );
677681 const int32_t wg_id = item.get_group (0 );
678682 if (wi_id < beam_hyps_num_beams[wg_id]) {
679- s_score[wi_id] = beam_hyps_normed_scores[wg_id * wg_size + wi_id];
683+ s_score[wi_id] = beam_hyps_normed_scores[wg_id * 2 * beam_size + wi_id];
680684 } else {
681685 s_score[wi_id] = std::numeric_limits<scalar_t >::lowest ();
682686 }
0 commit comments