Skip to content

Commit 7fb0d3a

Browse files
authored
Cherrypick #3703 for release/2.8 (#3735)
1 parent b65e445 commit 7fb0d3a

File tree

5 files changed

+131
-0
lines changed

5 files changed

+131
-0
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3579,3 +3579,22 @@ def aten_ops_nonzero(
35793579
name,
35803580
args[0],
35813581
)
3582+
3583+
3584+
@dynamo_tensorrt_converter(torch.ops.aten.linear.default, supports_dynamic_shapes=True)
3585+
def aten_ops_linear(
3586+
ctx: ConversionContext,
3587+
target: Target,
3588+
args: Tuple[Argument, ...],
3589+
kwargs: Dict[str, Argument],
3590+
name: str,
3591+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
3592+
return impl.linear.linear(
3593+
ctx,
3594+
target,
3595+
SourceIR.ATEN,
3596+
name,
3597+
input=args[0],
3598+
weight=args[1],
3599+
bias=args_bounds_check(args, 2, None),
3600+
)

py/torch_tensorrt/dynamo/conversion/impl/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
embedding,
1313
full,
1414
grid,
15+
linear,
1516
matmul,
1617
nccl_ops,
1718
normalization,
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from typing import Optional, Union
2+
3+
import numpy as np
4+
import tensorrt as trt
5+
import torch
6+
from torch.fx.node import Target
7+
from torch_tensorrt.dynamo.conversion import impl
8+
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
9+
from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR, get_trt_tensor
10+
from torch_tensorrt.dynamo.types import TRTTensor
11+
12+
13+
def linear(
14+
ctx: ConversionContext,
15+
target: Union[Target, str],
16+
source_ir: Optional[SourceIR],
17+
name: str,
18+
input: TRTTensor,
19+
weight: Union[TRTTensor, torch.Tensor, np.ndarray],
20+
bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
21+
) -> TRTTensor:
22+
# Process weight terms
23+
if not isinstance(weight, (TRTTensor, torch.Tensor, np.ndarray)):
24+
raise RuntimeError(
25+
f"Linear layer {name} has weight of type {type(weight)}, Expect Union[TRTTensor, torch.Tensor, np.ndarray],"
26+
)
27+
elif isinstance(weight, (torch.Tensor, np.ndarray)):
28+
weight = get_trt_tensor(ctx, weight, f"{name}_weight")
29+
30+
# Process bias terms
31+
if bias is not None and not isinstance(bias, (TRTTensor, torch.Tensor, np.ndarray)):
32+
raise RuntimeError(
33+
f"Linear layer {name} has bias of type {type(bias)}, Expect Union[TRTTensor, torch.Tensor, np.ndarray],"
34+
)
35+
elif isinstance(bias, (torch.Tensor, np.ndarray)):
36+
bias = get_trt_tensor(ctx, bias, f"{name}_bias")
37+
38+
# add IMatrixMultiplyLayer
39+
out = impl.matmul.matrix_multiply(
40+
ctx,
41+
target,
42+
source_ir,
43+
f"{name}_matrix_multiply",
44+
input,
45+
weight,
46+
input_matrix_op=trt.MatrixOperation.NONE,
47+
other_matrix_op=trt.MatrixOperation.TRANSPOSE,
48+
)
49+
50+
if bias is not None:
51+
# add bias
52+
out = impl.elementwise.add(
53+
ctx, target, source_ir, f"{name}_add_bias", out, bias
54+
)
55+
56+
return out

py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@
171171
aten.upsample_bilinear2d.vec,
172172
aten.upsample_trilinear3d.vec,
173173
aten.upsample_bicubic2d.vec,
174+
aten.linear.default,
174175
}
175176

176177

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt import Input
6+
7+
from .harness import DispatchTestCase
8+
9+
10+
class TestLinearConverter(DispatchTestCase):
11+
@parameterized.expand(
12+
[
13+
(10, 10),
14+
(10, 100),
15+
(100, 10),
16+
(100, 100),
17+
]
18+
)
19+
def test_linear_converter(self, in_features, out_features):
20+
class LinearModel(nn.Module):
21+
def __init__(self, in_features, out_features):
22+
super(LinearModel, self).__init__()
23+
self.linear = nn.Linear(in_features, out_features)
24+
25+
def forward(self, x):
26+
return self.linear(x)
27+
28+
model = LinearModel(in_features, out_features).eval().cuda()
29+
inputs = [torch.randn(int(torch.randint(1, 20, (1,))), in_features).cuda()]
30+
self.run_test(model, inputs, use_dynamo_tracer=True, enable_passes=True)
31+
32+
def test_linear_with_dynamic_shape(self):
33+
class LinearModel(torch.nn.Module):
34+
def forward(self, x, weight, bias):
35+
return torch.ops.aten.linear.default(x, weight, bias)
36+
37+
input_specs = [
38+
Input(
39+
dtype=torch.float32,
40+
min_shape=(1, 10),
41+
opt_shape=(10, 10),
42+
max_shape=(100, 10),
43+
),
44+
Input(dtype=torch.float32, shape=(20, 10)),
45+
Input(dtype=torch.float32, shape=(20,)),
46+
]
47+
48+
self.run_test_with_dynamic_shape(
49+
LinearModel(), input_specs, use_dynamo_tracer=True, enable_passes=True
50+
)
51+
52+
53+
if __name__ == "__main__":
54+
run_tests()

0 commit comments

Comments
 (0)