Skip to content

Commit a8184d6

Browse files
committed
casting Int32 convolution inputs to float
1 parent b2a5da6 commit a8184d6

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

core/conversion/converters/impl/conv_deconv.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@ namespace {
1313
bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args) {
1414
// Input to conv/deconv
1515
auto in = args[0].ITensor();
16-
16+
// if (in->getType() == nvinfer1::DataType::kINT32) {
17+
// LOG_DEBUG(
18+
// "Found type " << in->getType() << " in aten::convolution, casting to "
19+
// << nvinfer1::DataType::kFLOAT << " for compatibility.");
20+
// in = castITensor(ctx, in, nvinfer1::DataType::kFLOAT);
21+
// }
1722
// Conv /deconv parameters
1823
auto stride = util::toDims(args[3].unwrapToIntList());
1924
auto padding = util::toDims(args[4].unwrapToIntList());

0 commit comments

Comments
 (0)