Skip to content

Commit 271a388

Browse files
authored
Fix FMHA BWD Nan issue when seq length is odd (#370)
1 parent 18e0e04 commit 271a388

File tree

2 files changed

+30
-21
lines changed

2 files changed

+30
-21
lines changed

xla/service/gpu/xetla/sdp/fmha_backward.h

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,17 @@ class fmha_backward_t {
518518
bias_op(*rP, ctx.mem_desc_Bij.coord, bias_args);
519519
}
520520

521+
// Mask the logits of the QK matrix that exceed the sequence length of K by
522+
// setting them to -inf. This ensures that the attention scores for these
523+
// logits will be 0.
524+
using tile_mask = tile_mask_t<tile_P_t>;
525+
uint32_t sg_startT = ctx.startT + tile_offset_x;
526+
uint32_t remainT =
527+
std::max(static_cast<int>(args.uT) - static_cast<int>(sg_startT), 0);
528+
if (remainT < kSgBc) {
529+
tile_mask::padding_mask(*rP, remainT);
530+
}
531+
521532
subgroup::tile_broadcast_op<subgroup::tile_minus, tile_P_t>(*rP,
522533
l_load.reg);
523534
rP->reg = xetla_exp<accum_t>(rP->reg);
@@ -868,9 +879,9 @@ void fmha_backward_impl(sycl::queue& q, T* query, T* key, T* value, T* out,
868879
cgh.parallel_for<
869880
class FmhaBackwardDotDOO<fmha_policy, T, kUseBias, kIsDropout>>(
870881
NdRange0, [=](sycl::nd_item<3> item) SYCL_ESIMD_KERNEL {
871-
sycl::nd_item<3> ei(item);
872-
fmha_bwd_dot_do_op_t fmha_bwd_dot_do_o_op;
873-
fmha_bwd_dot_do_o_op(ei, args);
882+
sycl::nd_item<3> ei(item);
883+
fmha_bwd_dot_do_op_t fmha_bwd_dot_do_o_op;
884+
fmha_bwd_dot_do_o_op(ei, args);
874885
});
875886
});
876887

@@ -882,9 +893,9 @@ void fmha_backward_impl(sycl::queue& q, T* query, T* key, T* value, T* out,
882893
cgh.parallel_for<
883894
class FmhaBackwardKernel<fmha_policy, T, kUseBias, kIsDropout>>(
884895
NdRange1, [=](sycl::nd_item<3> item) SYCL_ESIMD_KERNEL {
885-
sycl::nd_item<3> ei(item);
886-
fmha_backward_op_t fmha_bwd_op;
887-
fmha_bwd_op(ei, args);
896+
sycl::nd_item<3> ei(item);
897+
fmha_backward_op_t fmha_bwd_op;
898+
fmha_bwd_op(ei, args);
888899
});
889900
});
890901

@@ -896,9 +907,9 @@ void fmha_backward_impl(sycl::queue& q, T* query, T* key, T* value, T* out,
896907
cgh.parallel_for<
897908
class FmhaBackwardConvertDQ<fmha_policy, T, kUseBias, kIsDropout>>(
898909
NdRange2, [=](sycl::nd_item<3> item) SYCL_ESIMD_KERNEL {
899-
sycl::nd_item<3> ei(item);
900-
fmha_bwd_convert_dq_op_t fmha_bwd_convert_dq_op;
901-
fmha_bwd_convert_dq_op(ei, args);
910+
sycl::nd_item<3> ei(item);
911+
fmha_bwd_convert_dq_op_t fmha_bwd_convert_dq_op;
912+
fmha_bwd_convert_dq_op(ei, args);
902913
});
903914
});
904915
}

xla/service/gpu/xetla/sdp/fmha_utils.h

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -207,41 +207,39 @@ struct group_row_reduce_t {
207207
}
208208
};
209209

210+
// Set default load type to block_2d because only the block_2d load/store will
211+
// ensure boundary safety.
210212
template <typename scalar_t, typename tile_desc_t, typename mem_desc_t>
211213
void store_tile(subgroup::tile_t<scalar_t, tile_desc_t>* src, mem_desc_t dst) {
212-
using store_t = subgroup::mem_payload_t<
213-
mem_desc_t, tile_desc_t,
214-
subgroup::msg_type_v<tile_desc_t, mem_desc_t::space>, gpu_arch::Xe>;
214+
using store_t = subgroup::mem_payload_t<mem_desc_t, tile_desc_t,
215+
msg_type::block_2d, gpu_arch::Xe>;
215216
store_t store(dst);
216217
subgroup::tile_store(*src, store);
217218
}
218219

219220
template <typename scalar_t, typename tile_desc_t, typename mem_desc_t>
220221
void store_tile(subgroup::tile_t<scalar_t, tile_desc_t>* src, mem_desc_t dst,
221222
int32_t tile_offset_x, int32_t tile_offset_y) {
222-
using store_t = subgroup::mem_payload_t<
223-
mem_desc_t, tile_desc_t,
224-
subgroup::msg_type_v<tile_desc_t, mem_desc_t::space>, gpu_arch::Xe>;
223+
using store_t = subgroup::mem_payload_t<mem_desc_t, tile_desc_t,
224+
msg_type::block_2d, gpu_arch::Xe>;
225225
dst.update_coord(tile_offset_x, tile_offset_y);
226226
store_t store(dst);
227227
subgroup::tile_store(*src, store);
228228
}
229229

230230
template <typename scalar_t, typename tile_desc_t, typename mem_desc_t>
231231
void load_tile(subgroup::tile_t<scalar_t, tile_desc_t>* dst, mem_desc_t src) {
232-
using load_t = subgroup::mem_payload_t<
233-
mem_desc_t, tile_desc_t,
234-
subgroup::msg_type_v<tile_desc_t, mem_desc_t::space>, gpu_arch::Xe>;
232+
using load_t = subgroup::mem_payload_t<mem_desc_t, tile_desc_t,
233+
msg_type::block_2d, gpu_arch::Xe>;
235234
load_t load(src);
236235
subgroup::tile_load(*dst, load);
237236
}
238237

239238
template <typename scalar_t, typename tile_desc_t, typename mem_desc_t>
240239
void load_tile(subgroup::tile_t<scalar_t, tile_desc_t>* dst, mem_desc_t src,
241240
int32_t tile_offset_x, int32_t tile_offset_y) {
242-
using load_t = subgroup::mem_payload_t<
243-
mem_desc_t, tile_desc_t,
244-
subgroup::msg_type_v<tile_desc_t, mem_desc_t::space>, gpu_arch::Xe>;
241+
using load_t = subgroup::mem_payload_t<mem_desc_t, tile_desc_t,
242+
msg_type::block_2d, gpu_arch::Xe>;
245243
src.update_coord(tile_offset_x, tile_offset_y);
246244
load_t load(src);
247245
subgroup::tile_load(*dst, load);

0 commit comments

Comments
 (0)