We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 066452f commit a58d17bCopy full SHA for a58d17b
tools/llm/torchtrt_ext/register_sdpa.py
@@ -1,7 +1,6 @@
1
import copy
2
import logging
3
import operator
4
-from re import I
5
from typing import Callable, Sequence, Tuple
6
7
import torch
tools/llm/torchtrt_ext/sdpa_converter.py
@@ -74,25 +74,6 @@ def tril(
74
# 4 ⬚ ⬚ ■ ■ ■ [False, False, True, True,True]]]])
75
return mask
76
77
- row_arange_tensor = impl.arange.arange(
78
- ctx, target, source_ir, name + "_arange_row", start=0, end=row, step=1
79
- )
80
- row_reshape_tensor = impl.shuffle.reshape(
81
- ctx, target, source_ir, name + "_reshape_row", row_arange_tensor, [row, 1]
82
83
-
84
- col_arange_tensor = impl.arange.arange(
85
- ctx, target, source_ir, name + "_arange_col", start=0, end=col, step=1
86
87
- col_reshape_tensor = impl.shuffle.reshape(
88
- ctx, target, source_ir, name + "_reshape_col", col_arange_tensor, [1, col]
89
90
91
- mask = impl.elementwise.ge(
92
- ctx, target, source_ir, name + "_ge", row_reshape_tensor, col_reshape_tensor
93
94
- return mask
95
96
97
@torch_tensorrt.dynamo.conversion.dynamo_tensorrt_converter(
98
torch.nn.functional.scaled_dot_product_attention,
0 commit comments