Skip to content

Commit 333df1a

Browse files
committed
changes to cast the convolution input layers to float
1 parent 3cb96ed commit 333df1a

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

core/conversion/converters/impl/conv_deconv.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@ namespace {
1313
bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args) {
1414
// Input to conv/deconv
1515
auto in = args[0].ITensor();
16-
// if (in->getType() == nvinfer1::DataType::kINT32) {
17-
// LOG_DEBUG(
18-
// "Found type " << in->getType() << " in aten::convolution, casting to "
19-
// << nvinfer1::DataType::kFLOAT << " for compatibility.");
20-
// in = castITensor(ctx, in, nvinfer1::DataType::kFLOAT);
21-
// }
16+
if (in->getType() == nvinfer1::DataType::kINT32) {
17+
LOG_DEBUG(
18+
"Found type " << in->getType() << " in aten::convolution, casting to "
19+
<< nvinfer1::DataType::kFLOAT << " for compatibility.");
20+
in = castITensor(ctx, in, nvinfer1::DataType::kFLOAT);
21+
}
2222
// Conv /deconv parameters
2323
auto stride = util::toDims(args[3].unwrapToIntList());
2424
auto padding = util::toDims(args[4].unwrapToIntList());

0 commit comments

Comments
 (0)