Skip to content

Commit 066452f

Browse files
committed
add sliding window support for Gemma3
1 parent a93266a commit 066452f

File tree

2 files changed

+66
-3
lines changed

2 files changed

+66
-3
lines changed

tools/llm/torchtrt_ext/register_sdpa.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import copy
22
import logging
33
import operator
4+
from re import I
45
from typing import Callable, Sequence, Tuple
56

67
import torch
@@ -89,7 +90,9 @@ def replace_variants_of_sdpa(
8990
logger.warning(
9091
f"This current version of SDPA converter only supports attn_mask = None, dropout_p = 0.0 and is_causal = True configuration. This could cause issues with accuracy for models with different configurations."
9192
)
92-
modified_input_args = (query, key, value, None, dropout_p, True)
93+
# TODO: lan to figure out why is_causal is always False in google/gemma-3-1b-it, as in the config file it should be every 5 sliding window layer followed by a full attention layer
94+
# also to figure out why the attn_mask passed in from transformers is not working
95+
modified_input_args = (query, key, value, None, dropout_p, is_causal)
9396
# Create a new node with torch.nn.functional.scaled_dot_product_attention
9497
# The input args is (query, key, value, is_causal). kwargs has scale
9598
with gm.graph.inserting_after(node):

tools/llm/torchtrt_ext/sdpa_converter.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,53 @@ def tril(
2727
name: str,
2828
row: TRTTensor,
2929
col: TRTTensor,
30+
sliding_window_size: Optional[int] = None,
3031
) -> TRTTensor:
32+
33+
row_arange_tensor = impl.arange.arange(
34+
ctx, target, source_ir, name + "_arange_row", start=0, end=row, step=1
35+
)
36+
col_arange_tensor = impl.arange.arange(
37+
ctx, target, source_ir, name + "_arange_col", start=0, end=col, step=1
38+
)
39+
row_arange_tensor = impl.unsqueeze.unsqueeze(
40+
ctx, target, source_ir, name + "_unsqueeze_row", row_arange_tensor, -1
41+
)
42+
col_arange_tensor = impl.unsqueeze.unsqueeze(
43+
ctx, target, source_ir, name + "_unsqueeze_col", col_arange_tensor, 0
44+
)
45+
# sub will return the following mask tensor:
46+
# [[0, -1, -2, -3],
47+
# [1, 0, -1, -2],
48+
# [2, 1, 0, -1],
49+
# [3, 2, 1, 0]]
50+
mask = impl.elementwise.sub(
51+
ctx, target, source_ir, name + "_sub", row_arange_tensor, col_arange_tensor
52+
)
53+
ge_0_mask = impl.elementwise.ge(ctx, target, source_ir, name + "_ge_0", mask, 0.0)
54+
if sliding_window_size is None:
55+
# return the following lower triangular mask includes the main diagonal:
56+
# 0 ■ ⬚ ⬚ ⬚ ⬚ tensor([[[[ True, False, False, False, False],
57+
# 1 ■ ■ ⬚ ⬚ ⬚ [ True, True, False, False, False],
58+
# 2 ■ ■ ■ ⬚ ⬚ [ True, True, True, False, False],
59+
# 3 ■ ■ ■ ■ ⬚ [ True, True, True, True, False],
60+
# 4 ■ ■ ■ ■ ■ [ True, True, True, True, True]]]])
61+
return ge_0_mask
62+
63+
lt_window_mask = impl.elementwise.lt(
64+
ctx, target, source_ir, name + "_lt_window_size", mask, sliding_window_size
65+
)
66+
mask = impl.elementwise.logical_and(
67+
ctx, target, source_ir, name + "_logical_and", ge_0_mask, lt_window_mask
68+
)
69+
# return the following mask if sliding_window_size is 3:
70+
# 0 ■ ⬚ ⬚ ⬚ ⬚ tensor([[[[ True, False, False, False, False],
71+
# 1 ■ ■ ⬚ ⬚ ⬚ [ True, True, False, False, False],
72+
# 2 ■ ■ ■ ⬚ ⬚ [ True, True, True, False, False],
73+
# 3 ⬚ ■ ■ ■ ⬚ [False, True, True, True, False],
74+
# 4 ⬚ ⬚ ■ ■ ■ [False, False, True, True,True]]]])
75+
return mask
76+
3177
row_arange_tensor = impl.arange.arange(
3278
ctx, target, source_ir, name + "_arange_row", start=0, end=row, step=1
3379
)
@@ -66,7 +112,7 @@ def scaled_dot_product_attention(
66112
# TODO: remove this once we have a better way to handle the causal mask
67113
scale = kwargs.get("scale", None)
68114
source_ir = SourceIR.ATEN
69-
is_causal = True
115+
70116
# implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
71117
use_fp32_acc = kwargs.get("use_fp32_acc", False)
72118
query_dtype = query.dtype
@@ -136,7 +182,21 @@ def scaled_dot_product_attention(
136182
S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, 2)
137183

138184
# generate the mask tensor
139-
tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S)
185+
if is_causal:
186+
tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S)
187+
else:
188+
# hard code the sliding window size to 512 for now
189+
tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S, 512)
190+
# TODO: lan to figure out why attn_mask passed in from transformers is not working
191+
# tried both 2d and 4d, but both are not working, hence the following code is commented out
192+
# assert len(attn_mask.shape) in [2, 4], f"attn_mask must be 2D or 4D, but got {attn_mask.shape=}"
193+
# if len(attn_mask.shape) == 4:
194+
# if attn_mask.shape[0] != 1:
195+
# attn_mask = impl.slice.slice_op(ctx, target, source_ir, name + "_slice", attn_mask, 0, 0, 1, 1)
196+
# if attn_mask.shape[1] != 1:
197+
# attn_mask = impl.slice.slice_op(ctx, target, source_ir, name + "_slice", attn_mask, 1, 0, 1, 1)
198+
# attn_mask = impl.squeeze.squeeze(ctx, target, source_ir, name + "_squeeze", attn_mask, (0, 1))
199+
# tril_tensor = attn_mask
140200

141201
temp_mask = impl.unary.logical_not(
142202
ctx, target, source_ir, name + "_logical_not", tril_tensor

0 commit comments

Comments
 (0)