@@ -229,6 +229,7 @@ const std::unordered_map<at::ScalarType, nvinfer1::DataType>& get_at_trt_type_ma
229
229
{at::kHalf , nvinfer1::DataType::kHALF },
230
230
{at::kInt , nvinfer1::DataType::kINT32 },
231
231
{at::kChar , nvinfer1::DataType::kINT8 },
232
+ {at::kBool , nvinfer1::DataType::kBOOL },
232
233
};
233
234
return at_trt_type_map;
234
235
}
@@ -239,6 +240,7 @@ const std::unordered_map<nvinfer1::DataType, at::ScalarType>& get_trt_at_type_ma
239
240
{nvinfer1::DataType::kHALF , at::kHalf },
240
241
{nvinfer1::DataType::kINT32 , at::kInt },
241
242
{nvinfer1::DataType::kINT8 , at::kChar },
243
+ {nvinfer1::DataType::kBOOL , at::kBool },
242
244
};
243
245
return trt_at_type_map;
244
246
}
@@ -249,15 +251,19 @@ const std::unordered_map<nvinfer1::DataType, at::ScalarType>& get_trt_aten_type_
249
251
}
250
252
251
253
at::ScalarType toATenDType (nvinfer1::DataType t) {
252
- return get_trt_aten_type_map ().at (t);
254
+ auto trt_aten_type_map = get_trt_aten_type_map ();
255
+ TRTORCH_CHECK (trt_aten_type_map.find (t) != trt_aten_type_map.end (), " Unsupported TensorRT datatype" );
256
+ return trt_aten_type_map.at (t);
253
257
}
254
258
255
259
const std::unordered_map<at::ScalarType, nvinfer1::DataType>& get_aten_trt_type_map () {
256
260
return get_at_trt_type_map ();
257
261
}
258
262
259
263
nvinfer1::DataType toTRTDataType (at::ScalarType t) {
260
- return get_aten_trt_type_map ().at (t);
264
+ auto aten_trt_type_map = get_aten_trt_type_map ();
265
+ TRTORCH_CHECK (aten_trt_type_map.find (t) != aten_trt_type_map.end (), " Unsupported Aten datatype" );
266
+ return aten_trt_type_map.at (t);
261
267
}
262
268
263
269
c10::optional<nvinfer1::DataType> toTRTDataType (caffe2::TypeMeta dtype) {
0 commit comments