Skip to content

Commit da90d61

Browse files
authored
feat: a lowering pass to re-compose ops into aten.linear (#2411)
1 parent 7029e91 commit da90d61

File tree

10 files changed

+292
-117
lines changed

10 files changed

+292
-117
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1957,3 +1957,31 @@ def aten_ops_argmax(
19571957
dim=args_bounds_check(args, 1),
19581958
keep_dim=args_bounds_check(args, 2, False),
19591959
)
1960+
1961+
1962+
@dynamo_tensorrt_converter(torch.ops.aten.addmm.default) # type: ignore[misc]
1963+
@enforce_tensor_types(
1964+
{
1965+
0: (TRTTensor,),
1966+
1: (np.ndarray, torch.Tensor, TRTTensor),
1967+
2: (np.ndarray, torch.Tensor, TRTTensor),
1968+
}
1969+
) # type: ignore[misc]
1970+
def aten_ops_addmm(
1971+
ctx: ConversionContext,
1972+
target: Target,
1973+
args: Tuple[Argument, ...],
1974+
kwargs: Dict[str, Argument],
1975+
name: str,
1976+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1977+
return impl.addmm.addmm(
1978+
ctx,
1979+
target,
1980+
SourceIR.ATEN,
1981+
name,
1982+
args[0],
1983+
args[1],
1984+
args[2],
1985+
beta=kwargs.get("beta", 1),
1986+
alpha=kwargs.get("alpha", 1),
1987+
)

py/torch_tensorrt/dynamo/conversion/impl/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22

33
from . import (
44
activation,
5-
attention,
5+
addmm,
66
argmax,
7+
attention,
78
cast,
89
cat,
910
condition,
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from typing import Optional, Union
2+
3+
import numpy as np
4+
import torch
5+
from torch.fx.node import Target
6+
from torch_tensorrt.dynamo._SourceIR import SourceIR
7+
from torch_tensorrt.dynamo.conversion import impl
8+
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
9+
from torch_tensorrt.fx.types import TRTTensor
10+
11+
12+
def addmm(
13+
ctx: ConversionContext,
14+
target: Target,
15+
source_ir: Optional[SourceIR],
16+
name: str,
17+
input: TRTTensor,
18+
mat1: Union[TRTTensor, torch.Tensor, np.ndarray],
19+
mat2: Union[TRTTensor, torch.Tensor, np.ndarray],
20+
*,
21+
beta: Union[float, int],
22+
alpha: Union[float, int],
23+
) -> TRTTensor:
24+
mm = impl.matmul.matrix_multiply(ctx, target, source_ir, f"{name}_mm", mat1, mat2)
25+
if alpha != 1:
26+
mm = impl.elementwise.mul(
27+
ctx, target, SourceIR.ATEN, f"{name}_mul_alpha", mm, alpha
28+
)
29+
if beta != 1:
30+
input = impl.elementwise.mul(
31+
ctx, target, SourceIR.ATEN, f"{name}_mul_beta", input, beta
32+
)
33+
34+
return impl.elementwise.add(ctx, target, source_ir, f"{name}_add", input, mm)

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -105,22 +105,6 @@ def alias_replacement(x: torch.Tensor) -> torch.Tensor:
105105
return x
106106

107107

108-
@register_torch_trt_decomposition(
109-
torch.ops.aten.addmm, registry=TORCH_TRT_DECOMPOSITIONS
110-
)
111-
def addmm_replacement(
112-
input_: torch.Tensor,
113-
mat1: torch.Tensor,
114-
mat2: torch.Tensor,
115-
*,
116-
beta: int = 1,
117-
alpha: int = 1,
118-
) -> torch.Tensor:
119-
return torch.add(
120-
torch.mul(input_, beta), torch.mul(torch.matmul(mat1, mat2), alpha)
121-
)
122-
123-
124108
@register_torch_trt_decomposition(
125109
torch.ops.aten.reciprocal.default, registry=TORCH_TRT_DECOMPOSITIONS
126110
)

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .constant_folding import constant_fold
77
from .fuse_prims_broadcast import fuse_prims_broadcast
88
from .lower_efficient_attention import lower_efficient_attention
9+
from .lower_linear import lower_linear
910
from .pass_manager import DynamoPassManager
1011
from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
1112
from .repair_input_as_output import repair_input_as_output
@@ -17,6 +18,7 @@
1718
constant_fold,
1819
repair_input_as_output,
1920
lower_efficient_attention,
21+
lower_linear,
2022
fuse_prims_broadcast,
2123
replace_max_pool_with_indices,
2224
]

py/torch_tensorrt/dynamo/lowering/passes/lower_efficient_attention.py

Lines changed: 6 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import torch
66
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
77
clean_up_graph_after_modifications,
8-
get_tensor_placeholders,
98
)
109

1110
logger = logging.getLogger(__name__)
@@ -36,34 +35,13 @@ def efficient_attention_replacement() -> (
3635
):
3736
"""Constructs the original and replacement functions for efficient attention"""
3837

39-
# Empty boilerplate function taking in three Tensors and returning one
40-
def boilerplate(
41-
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
42-
) -> torch.Tensor:
43-
...
44-
45-
# Trace boilerplate function and extract placeholder and output nodes
46-
orig = torch.fx.symbolic_trace(boilerplate)
47-
q, k, v = get_tensor_placeholders(orig)
48-
output = [node for node in orig.graph.nodes if node.op == "output"][0]
49-
50-
# Graph types to replace are those which use the _scaled_dot_product_efficient_attention
51-
# function and extract only the first element
52-
with orig.graph.inserting_before(output):
53-
att = orig.graph.call_function(
54-
torch.ops.aten._scaled_dot_product_efficient_attention.default,
55-
args=(q, k, v, None, False),
38+
# Original graph
39+
def orig(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
40+
outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default(
41+
q, k, v, None, False
5642
)
57-
out = orig.graph.call_function(
58-
operator.getitem,
59-
args=(att, 0),
60-
)
61-
62-
# Assign the output of the graph to be the single getitem output
63-
output.args = (out,)
64-
65-
orig.graph.lint()
66-
orig.recompile()
43+
out = operator.getitem(outputs, 0)
44+
return out
6745

6846
# Replacement graph consists of the functional version of scaled_dot_product_attention
6947
def replacement(
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import logging
2+
from typing import Callable, Sequence, Tuple
3+
4+
import torch
5+
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
6+
clean_up_graph_after_modifications,
7+
)
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
def lower_linear(
13+
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
14+
) -> torch.fx.GraphModule:
15+
"""Replace aten.linear with an equivalent implementation which can be easily converted to TRT"""
16+
orig, replacement = linear_replacement()
17+
18+
if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement):
19+
gm = clean_up_graph_after_modifications(gm)
20+
logger.debug(f"Graph after lowering linear:\n{gm.graph}")
21+
22+
return gm
23+
24+
25+
def linear_replacement() -> (
26+
Tuple[
27+
torch.fx.GraphModule,
28+
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
29+
]
30+
):
31+
"""Constructs the original and replacement functions for linear"""
32+
33+
# Original graph
34+
def orig(
35+
input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
36+
) -> torch.Tensor:
37+
W_T = torch.ops.aten.permute.default(weight, [1, 0])
38+
out = torch.ops.aten.addmm.default(bias, input, W_T)
39+
return out
40+
41+
# Replacement graph
42+
def replacement(
43+
input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
44+
) -> torch.Tensor:
45+
return torch.ops.aten.linear.default(input, weight, bias)
46+
47+
return orig, replacement
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestAddmmConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
((2, 2), (2, 3), (3, 2)),
13+
((4, 6), (4, 5), (5, 6)),
14+
((2, 1), (2, 3), (3, 1)),
15+
((4, 1), (4, 1), (1, 1)),
16+
((1, 2), (1, 3), (3, 2)),
17+
]
18+
)
19+
def test_addmm(self, input_shape, mat1_shape, mat2_shape):
20+
class Addmm(nn.Module):
21+
def forward(self, input, mat1, mat2):
22+
return torch.ops.aten.addmm.default(input, mat1, mat2)
23+
24+
inputs = [
25+
torch.randn(input_shape),
26+
torch.randn(mat1_shape),
27+
torch.randn(mat2_shape),
28+
]
29+
30+
self.run_test(
31+
Addmm(),
32+
inputs,
33+
)
34+
35+
@parameterized.expand(
36+
[
37+
((2, 2), (2, 3), (3, 2), 1.0, 1.0),
38+
((4, 6), (4, 5), (5, 6), 1.2, 0.8),
39+
((2, 1), (2, 3), (3, 1), 3, 2),
40+
((4, 1), (4, 1), (1, 1), 1, 1),
41+
((1, 2), (1, 3), (3, 2), 2, 1.0),
42+
((1, 2), (1, 3), (3, 2), 1, 2.0),
43+
]
44+
)
45+
def test_addmm_scale(self, input_shape, mat1_shape, mat2_shape, beta, alpha):
46+
class Addmm(nn.Module):
47+
def forward(self, input, mat1, mat2):
48+
return torch.ops.aten.addmm.default(
49+
input, mat1, mat2, beta=beta, alpha=alpha
50+
)
51+
52+
inputs = [
53+
torch.randn(input_shape),
54+
torch.randn(mat1_shape),
55+
torch.randn(mat2_shape),
56+
]
57+
58+
self.run_test(
59+
Addmm(),
60+
inputs,
61+
)
62+
63+
64+
if __name__ == "__main__":
65+
run_tests()

tests/py/dynamo/lowering/test_aten_lowering_passes.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,5 +267,113 @@ def forward(self, q, k, v):
267267
torch._dynamo.reset()
268268

269269

270+
class TestLowerLinear(TestCase):
271+
def test_lower_linear(self):
272+
class Linear(torch.nn.Module):
273+
def forward(self, input, weight, bias):
274+
out = torch.ops.aten.linear.default(input, weight, bias)
275+
return out
276+
277+
inputs = [
278+
torch.rand((3, 32)).cuda(),
279+
torch.rand((64, 32)).cuda(),
280+
torch.rand((64,)).cuda(),
281+
]
282+
283+
fx_graph = torch.fx.symbolic_trace(Linear())
284+
expected_ops = {torch.ops.aten.linear.default}
285+
unexpected_ops = {
286+
torch.ops.aten.permute.default,
287+
torch.ops.aten.addmm.default,
288+
}
289+
290+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
291+
fx_graph,
292+
inputs,
293+
expected_ops=expected_ops,
294+
unexpected_ops=unexpected_ops,
295+
min_block_size=1,
296+
)
297+
298+
self.assertEquals(
299+
len(unexpected_ops_seen),
300+
0,
301+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
302+
)
303+
304+
self.assertEquals(
305+
len(expected_ops_unseen),
306+
0,
307+
f"The following expected ops were not encountered: {expected_ops_unseen}",
308+
)
309+
torch._dynamo.reset()
310+
311+
# Validate that the results between Torch and Torch-TRT are similar
312+
optimized_model = torch_tensorrt.compile(
313+
fx_graph,
314+
"torch_compile",
315+
inputs,
316+
min_block_size=1,
317+
pass_through_build_failures=True,
318+
)
319+
optimized_model_results = torch.cat(
320+
[tensor.detach().cpu() for tensor in optimized_model(*inputs)]
321+
)
322+
torch_model_results = torch.cat(
323+
[tensor.detach().cpu() for tensor in fx_graph(*inputs)]
324+
)
325+
326+
max_diff = float(
327+
torch.max(torch.abs(optimized_model_results - torch_model_results))
328+
)
329+
self.assertAlmostEqual(
330+
max_diff,
331+
0,
332+
DECIMALS_OF_AGREEMENT,
333+
msg=f"Linear TRT outputs don't match with the original model.",
334+
)
335+
torch._dynamo.reset()
336+
337+
def test_lower_linear_batch(self):
338+
class Linear(torch.nn.Module):
339+
def forward(self, input, weight, bias):
340+
out = torch.ops.aten.linear.default(input, weight, bias)
341+
return out
342+
343+
inputs = [
344+
torch.rand((2, 2, 32)).cuda(),
345+
torch.rand((64, 32)).cuda(),
346+
torch.rand((64,)).cuda(),
347+
]
348+
349+
fx_graph = torch.fx.symbolic_trace(Linear())
350+
351+
# Validate that the results between Torch and Torch-TRT are similar
352+
optimized_model = torch_tensorrt.compile(
353+
fx_graph,
354+
"torch_compile",
355+
inputs,
356+
min_block_size=1,
357+
pass_through_build_failures=True,
358+
)
359+
optimized_model_results = torch.cat(
360+
[tensor.detach().cpu() for tensor in optimized_model(*inputs)]
361+
)
362+
torch_model_results = torch.cat(
363+
[tensor.detach().cpu() for tensor in fx_graph(*inputs)]
364+
)
365+
366+
max_diff = float(
367+
torch.max(torch.abs(optimized_model_results - torch_model_results))
368+
)
369+
self.assertAlmostEqual(
370+
max_diff,
371+
0,
372+
DECIMALS_OF_AGREEMENT,
373+
msg=f"Linear TRT outputs don't match with the original model.",
374+
)
375+
torch._dynamo.reset()
376+
377+
270378
if __name__ == "__main__":
271379
run_tests()

0 commit comments

Comments
 (0)