Skip to content

Commit 0360198

Browse files
authored
Merge pull request #405 from guoruoqian/batchnorm_fix_bug
Fix bug when batchnorm should unpack
2 parents 9439059 + aa8ffc1 commit 0360198

File tree

4 files changed

+84
-3
lines changed

4 files changed

+84
-3
lines changed

core/conversion/converters/impl/batch_norm.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ auto batch_norm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().
4343
auto should_unpack = util::toVec(orig_shape).size() < 4;
4444
if (should_unpack) {
4545
// expand spatial dims from 1D to 2D
46-
auto new_shape = util::toDimsPad(util::toVec(orig_shape), 4);
46+
auto new_shape = util::toDimsTailPad(util::toVec(orig_shape), 4);
4747
LOG_DEBUG(
4848
"Input shape is less than 4D got: "
4949
<< orig_shape << ", inserting shuffle layer to reshape to 4D tensor shape: " << new_shape);

core/util/trt_util.cpp

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,30 @@ nvinfer1::Dims toDimsPad(c10::IntArrayRef l, uint64_t pad_to) {
8686
return dims;
8787
}
8888

89+
nvinfer1::Dims toDimsTailPad(c10::IntArrayRef l, uint64_t pad_to) {
90+
if (l.size() > pad_to) {
91+
LOG_DEBUG(
92+
"Requested padding of dimensions to " << pad_to << " but found " << l.size()
93+
<< " dimensions, not going to pad");
94+
return toDims(l);
95+
}
96+
97+
TRTORCH_CHECK(
98+
pad_to <= nvinfer1::Dims::MAX_DIMS,
99+
"The list requested to be converted to nvinfer1::Dims exceeds the max number of dimensions for TensorRT");
100+
101+
nvinfer1::Dims dims;
102+
dims.nbDims = pad_to;
103+
for (size_t i = 0; i < l.size(); i++) {
104+
dims.d[i] = l[i];
105+
}
106+
107+
for (size_t i = pad_to - l.size(); i < pad_to; i++) {
108+
dims.d[i] = 1;
109+
}
110+
return dims;
111+
}
112+
89113
nvinfer1::Dims toDims(c10::IntArrayRef l) {
90114
TRTORCH_CHECK(
91115
l.size() <= nvinfer1::Dims::MAX_DIMS,
@@ -136,6 +160,30 @@ nvinfer1::Dims toDimsPad(c10::List<int64_t> l, uint64_t pad_to) {
136160
return dims;
137161
}
138162

163+
nvinfer1::Dims toDimsTailPad(c10::List<int64_t> l, uint64_t pad_to) {
164+
if (l.size() > pad_to) {
165+
LOG_DEBUG(
166+
"Requested padding of dimensions to " << pad_to << " but found " << l.size()
167+
<< " dimensions, not going to pad");
168+
return toDims(l);
169+
}
170+
171+
TRTORCH_CHECK(
172+
pad_to <= nvinfer1::Dims::MAX_DIMS,
173+
"The list requested to be converted to nvinfer1::Dims exceeds the max number of dimensions for TensorRT");
174+
175+
nvinfer1::Dims dims;
176+
dims.nbDims = pad_to;
177+
for (size_t i = 0; i < l.size(); i++) {
178+
dims.d[i] = l[i];
179+
}
180+
181+
for (size_t i = pad_to - l.size(); i < pad_to; i++) {
182+
dims.d[i] = 1;
183+
}
184+
return dims;
185+
}
186+
139187
nvinfer1::Dims unpadDims(const nvinfer1::Dims& d) {
140188
nvinfer1::Dims dims;
141189

@@ -304,4 +352,4 @@ c10::optional<nvinfer1::DataType> toTRTDataType(caffe2::TypeMeta dtype) {
304352

305353
} // namespace util
306354
} // namespace core
307-
} // namespace trtorch
355+
} // namespace trtorch

core/util/trt_util.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ int64_t volume(const nvinfer1::Dims& d);
9292
bool broadcastable(nvinfer1::Dims a, nvinfer1::Dims b, bool multidirectional = true);
9393
nvinfer1::Dims toDimsPad(c10::IntArrayRef l, uint64_t pad_to);
9494
nvinfer1::Dims toDimsPad(c10::List<int64_t> l, uint64_t pad_to);
95+
nvinfer1::Dims toDimsTailPad(c10::IntArrayRef l, uint64_t pad_to);
96+
nvinfer1::Dims toDimsTailPad(c10::List<int64_t> l, uint64_t pad_to);
9597
nvinfer1::Dims unpadDims(const nvinfer1::Dims& d);
9698
nvinfer1::Dims unsqueezeDims(const nvinfer1::Dims& d, int pos);
9799
nvinfer1::Dims squeezeDims(const nvinfer1::Dims& d, int pos);
@@ -110,4 +112,4 @@ const std::unordered_map<at::ScalarType, nvinfer1::DataType>& get_aten_trt_type_
110112

111113
} // namespace util
112114
} // namespace core
113-
} // namespace trtorch
115+
} // namespace trtorch

tests/core/conversion/converters/test_batch_norm.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,34 @@ TEST(Converters, ATenBatchNormConvertsCorrectly) {
3434

3535
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
3636
}
37+
38+
TEST(Converters, ATenBatchNormShouldUnpackConvertsCorrectly) {
39+
const auto graph = R"IR(
40+
graph(%0 : Tensor,
41+
%1: Float(5:1),
42+
%2: Float(5:1),
43+
%3: Float(5:1),
44+
%4: Float(5:1)):
45+
%5 : bool = prim::Constant[value=0]()
46+
%6 : float = prim::Constant[value=1.0000000000000001e-05]()
47+
%7 : float = prim::Constant[value=0.10000000000000001]()
48+
%8 : Tensor = aten::batch_norm(%0, %1, %2, %3, %4, %5, %6, %7, %5)
49+
return (%8))IR";
50+
51+
auto g = std::make_shared<torch::jit::Graph>();
52+
torch::jit::parseIR(graph, &*g);
53+
54+
auto in = at::randint(1, 10, {3, 5}, {at::kCUDA});
55+
auto gamma = at::randint(1, 10, {5}, {at::kCUDA});
56+
auto beta = at::randint(1, 10, {5}, {at::kCUDA});
57+
auto mean = at::randint(1, 10, {5}, {at::kCUDA});
58+
auto var = at::randint(1, 10, {5}, {at::kCUDA});
59+
60+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {gamma, beta, mean, var});
61+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
62+
63+
params = trtorch::core::conversion::get_named_params(g->inputs(), {gamma, beta, mean, var});
64+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
65+
66+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
67+
}

0 commit comments

Comments
 (0)