Skip to content

Commit 779e174

Browse files
committed
resolve the attn_mask nan issue
1 parent 47abe2c commit 779e174

File tree

2 files changed

+49
-44
lines changed

2 files changed

+49
-44
lines changed

tools/llm/test_trt_sdpa.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def forward(self, query, key, value, attn_mask):
1313
enable_flash=False,
1414
enable_math=False,
1515
enable_mem_efficient=True,
16+
enable_cudnn=False,
1617
):
1718
return torch.nn.functional.scaled_dot_product_attention(
1819
query, key, value, attn_mask, 0.0, False, scale=0.0625

tools/llm/torchtrt_ext/sdpa_converter.py

Lines changed: 48 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -162,11 +162,7 @@ def scaled_dot_product_attention(
162162
if S < 0:
163163
S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, 2)
164164
# generate the mask tensor
165-
if is_causal:
166-
tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S)
167-
else:
168-
# TODO: lan to figure out why attn_mask passed in from transformers is not working
169-
# tried both 2d and 4d, but both are not working
165+
if not is_causal:
170166
assert len(attn_mask.shape) in [
171167
2,
172168
4,
@@ -183,48 +179,56 @@ def scaled_dot_product_attention(
183179
attn_mask = impl.squeeze.squeeze(
184180
ctx, target, source_ir, name + "_squeeze", attn_mask, (0, 1)
185181
)
186-
tril_tensor = attn_mask
187-
188-
# generate attn_bias via where instead of (logical_and, sub, log) to see whether nan is related to this
189-
attn_bias_via_where = True
190-
if attn_bias_via_where:
191-
attn_bias = impl.condition.where(
192-
ctx,
193-
target,
194-
source_ir,
195-
name + "_where",
196-
torch.tensor(0.0, dtype=torch.float32).cuda(),
197-
torch.tensor(-float("inf"), dtype=torch.float32).cuda(),
198-
tril_tensor,
199-
)
182+
attn_bias = attn_mask
200183
else:
201-
temp_mask = impl.unary.logical_not(
202-
ctx, target, source_ir, name + "_logical_not", tril_tensor
203-
)
184+
tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S)
185+
# generate attn_bias via where instead of (logical_and, sub, log) to see whether nan is related to this
186+
attn_bias_via_where = True
187+
if attn_bias_via_where:
188+
attn_bias = impl.condition.where(
189+
ctx,
190+
target,
191+
source_ir,
192+
name + "_where",
193+
torch.tensor(0.0, dtype=torch.float32).cuda(),
194+
torch.tensor(-float("inf"), dtype=torch.float32).cuda(),
195+
tril_tensor,
196+
)
197+
else:
198+
temp_mask = impl.unary.logical_not(
199+
ctx, target, source_ir, name + "_logical_not", tril_tensor
200+
)
204201

205-
# This need_mask determines if we want to use the causal mask or not
206-
# When KV caching is enabled, L = 1 and != S. In this case, we shouldn't use the causal mask.
207-
# So need_mask will be all False values in this case.
208-
# TODO: Implement more general case where L != 1 and S != L
209-
need_mask = impl.elementwise.eq(ctx, target, source_ir, name + "_eq", L, S)
210-
temp_mask = impl.elementwise.logical_and(
211-
ctx, target, source_ir, name + "_logical_and", need_mask, temp_mask
212-
)
213-
temp_mask_casted = cast_trt_tensor(
214-
ctx, temp_mask, query_dtype, name + "_casted_bool", target, source_ir
215-
)
202+
# This need_mask determines if we want to use the causal mask or not
203+
# When KV caching is enabled, L = 1 and != S. In this case, we shouldn't use the causal mask.
204+
# So need_mask will be all False values in this case.
205+
# TODO: Implement more general case where L != 1 and S != L
206+
need_mask = impl.elementwise.eq(
207+
ctx, target, source_ir, name + "_eq", L, S
208+
)
209+
temp_mask = impl.elementwise.logical_and(
210+
ctx, target, source_ir, name + "_logical_and", need_mask, temp_mask
211+
)
212+
temp_mask_casted = cast_trt_tensor(
213+
ctx,
214+
temp_mask,
215+
query_dtype,
216+
name + "_casted_bool",
217+
target,
218+
source_ir,
219+
)
216220

217-
one_minus_temp_mask = impl.elementwise.sub(
218-
ctx,
219-
target,
220-
source_ir,
221-
name + "_one_minus_temp_mask",
222-
1.0,
223-
temp_mask_casted,
224-
)
225-
attn_bias = impl.unary.log(
226-
ctx, target, source_ir, name + "_log", one_minus_temp_mask
227-
)
221+
one_minus_temp_mask = impl.elementwise.sub(
222+
ctx,
223+
target,
224+
source_ir,
225+
name + "_one_minus_temp_mask",
226+
1.0,
227+
temp_mask_casted,
228+
)
229+
attn_bias = impl.unary.log(
230+
ctx, target, source_ir, name + "_log", one_minus_temp_mask
231+
)
228232

229233
scaled_add_attn_bias = impl.elementwise.add(
230234
ctx, target, source_ir, name + "_attn_bias_add", mm, attn_bias

0 commit comments

Comments
 (0)