Skip to content

Commit ccfdfd1

Browse files
ljfitzramiro050
authored andcommitted
Refine static shapes for conv2d and maxpool2d
1 parent 4486de5 commit ccfdfd1

File tree

3 files changed

+162
-4
lines changed

3 files changed

+162
-4
lines changed

e2e_testing/torchscript/conv.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,32 @@ def forward(self, x):
104104
def Conv2dWithPaddingDilationStrideModule_basic(module, tu: TestUtils):
105105
t = tu.rand(5, 2, 10, 20)
106106
module.forward(t)
107+
108+
109+
class Conv2dWithPaddingDilationStrideStaticModule(torch.nn.Module):
110+
def __init__(self):
111+
super().__init__()
112+
torch.manual_seed(0)
113+
self.conv = torch.nn.Conv2d(in_channels=2,
114+
out_channels=10,
115+
kernel_size=3,
116+
padding=3,
117+
stride=2,
118+
dilation=3,
119+
bias=False)
120+
self.train(False)
121+
122+
@export
123+
@annotate_args([
124+
None,
125+
([5, 2, 10, 20], torch.float32, True),
126+
])
127+
def forward(self, x):
128+
return self.conv(x)
129+
130+
131+
@register_test_case(
132+
module_factory=lambda: Conv2dWithPaddingDilationStrideStaticModule())
133+
def Conv2dWithPaddingDilationStrideStaticModule_basic(module, tu: TestUtils):
134+
t = tu.rand(5, 2, 10, 20)
135+
module.forward(t)

lib/Dialect/Torch/Transforms/RefineTypes.cpp

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -824,16 +824,61 @@ ChangeResult TypeAnalyzer::visitAtenLinearOp(
824824
return getLatticeElement(op->getResult(0)).join(knowledge);
825825
}
826826

827+
static int64_t getOutputDimForOpWithKernel(int64_t dimIn, int64_t padding,
828+
int64_t dilation, int64_t kernelSize,
829+
int64_t stride) {
830+
return ((dimIn + 2 * padding - dilation * (kernelSize - 1) - 1) / stride) + 1;
831+
}
832+
833+
template <class Op>
834+
std::vector<int64_t>
835+
computeOpWithKernelOutputShape(Op op, const ValueKnowledge &ifm,
836+
int64_t features, int64_t kernelHeight,
837+
int64_t kernelWidth) {
838+
std::vector<int64_t> result = {ifm.sizes[0], // N
839+
features, // F
840+
kUnknownSize, kUnknownSize};
841+
842+
SmallVector<int64_t> padding;
843+
if (!matchPattern(op.padding(), m_TorchConstantIntList(padding)))
844+
return result;
845+
SmallVector<int64_t, 2> stride;
846+
if (!matchPattern(op.stride(), m_TorchConstantIntList(stride)))
847+
return result;
848+
SmallVector<int64_t, 2> dilation;
849+
if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilation)))
850+
return result;
851+
852+
int64_t ifmHeight = ifm.sizes[2];
853+
if (ifmHeight != kUnknownSize && kernelHeight != kUnknownSize)
854+
result[2] = getOutputDimForOpWithKernel(ifmHeight, padding[0], dilation[0],
855+
kernelHeight, stride[0]);
856+
857+
int64_t ifmWidth = ifm.sizes[3];
858+
if (ifmWidth != kUnknownSize && kernelWidth != kUnknownSize)
859+
result[3] = getOutputDimForOpWithKernel(ifmWidth, padding[1], dilation[1],
860+
kernelWidth, stride[1]);
861+
862+
return result;
863+
}
864+
827865
ChangeResult TypeAnalyzer::visitAtenConv2dOp(
828866
AtenConv2dOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
829867
auto knowledge =
830868
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
831869
knowledge.hasSizes = true;
832-
knowledge.sizes.resize(4, kUnknownSize);
870+
auto &ifm = operands[0]->getValue();
871+
auto &weights = operands[1]->getValue();
872+
if (weights.hasSizes && ifm.hasSizes)
873+
knowledge.sizes = computeOpWithKernelOutputShape(
874+
op, ifm, weights.sizes[0], weights.sizes[2], weights.sizes[3]);
875+
else
876+
knowledge.sizes.resize(4, kUnknownSize);
877+
833878
// Running some experiments in PyTorch, the bias doesn't seem to
834879
// contribute to the final element type.
835-
knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank(
836-
op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue()});
880+
knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank(op->getContext(),
881+
{&ifm, &weights});
837882
return getLatticeElement(op->getResult(0)).join(knowledge);
838883
}
839884

@@ -842,7 +887,15 @@ ChangeResult TypeAnalyzer::visitAtenMaxPool2dOp(
842887
auto knowledge =
843888
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
844889
knowledge.hasSizes = true;
845-
knowledge.sizes.resize(4, kUnknownSize);
890+
auto &ifm = operands[0]->getValue();
891+
SmallVector<int64_t, 2> kernelSize;
892+
if (!matchPattern(op.kernel_size(), m_TorchConstantIntList(kernelSize)))
893+
kernelSize = SmallVector<int64_t, 2>{kUnknownSize, kUnknownSize};
894+
if (ifm.hasSizes)
895+
knowledge.sizes = computeOpWithKernelOutputShape(
896+
op, ifm, ifm.sizes[1], kernelSize[0], kernelSize[1]);
897+
else
898+
knowledge.sizes.resize(4, kUnknownSize);
846899
knowledge.dtype = operands[0]->getValue().dtype;
847900
return getLatticeElement(op->getResult(0)).join(knowledge);
848901
}

test/Dialect/Torch/refine-types.mlir

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,18 @@ builtin.func @g(%arg0:!torch.vtensor<*,f32>, %arg1:!torch.vtensor<*,f32>, %arg2:
9393
return %3 :!torch.vtensor
9494
}
9595

96+
// CHECK-LABEL: func @h
97+
// CHECK: torch.aten.conv2d{{.*}} -> !torch.vtensor<[1,16,62,62],f32>
98+
builtin.func @h(%arg0:!torch.vtensor<[1,8,64,64],f32>, %arg1:!torch.vtensor<[16,8,3,3],f32>, %arg2:!torch.vtensor<*,f32>) ->!torch.vtensor {
99+
%int0 = torch.constant.int 0
100+
%int1 = torch.constant.int 1
101+
%stride = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
102+
%padding = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
103+
%dilation = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
104+
%3 = torch.aten.conv2d %arg0, %arg1, %arg2, %stride, %padding, %dilation, %int1 : !torch.vtensor<[1,8,64,64],f32>, !torch.vtensor<[16,8,3,3],f32>, !torch.vtensor<*,f32>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.int ->!torch.vtensor
105+
return %3 :!torch.vtensor
106+
}
107+
96108
// -----
97109

98110
// CHECK-LABEL: func @f
@@ -110,6 +122,70 @@ builtin.func @f(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor {
110122
return %27 : !torch.vtensor
111123
}
112124

125+
// CHECK-LABEL: func @g
126+
builtin.func @g(%arg0: !torch.vtensor<[1,8,64,64],f32>) -> !torch.vtensor {
127+
%int0 = torch.constant.int 0
128+
%int1 = torch.constant.int 1
129+
%int2 = torch.constant.int 2
130+
%int3 = torch.constant.int 3
131+
%bool_false = torch.constant.bool false
132+
%krnl = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
133+
%stride = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
134+
%padding = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
135+
%dilation = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
136+
// CHECK: torch.aten.max_pool2d{{.*}} -> !torch.vtensor<[1,8,32,32],f32>
137+
%27 = torch.aten.max_pool2d %arg0, %krnl, %stride, %padding, %dilation, %bool_false : !torch.vtensor<[1,8,64,64],f32>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.bool -> !torch.vtensor
138+
return %27 : !torch.vtensor
139+
}
140+
141+
// CHECK-LABEL: func @h
142+
builtin.func @h(%arg0: !torch.vtensor<[1,8,64,64],f32>) -> !torch.vtensor {
143+
%int0 = torch.constant.int 0
144+
%int1 = torch.constant.int 1
145+
%int2 = torch.constant.int 2
146+
%int3 = torch.constant.int 3
147+
%bool_false = torch.constant.bool false
148+
%krnl = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
149+
%stride = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
150+
%padding = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
151+
%dilation = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
152+
// CHECK: torch.aten.max_pool2d{{.*}} -> !torch.vtensor<[1,8,62,62],f32>
153+
%27 = torch.aten.max_pool2d %arg0, %krnl, %stride, %padding, %dilation, %bool_false : !torch.vtensor<[1,8,64,64],f32>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.bool -> !torch.vtensor
154+
return %27 : !torch.vtensor
155+
}
156+
157+
// CHECK-LABEL: func @i
158+
builtin.func @i(%arg0: !torch.vtensor<[1,8,64,64],f32>) -> !torch.vtensor {
159+
%int0 = torch.constant.int 0
160+
%int1 = torch.constant.int 1
161+
%int2 = torch.constant.int 2
162+
%int3 = torch.constant.int 3
163+
%bool_false = torch.constant.bool false
164+
%krnl = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
165+
%stride = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
166+
%padding = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
167+
%dilation = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
168+
// CHECK: torch.aten.max_pool2d{{.*}} -> !torch.vtensor<[1,8,66,66],f32>
169+
%27 = torch.aten.max_pool2d %arg0, %krnl, %stride, %padding, %dilation, %bool_false : !torch.vtensor<[1,8,64,64],f32>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.bool -> !torch.vtensor
170+
return %27 : !torch.vtensor
171+
}
172+
173+
// CHECK-LABEL: func @j
174+
builtin.func @j(%arg0: !torch.vtensor<[1,8,64,64],f32>) -> !torch.vtensor {
175+
%int0 = torch.constant.int 0
176+
%int1 = torch.constant.int 1
177+
%int2 = torch.constant.int 2
178+
%int3 = torch.constant.int 3
179+
%bool_false = torch.constant.bool false
180+
%krnl = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
181+
%stride = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
182+
%padding = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
183+
%dilation = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
184+
// CHECK: torch.aten.max_pool2d{{.*}} -> !torch.vtensor<[1,8,32,32],f32>
185+
%27 = torch.aten.max_pool2d %arg0, %krnl, %stride, %padding, %dilation, %bool_false : !torch.vtensor<[1,8,64,64],f32>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.bool -> !torch.vtensor
186+
return %27 : !torch.vtensor
187+
}
188+
113189
// -----
114190

115191
// CHECK-LABEL: func @f

0 commit comments

Comments
 (0)