Skip to content

Commit 52c6eee

Browse files
authored
Merge pull request #255 from NVIDIA/fix_dtype
Support bool datatype of input/output tensors
2 parents 0b5d6e9 + 4b86558 commit 52c6eee

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

core/util/trt_util.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ const std::unordered_map<at::ScalarType, nvinfer1::DataType>& get_at_trt_type_ma
229229
{at::kHalf, nvinfer1::DataType::kHALF},
230230
{at::kInt, nvinfer1::DataType::kINT32},
231231
{at::kChar, nvinfer1::DataType::kINT8},
232+
{at::kBool, nvinfer1::DataType::kBOOL},
232233
};
233234
return at_trt_type_map;
234235
}
@@ -239,6 +240,7 @@ const std::unordered_map<nvinfer1::DataType, at::ScalarType>& get_trt_at_type_ma
239240
{nvinfer1::DataType::kHALF, at::kHalf},
240241
{nvinfer1::DataType::kINT32, at::kInt},
241242
{nvinfer1::DataType::kINT8, at::kChar},
243+
{nvinfer1::DataType::kBOOL, at::kBool},
242244
};
243245
return trt_at_type_map;
244246
}
@@ -249,15 +251,19 @@ const std::unordered_map<nvinfer1::DataType, at::ScalarType>& get_trt_aten_type_
249251
}
250252

251253
at::ScalarType toATenDType(nvinfer1::DataType t) {
252-
return get_trt_aten_type_map().at(t);
254+
auto trt_aten_type_map = get_trt_aten_type_map();
255+
TRTORCH_CHECK(trt_aten_type_map.find(t) != trt_aten_type_map.end(), "Unsupported TensorRT datatype");
256+
return trt_aten_type_map.at(t);
253257
}
254258

255259
const std::unordered_map<at::ScalarType, nvinfer1::DataType>& get_aten_trt_type_map() {
256260
return get_at_trt_type_map();
257261
}
258262

259263
nvinfer1::DataType toTRTDataType(at::ScalarType t) {
260-
return get_aten_trt_type_map().at(t);
264+
auto aten_trt_type_map = get_aten_trt_type_map();
265+
TRTORCH_CHECK(aten_trt_type_map.find(t) != aten_trt_type_map.end(), "Unsupported Aten datatype");
266+
return aten_trt_type_map.at(t);
261267
}
262268

263269
c10::optional<nvinfer1::DataType> toTRTDataType(caffe2::TypeMeta dtype) {

tests/util/util.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ bool checkRtol(const at::Tensor& diff, const std::vector<at::Tensor> inputs, flo
1717

1818
bool almostEqual(const at::Tensor& a, const at::Tensor& b, float threshold) {
1919
LOG_DEBUG(a << std::endl << b << std::endl);
20-
return checkRtol(a - b, {a, b}, threshold);
20+
auto a_float = a.toType(at::kFloat);
21+
auto b_float = b.toType(at::kFloat);
22+
return checkRtol(a_float - b_float, {a_float, b_float}, threshold);
2123
}
2224

2325
bool exactlyEqual(const at::Tensor& a, const at::Tensor& b) {

0 commit comments

Comments
 (0)