Skip to content

Commit a65f0f1

Browse files
committed
add test case
1 parent a58d17b commit a65f0f1

File tree

5 files changed

+136
-43
lines changed

5 files changed

+136
-43
lines changed

tools/llm/run_llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def compile_torchtrt(model, input_ids, args):
116116
use_fp32_acc=use_fp32_acc,
117117
device=DEVICE,
118118
disable_tf32=True,
119-
use_python_runtime=True,
119+
use_python_runtime=False,
120120
debug=args.debug,
121121
offload_module_to_cpu=True,
122122
min_block_size=args.min_block_size,

tools/llm/test_trt_sdpa.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import torch
2+
import torch_tensorrt
3+
from torch.export import Dim
4+
from torchtrt_ext import register_sdpa
5+
6+
7+
class SimpleNetwork(torch.nn.Module):
8+
def __init__(self):
9+
super(SimpleNetwork, self).__init__()
10+
11+
def forward(self, query, key, value, attn_mask):
12+
with torch.backends.cuda.sdp_kernel(
13+
enable_flash=False,
14+
enable_math=False,
15+
enable_mem_efficient=True,
16+
):
17+
return torch.nn.functional.scaled_dot_product_attention(
18+
query, key, value, attn_mask, 0.0, False, scale=0.0625
19+
)
20+
21+
22+
dtype = torch.float32
23+
24+
dyn_dim = Dim("dyn_dim", min=3, max=32)
25+
26+
query = torch.randn((1, 4, 13, 256), dtype=dtype).cuda()
27+
key = torch.randn((1, 4, 13, 256), dtype=dtype).cuda()
28+
value = torch.randn((1, 4, 13, 256), dtype=dtype).cuda()
29+
attn_mask = torch.ones((13, 13), dtype=torch.bool).tril(diagonal=0).cuda()
30+
inputs = (query, key, value, attn_mask)
31+
32+
model = SimpleNetwork().eval().cuda()
33+
output_pyt = model(*inputs)
34+
exp_program = torch.export.export(
35+
model,
36+
inputs,
37+
strict=False,
38+
dynamic_shapes={
39+
"query": {2: dyn_dim},
40+
"key": {2: dyn_dim},
41+
"value": {2: dyn_dim},
42+
"attn_mask": {0: dyn_dim, 1: dyn_dim},
43+
},
44+
)
45+
DEBUG_LOGGING_DIR = "./debug_logs"
46+
with torch_tensorrt.dynamo.Debugger(
47+
"graphs",
48+
logging_dir=DEBUG_LOGGING_DIR,
49+
capture_fx_graph_after=["complex_graph_detection"],
50+
save_engine_profile=True,
51+
profile_format="trex",
52+
engine_builder_monitor=True,
53+
):
54+
trt_model = torch_tensorrt.dynamo.compile(
55+
exp_program,
56+
inputs=inputs,
57+
enabled_precisions={dtype},
58+
min_block_size=1,
59+
cache_built_engines=False,
60+
reuse_cached_engines=False,
61+
truncate_double=True,
62+
use_python_runtime=False,
63+
)
64+
outputs_trt = trt_model(*inputs)
65+
breakpoint()
66+
assert torch.allclose(output_pyt, outputs_trt, rtol=1e-2, atol=1e-2)
67+
68+
print("Done")

tools/llm/torchtrt_ext/register_sdpa.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,9 @@ def replace_variants_of_sdpa(
8989
logger.warning(
9090
f"This current version of SDPA converter only supports attn_mask = None, dropout_p = 0.0 and is_causal = True configuration. This could cause issues with accuracy for models with different configurations."
9191
)
92-
# TODO: lan to figure out why is_causal is always False in google/gemma-3-1b-it, as in the config file it should be every 5 sliding window layer followed by a full attention layer
93-
# also to figure out why the attn_mask passed in from transformers is not working
94-
modified_input_args = (query, key, value, None, dropout_p, is_causal)
92+
# TODO: lan to figure out why the attn_mask passed in from transformers is not working
93+
# modified_input_args = (query, key, value, None, dropout_p, True)
94+
modified_input_args = (query, key, value, attn_mask, dropout_p, is_causal)
9595
# Create a new node with torch.nn.functional.scaled_dot_product_attention
9696
# The input args is (query, key, value, is_causal). kwargs has scale
9797
with gm.graph.inserting_after(node):

tools/llm/torchtrt_ext/sdpa_converter.py

Lines changed: 64 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -161,51 +161,77 @@ def scaled_dot_product_attention(
161161
L = impl.shape.shape(ctx, target, source_ir, name + "_shape_0", query, 2)
162162
if S < 0:
163163
S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, 2)
164-
165164
# generate the mask tensor
166165
if is_causal:
167166
tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S)
168167
else:
169-
# hard code the sliding window size to 512 for now
170-
tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S, 512)
171168
# TODO: lan to figure out why attn_mask passed in from transformers is not working
172-
# tried both 2d and 4d, but both are not working, hence the following code is commented out
173-
# assert len(attn_mask.shape) in [2, 4], f"attn_mask must be 2D or 4D, but got {attn_mask.shape=}"
174-
# if len(attn_mask.shape) == 4:
175-
# if attn_mask.shape[0] != 1:
176-
# attn_mask = impl.slice.slice_op(ctx, target, source_ir, name + "_slice", attn_mask, 0, 0, 1, 1)
177-
# if attn_mask.shape[1] != 1:
178-
# attn_mask = impl.slice.slice_op(ctx, target, source_ir, name + "_slice", attn_mask, 1, 0, 1, 1)
179-
# attn_mask = impl.squeeze.squeeze(ctx, target, source_ir, name + "_squeeze", attn_mask, (0, 1))
180-
# tril_tensor = attn_mask
181-
182-
temp_mask = impl.unary.logical_not(
183-
ctx, target, source_ir, name + "_logical_not", tril_tensor
184-
)
169+
# tried both 2d and 4d, but both are not working
170+
assert len(attn_mask.shape) in [
171+
2,
172+
4,
173+
], f"attn_mask must be 2D or 4D, but got {attn_mask.shape=}"
174+
if len(attn_mask.shape) == 4:
175+
if attn_mask.shape[0] != 1:
176+
attn_mask = impl.slice.slice_op(
177+
ctx, target, source_ir, name + "_slice", attn_mask, 0, 0, 1, 1
178+
)
179+
if attn_mask.shape[1] != 1:
180+
attn_mask = impl.slice.slice_op(
181+
ctx, target, source_ir, name + "_slice", attn_mask, 1, 0, 1, 1
182+
)
183+
attn_mask = impl.squeeze.squeeze(
184+
ctx, target, source_ir, name + "_squeeze", attn_mask, (0, 1)
185+
)
186+
tril_tensor = attn_mask
185187

186-
# This need_mask determines if we want to use the causal mask or not
187-
# When KV caching is enabled, L = 1 and != S. In this case, we shouldn't use the causal mask.
188-
# So need_mask will be all False values in this case.
189-
# TODO: Implement more general case where L != 1 and S != L
190-
need_mask = impl.elementwise.eq(ctx, target, source_ir, name + "_eq", L, S)
191-
temp_mask = impl.elementwise.logical_and(
192-
ctx, target, source_ir, name + "_logical_and", need_mask, temp_mask
193-
)
194-
temp_mask_casted = cast_trt_tensor(
195-
ctx, temp_mask, query_dtype, name + "_casted_bool", target, source_ir
196-
)
188+
# generate attn_bias via where instead of (logical_and, sub, log) to see whether nan is related to this
189+
attn_bias_via_where = True
190+
if attn_bias_via_where:
191+
attn_bias = impl.condition.where(
192+
ctx,
193+
target,
194+
source_ir,
195+
name + "_where",
196+
torch.tensor(0.0, dtype=torch.float32).cuda(),
197+
torch.tensor(-float("inf"), dtype=torch.float32).cuda(),
198+
tril_tensor,
199+
)
200+
else:
201+
temp_mask = impl.unary.logical_not(
202+
ctx, target, source_ir, name + "_logical_not", tril_tensor
203+
)
204+
temp_mask = cast_trt_tensor(
205+
ctx, temp_mask, trt.float32, name + "_casted_bool", target, source_ir
206+
)
207+
temp_mask = impl.elementwise.mul(
208+
ctx, target, source_ir, name + "_mul_-inf", temp_mask, float("-inf")
209+
)
210+
attn_bias = temp_mask
197211

198-
one_minus_temp_mask = impl.elementwise.sub(
199-
ctx,
200-
target,
201-
source_ir,
202-
name + "_one_minus_temp_mask",
203-
1.0,
204-
temp_mask_casted,
205-
)
206-
attn_bias = impl.unary.log(
207-
ctx, target, source_ir, name + "_log", one_minus_temp_mask
208-
)
212+
# This need_mask determines if we want to use the causal mask or not
213+
# When KV caching is enabled, L = 1 and != S. In this case, we shouldn't use the causal mask.
214+
# So need_mask will be all False values in this case.
215+
# TODO: Implement more general case where L != 1 and S != L
216+
need_mask = impl.elementwise.eq(ctx, target, source_ir, name + "_eq", L, S)
217+
temp_mask = impl.elementwise.logical_and(
218+
ctx, target, source_ir, name + "_logical_and", need_mask, temp_mask
219+
)
220+
temp_mask_casted = cast_trt_tensor(
221+
ctx, temp_mask, query_dtype, name + "_casted_bool", target, source_ir
222+
)
223+
224+
one_minus_temp_mask = impl.elementwise.sub(
225+
ctx,
226+
target,
227+
source_ir,
228+
name + "_one_minus_temp_mask",
229+
1.0,
230+
temp_mask_casted,
231+
)
232+
attn_bias = impl.unary.log(
233+
ctx, target, source_ir, name + "_log", one_minus_temp_mask
234+
)
209235

210236
scaled_add_attn_bias = impl.elementwise.add(
211237
ctx, target, source_ir, name + "_attn_bias_add", mm, attn_bias

tools/llm/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,6 @@ def generate_with_dynamic_cache(model, input_seq, max_output_seq_length, eos_tok
179179
num_tokens_generated = 0
180180
kv_cache = get_zeroed_dynamic_cache_inputs(model)
181181
last_position_id = position_ids[-1, -1].item()
182-
breakpoint()
183182
while num_tokens_generated < num_output_tokens:
184183
is_generate = False if input_seq.shape[1] > 1 else True
185184
position_ids = (

0 commit comments

Comments
 (0)