Skip to content

Commit b41bb09

Browse files
rogerxfeng8WeiZhu
andauthored
1. fix nan in alibi: kernel assume alibi is same dtype as key/value, pass f32 alibi but interprete as f16 cause nan (#5068) (#5080)
2. always use 2d load for bias, for boundary check Co-authored-by: WeiZhu <[email protected]>
1 parent f5a6a2b commit b41bb09

File tree

3 files changed

+12
-11
lines changed

3 files changed

+12
-11
lines changed

csrc/gpu/aten/operators/xetla/kernels/SDP/fmha_forward.hpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -396,15 +396,16 @@ class fmha_forward_t {
396396
// B, N, 1, T
397397
// gid * T + startT
398398
if constexpr (kUseAlibi && !kVarlen) {
399-
int32_t batch_start = gid * args.uAT;
400-
int32_t start_x = batch_start + startT;
401-
uint32_t end_x = startT + kBc;
399+
int32_t start_x = startT;
400+
uint32_t end_x = start_x + kBc;
402401
uint32_t boundary_x = args.uT;
403402
end_x = end_x > boundary_x ? boundary_x : end_x;
404-
end_x += batch_start;
403+
404+
int32_t start_y = gid;
405+
uint32_t end_y = start_y + 1;
405406

406407
mem_desc_Ai.init(
407-
args.A_ptr, {end_x, 1, args.uAT * args.uN * args.uB}, {start_x, 0});
408+
args.A_ptr, {end_x, end_y, args.uAT}, {start_x, start_y});
408409
}
409410

410411
// B, N or N

csrc/gpu/aten/operators/xetla/kernels/include/subgroup/tile/impl/tile_op_functor.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,7 @@ struct bias_add_op_t<
609609
using bias_payload_t = mem_payload_t<
610610
mem_desc_bias_t,
611611
bias_tile_desc_t,
612-
msg_type_v<bias_tile_desc_t, mem_desc_bias_t>,
612+
msg_type::block_2d,
613613
arch_tag>;
614614
coord_t bias_coord(coord.x, 0);
615615
mem_desc_bias_t mem_desc_bias(args.base, args.shape, bias_coord);

intel_extension_for_pytorch/transformers/models/xpu/optimize_transformers/modules/transformer_modules/XPUAttentionfp16.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def all_reduce_if_necessary(self, reduce_target):
177177
dist.all_reduce(reduce_target, group=self.tp_group)
178178
return
179179

180-
def get_blocked_alibi(self, alibi, seq_len):
180+
def get_blocked_alibi(self, alibi, seq_len, dtype):
181181
if self.layer_idx == 0:
182182
cache_len = (
183183
self.max_position
@@ -190,7 +190,7 @@ def get_blocked_alibi(self, alibi, seq_len):
190190
cache_len,
191191
] # [beam*num_head, q_len, kv_len]
192192
IPEXAttention.blocked_alibi = torch.empty(
193-
shape, device=alibi.device, dtype=alibi.dtype
193+
shape, device=alibi.device, dtype=dtype
194194
)
195195
kv_len = alibi.shape[2]
196196
IPEXAttention.blocked_alibi[:, :, 0:kv_len] = alibi
@@ -228,13 +228,14 @@ def sdp(self, query, key, value, past_key_value, attention_mask, head_mask, alib
228228

229229
# if attention_mask is not None:
230230
# attention_mask = self.get_blocked_attn_mask(attention_mask)
231+
# use key/value's data type as alibi's data type
231232
if alibi is not None:
232233
if isinstance(past_key_value, IPEXStaticCache):
233234
alibi = self.get_blocked_alibi(
234-
alibi, past_key_value.get_seq_length() + key.size(2)
235+
alibi, past_key_value.get_seq_length() + key.size(2), key.dtype
235236
)
236237
else:
237-
alibi = self.get_blocked_alibi(alibi, key.size(2))
238+
alibi = self.get_blocked_alibi(alibi, key.size(2), key.dtype)
238239
if (
239240
self.beam_idx is not None
240241
and query.size(-2) == 1
@@ -304,7 +305,6 @@ def sdp(self, query, key, value, past_key_value, attention_mask, head_mask, alib
304305
if not self.is_beam_search() and query.size(-2) == 1:
305306
key = key.permute(2, 0, 1, 3).contiguous().permute(1, 2, 0, 3)
306307
value = value.permute(2, 0, 1, 3).contiguous().permute(1, 2, 0, 3)
307-
308308
attention_output = torch.xpu.IpexSDP(
309309
query,
310310
key,

0 commit comments

Comments
 (0)