Skip to content

Commit 7767594

Browse files
chohk88Hoonkyung Cho
andauthored
Add test case for ITensor weight in convolution and fix related bug (#3327)
Co-authored-by: Hoonkyung Cho <[email protected]>
1 parent 544c545 commit 7767594

File tree

2 files changed

+52
-3
lines changed

2 files changed

+52
-3
lines changed

py/torch_tensorrt/dynamo/conversion/impl/conv.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import tensorrt as trt
77
import torch
88
from torch.fx.node import Target
9+
910
from torch_tensorrt.dynamo.conversion import impl
1011
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
1112
from torch_tensorrt.dynamo.conversion.converter_utils import (
@@ -68,10 +69,9 @@ def convNd(
6869
weight = get_trt_tensor(ctx, weight, f"{name}_weight")
6970
# Append new dimension (unsqueeze) if the convolution is 1d
7071
if is_conv1d:
71-
input = impl.unsqueeze.unsqueeze(
72-
ctx, target, source_ir, name + "_unsqueeze_weight", weight, -1
72+
weight = impl.unsqueeze.unsqueeze(
73+
ctx, target, source_ir, weight.name + "_unsqueeze_conv1d", weight, -1
7374
)
74-
7575
elif isinstance(weight, (torch.Tensor, np.ndarray)):
7676
# Transform the weight constant into a Numpy array
7777
weight = to_numpy(weight, dtype=input.dtype)

tests/py/dynamo/conversion/test_convolution_aten.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
from parameterized import param, parameterized
33
from torch.testing._internal.common_utils import run_tests
4+
45
from torch_tensorrt import Input
56

67
from .harness import DispatchTestCase
@@ -45,6 +46,54 @@ def forward(self, x):
4546
enable_passes=True,
4647
)
4748

49+
@parameterized.expand(
50+
[
51+
("default", 1),
52+
param("no_bias", 1, bias=False),
53+
("tuple_parameters", 1, (1), (1)),
54+
param("non_zero_padding", 1, padding=1),
55+
param("dilation", 1, dilation=2),
56+
]
57+
)
58+
def test_conv1d_TRTTensor_weight(
59+
self,
60+
_,
61+
kernel_size,
62+
stride=1,
63+
padding=0,
64+
dilation=1,
65+
groups=1,
66+
bias=True,
67+
):
68+
class TestModule(torch.nn.Module):
69+
def __init__(self):
70+
super().__init__()
71+
72+
def forward(self, x, w):
73+
return torch.ops.aten.convolution.default(
74+
x,
75+
w,
76+
None,
77+
(stride,) if isinstance(stride, int) else stride,
78+
(padding,) if isinstance(padding, int) else padding,
79+
(dilation,) if isinstance(dilation, int) else dilation,
80+
False,
81+
(0,),
82+
groups,
83+
)
84+
85+
inputs = [
86+
torch.randn(1, 3, 32),
87+
torch.randn(
88+
6, 3, 1
89+
), # Conv1d weight shape: (out_channels, in_channels, kernel_size)
90+
]
91+
self.run_test(
92+
TestModule(),
93+
inputs,
94+
use_dynamo_tracer=True,
95+
)
96+
4897
def test_conv1d_with_dynamic_shape(
4998
self,
5099
kernel_size=1,

0 commit comments

Comments
 (0)