Skip to content

Commit d7071ba

Browse files
authored
Remove linear lowering pass and converter (#3323)
1 parent 283a983 commit d7071ba

File tree

7 files changed

+0
-362
lines changed

7 files changed

+0
-362
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2509,26 +2509,6 @@ def aten_ops_convolution(
25092509
)
25102510

25112511

2512-
@dynamo_tensorrt_converter(torch.ops.aten.linear.default, supports_dynamic_shapes=True)
2513-
@dynamo_tensorrt_converter(torch.ops.aten.linear, supports_dynamic_shapes=True)
2514-
def aten_ops_linear(
2515-
ctx: ConversionContext,
2516-
target: Target,
2517-
args: Tuple[Argument, ...],
2518-
kwargs: Dict[str, Argument],
2519-
name: str,
2520-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2521-
return impl.linear.linear(
2522-
ctx,
2523-
target,
2524-
SourceIR.ATEN,
2525-
name,
2526-
input=args[0],
2527-
weight=args[1],
2528-
bias=args_bounds_check(args, 2, None),
2529-
)
2530-
2531-
25322512
@dynamo_tensorrt_converter(torch.ops.aten._cdist_forward.default)
25332513
def aten_ops_cdist_forward(
25342514
ctx: ConversionContext,

py/torch_tensorrt/dynamo/conversion/impl/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
embedding,
1414
full,
1515
grid,
16-
linear,
1716
matmul,
1817
normalization,
1918
pad,

py/torch_tensorrt/dynamo/conversion/impl/linear.py

Lines changed: 0 additions & 54 deletions
This file was deleted.

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from .accumulate_fp32_matmul import accumulate_fp32_matmul
88
from .constant_folding import constant_fold
99
from .fuse_prims_broadcast import fuse_prims_broadcast
10-
from .lower_linear import lower_linear
1110
from .pass_manager import DynamoPassManager
1211
from .remove_assert_scalar import remove_assert_scalar
1312
from .remove_detach import remove_detach
@@ -22,7 +21,6 @@
2221
remove_input_alias_fixing_clones,
2322
constant_fold,
2423
repair_input_as_output,
25-
lower_linear,
2624
fuse_prims_broadcast,
2725
replace_max_pool_with_indices,
2826
replace_full_like_with_full,

py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py

Lines changed: 0 additions & 42 deletions
This file was deleted.

tests/py/dynamo/conversion/test_linear_aten.py

Lines changed: 0 additions & 131 deletions
This file was deleted.

tests/py/dynamo/lowering/test_aten_lowering_passes.py

Lines changed: 0 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -158,118 +158,6 @@ def forward(self, x):
158158
torch._dynamo.reset()
159159

160160

161-
class TestLowerLinear(TestCase):
162-
@unittest.skip(
163-
"This test has threshold failures. This is tracked at https://github.com/pytorch/TensorRT/issues/2715",
164-
)
165-
def test_lower_linear(self):
166-
class Linear(torch.nn.Module):
167-
def forward(self, input, weight, bias):
168-
out = torch.ops.aten.linear.default(input, weight, bias)
169-
return out
170-
171-
inputs = [
172-
torch.rand((3, 32)).cuda(),
173-
torch.rand((64, 32)).cuda(),
174-
torch.rand((64,)).cuda(),
175-
]
176-
177-
fx_graph = torch.fx.symbolic_trace(Linear())
178-
expected_ops = {torch.ops.aten.linear.default}
179-
unexpected_ops = {
180-
torch.ops.aten.permute.default,
181-
torch.ops.aten.addmm.default,
182-
}
183-
184-
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
185-
fx_graph,
186-
inputs,
187-
expected_ops=expected_ops,
188-
unexpected_ops=unexpected_ops,
189-
min_block_size=1,
190-
)
191-
192-
self.assertEqual(
193-
len(unexpected_ops_seen),
194-
0,
195-
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
196-
)
197-
198-
self.assertEqual(
199-
len(expected_ops_unseen),
200-
0,
201-
f"The following expected ops were not encountered: {expected_ops_unseen}",
202-
)
203-
torch._dynamo.reset()
204-
205-
# Validate that the results between Torch and Torch-TRT are similar
206-
optimized_model = torch_tensorrt.compile(
207-
fx_graph,
208-
"torch_compile",
209-
inputs,
210-
min_block_size=1,
211-
pass_through_build_failures=True,
212-
)
213-
optimized_model_results = torch.cat(
214-
[tensor.detach().cpu() for tensor in optimized_model(*inputs)]
215-
)
216-
torch_model_results = torch.cat(
217-
[tensor.detach().cpu() for tensor in fx_graph(*inputs)]
218-
)
219-
220-
max_diff = float(
221-
torch.max(torch.abs(optimized_model_results - torch_model_results))
222-
)
223-
224-
self.assertAlmostEqual(
225-
max_diff,
226-
0,
227-
DECIMALS_OF_AGREEMENT,
228-
msg=f"Linear TRT outputs don't match with the original model.",
229-
)
230-
torch._dynamo.reset()
231-
232-
def test_lower_linear_batch(self):
233-
class Linear(torch.nn.Module):
234-
def forward(self, input, weight, bias):
235-
out = torch.ops.aten.linear.default(input, weight, bias)
236-
return out
237-
238-
inputs = [
239-
torch.rand((2, 2, 32)).cuda(),
240-
torch.rand((64, 32)).cuda(),
241-
torch.rand((64,)).cuda(),
242-
]
243-
244-
fx_graph = torch.fx.symbolic_trace(Linear())
245-
246-
# Validate that the results between Torch and Torch-TRT are similar
247-
optimized_model = torch_tensorrt.compile(
248-
fx_graph,
249-
"torch_compile",
250-
inputs,
251-
min_block_size=1,
252-
pass_through_build_failures=True,
253-
)
254-
optimized_model_results = torch.cat(
255-
[tensor.detach().cpu() for tensor in optimized_model(*inputs)]
256-
)
257-
torch_model_results = torch.cat(
258-
[tensor.detach().cpu() for tensor in fx_graph(*inputs)]
259-
)
260-
261-
max_diff = float(
262-
torch.max(torch.abs(optimized_model_results - torch_model_results))
263-
)
264-
self.assertAlmostEqual(
265-
max_diff,
266-
0,
267-
DECIMALS_OF_AGREEMENT,
268-
msg=f"Linear TRT outputs don't match with the original model.",
269-
)
270-
torch._dynamo.reset()
271-
272-
273161
class TestLowerViewToReshape(TestCase):
274162
def test_view_to_reshape(self):
275163
class ViewToReshape(torch.nn.Module):

0 commit comments

Comments
 (0)