Skip to content

Commit 79cdf80

Browse files
authored
Merge pull request #1609 from pytorch/convolution_cast
Convolution cast
2 parents 18ba2cb + ab2ed5e commit 79cdf80

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

core/conversion/converters/impl/conv_deconv.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@ namespace {
1313
bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args) {
1414
// Input to conv/deconv
1515
auto in = args[0].ITensor();
16-
16+
if (in->getType() == nvinfer1::DataType::kINT32) {
17+
LOG_WARNING(
18+
"Found type " << in->getType() << "in aten::convolution, casting to" << nvinfer1::DataType::kFLOAT
19+
<< " for compatibility.");
20+
in = castITensor(ctx, in, nvinfer1::DataType::kFLOAT);
21+
}
1722
// Conv /deconv parameters
1823
auto stride = util::toDims(args[3].unwrapToIntList());
1924
auto padding = util::toDims(args[4].unwrapToIntList());

0 commit comments

Comments
 (0)