Skip to content

Commit a58d17b

Browse files
committed
test
1 parent 066452f commit a58d17b

File tree

2 files changed

+0
-20
lines changed

2 files changed

+0
-20
lines changed

tools/llm/torchtrt_ext/register_sdpa.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import copy
22
import logging
33
import operator
4-
from re import I
54
from typing import Callable, Sequence, Tuple
65

76
import torch

tools/llm/torchtrt_ext/sdpa_converter.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -74,25 +74,6 @@ def tril(
7474
# 4 ⬚ ⬚ ■ ■ ■ [False, False, True, True,True]]]])
7575
return mask
7676

77-
row_arange_tensor = impl.arange.arange(
78-
ctx, target, source_ir, name + "_arange_row", start=0, end=row, step=1
79-
)
80-
row_reshape_tensor = impl.shuffle.reshape(
81-
ctx, target, source_ir, name + "_reshape_row", row_arange_tensor, [row, 1]
82-
)
83-
84-
col_arange_tensor = impl.arange.arange(
85-
ctx, target, source_ir, name + "_arange_col", start=0, end=col, step=1
86-
)
87-
col_reshape_tensor = impl.shuffle.reshape(
88-
ctx, target, source_ir, name + "_reshape_col", col_arange_tensor, [1, col]
89-
)
90-
91-
mask = impl.elementwise.ge(
92-
ctx, target, source_ir, name + "_ge", row_reshape_tensor, col_reshape_tensor
93-
)
94-
return mask
95-
9677

9778
@torch_tensorrt.dynamo.conversion.dynamo_tensorrt_converter(
9879
torch.nn.functional.scaled_dot_product_attention,

0 commit comments

Comments
 (0)