Skip to content

Commit 884a5ac

Browse files
committed
test: Add class test for torch_tensorrt.Input
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent dc572cb commit 884a5ac

File tree

5 files changed

+139
-5
lines changed

5 files changed

+139
-5
lines changed

py/torch_tensorrt/_Input.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def _parse_format(format: Any) -> _enums.TensorFormat:
214214
if format == torch.contiguous_format:
215215
return _enums.TensorFormat.contiguous
216216
elif format == torch.channels_last:
217-
return _enums.TensorFormat.channel_last
217+
return _enums.TensorFormat.channels_last
218218
else:
219219
raise ValueError(
220220
"Provided an unsupported tensor format (support: NHCW/contiguous_format, NHWC/channel_last)")

py/torch_tensorrt/csrc/tensorrt_classes.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ Device::Device(const core::runtime::CudaDevice& internal_dev) {
4949

5050
nvinfer1::TensorFormat toTRTTensorFormat(TensorFormat value) {
5151
switch (value) {
52-
case TensorFormat::kChannelLast:
52+
case TensorFormat::kChannelsLast:
5353
return nvinfer1::TensorFormat::kHWC;
5454
case TensorFormat::kContiguous:
5555
default:
@@ -61,7 +61,7 @@ std::string to_str(TensorFormat value) {
6161
switch (value) {
6262
case TensorFormat::kContiguous:
6363
return "Contiguous/Linear/NCHW";
64-
case TensorFormat::kChannelLast:
64+
case TensorFormat::kChannelsLast:
6565
return "Channel Last/NHWC";
6666
default:
6767
return "UNKNOWN";

py/torch_tensorrt/csrc/tensorrt_classes.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ enum class DataType : int8_t { kFloat, kHalf, kChar, kInt32, kBool, kUnknown };
3131
std::string to_str(DataType value);
3232
nvinfer1::DataType toTRTDataType(DataType value);
3333

34-
enum class TensorFormat : int8_t { kContiguous, kChannelLast };
34+
enum class TensorFormat : int8_t { kContiguous, kChannelsLast };
3535
std::string to_str(TensorFormat value);
3636
nvinfer1::TensorFormat toTRTTensorFormat(TensorFormat value);
3737

py/torch_tensorrt/csrc/torch_tensorrt_py.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ PYBIND11_MODULE(_C, m) {
204204

205205
py::enum_<TensorFormat>(m, "TensorFormat", "Enum to specifiy the memory layout of tensors")
206206
.value("contiguous", TensorFormat::kContiguous, "Contiguous memory layout (NCHW / Linear)")
207-
.value("channel_last", TensorFormat::kChannelLast, "Channel last memory layout (NHWC)")
207+
.value("channels_last", TensorFormat::kChannelsLast, "Channels last memory layout (NHWC)")
208208
.export_values();
209209

210210
py::enum_<nvinfer1::CalibrationAlgoType>(m, "CalibrationAlgo", py::module_local(), "Type of calibration algorithm")

tests/py/test_api.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44
import torchvision.models as models
55
import copy
6+
from typing import Dict
67

78
from model_test_case import ModelTestCase
89

@@ -351,6 +352,138 @@ def test_from_torch(self):
351352
self.assertEqual(device.device_type, torchtrt.DeviceType.GPU)
352353
self.assertEqual(device.gpu_id, 0)
353354

355+
class TestInput(unittest.TestCase):
356+
357+
def _verify_correctness(self, struct: torchtrt.Input, target: Dict) -> bool:
358+
internal = struct._to_internal()
359+
360+
list_eq = lambda al, bl: all([a == b for (a, b) in zip (al, bl)])
361+
362+
eq = lambda a, b : a == b
363+
364+
def field_is_correct(field, equal_fn, a1, a2):
365+
equal = equal_fn(a1, a2)
366+
if not equal:
367+
print("\nField {} is incorrect: {} != {}".format(field, a1, a2))
368+
return equal
369+
370+
min_ = field_is_correct("min", list_eq, internal.min, target["min"])
371+
opt_ = field_is_correct("opt", list_eq, internal.opt, target["opt"])
372+
max_ = field_is_correct("max", list_eq, internal.max, target["max"])
373+
is_dynamic_ = field_is_correct("is_dynamic", eq, internal.input_is_dynamic, target["input_is_dynamic"])
374+
explicit_set_dtype_ = field_is_correct("explicit_dtype", eq, internal._explicit_set_dtype, target["explicit_set_dtype"])
375+
dtype_ = field_is_correct("dtype", eq, int(internal.dtype), int(target["dtype"]))
376+
format_ = field_is_correct("format", eq, int(internal.format), int(target["format"]))
377+
378+
return all([min_,opt_,max_,is_dynamic_,explicit_set_dtype_,dtype_,format_])
379+
380+
381+
def test_infer_from_example_tensor(self):
382+
shape = [1, 3, 255, 255]
383+
target = {
384+
"min": shape,
385+
"opt": shape,
386+
"max": shape,
387+
"input_is_dynamic": False,
388+
"dtype": torchtrt.dtype.half,
389+
"format": torchtrt.TensorFormat.contiguous,
390+
"explicit_set_dtype": True
391+
}
392+
393+
example_tensor = torch.randn(shape).half()
394+
i = torchtrt.Input._from_tensor(example_tensor)
395+
self.assertTrue(self._verify_correctness(i, target))
396+
397+
398+
def test_static_shape(self):
399+
shape = [1, 3, 255, 255]
400+
target = {
401+
"min": shape,
402+
"opt": shape,
403+
"max": shape,
404+
"input_is_dynamic": False,
405+
"dtype": torchtrt.dtype.unknown,
406+
"format": torchtrt.TensorFormat.contiguous,
407+
"explicit_set_dtype": False
408+
}
409+
410+
i = torchtrt.Input(shape)
411+
self.assertTrue(self._verify_correctness(i, target))
412+
413+
i = torchtrt.Input(tuple(shape))
414+
self.assertTrue(self._verify_correctness(i, target))
415+
416+
i = torchtrt.Input(torch.randn(shape).shape)
417+
self.assertTrue(self._verify_correctness(i, target))
418+
419+
i = torchtrt.Input(shape=shape)
420+
self.assertTrue(self._verify_correctness(i, target))
421+
422+
i = torchtrt.Input(shape=tuple(shape))
423+
self.assertTrue(self._verify_correctness(i, target))
424+
425+
i = torchtrt.Input(shape=torch.randn(shape).shape)
426+
self.assertTrue(self._verify_correctness(i, target))
427+
428+
def test_data_type(self):
429+
shape = [1, 3, 255, 255]
430+
target = {
431+
"min": shape,
432+
"opt": shape,
433+
"max": shape,
434+
"input_is_dynamic": False,
435+
"dtype": torchtrt.dtype.half,
436+
"format": torchtrt.TensorFormat.contiguous,
437+
"explicit_set_dtype": True
438+
}
439+
440+
i = torchtrt.Input(shape, dtype=torchtrt.dtype.half)
441+
self.assertTrue(self._verify_correctness(i, target))
442+
443+
i = torchtrt.Input(shape, dtype=torch.half)
444+
self.assertTrue(self._verify_correctness(i, target))
445+
446+
def test_tensor_format(self):
447+
shape = [1, 3, 255, 255]
448+
target = {
449+
"min": shape,
450+
"opt": shape,
451+
"max": shape,
452+
"input_is_dynamic": False,
453+
"dtype": torchtrt.dtype.unknown,
454+
"format": torchtrt.TensorFormat.channels_last,
455+
"explicit_set_dtype": False
456+
}
457+
458+
i = torchtrt.Input(shape, format=torchtrt.TensorFormat.channels_last)
459+
self.assertTrue(self._verify_correctness(i, target))
460+
461+
i = torchtrt.Input(shape, format=torch.channels_last)
462+
self.assertTrue(self._verify_correctness(i, target))
463+
464+
def test_dynamic_shape(self):
465+
min_shape = [1, 3, 128, 128]
466+
opt_shape = [1, 3, 256, 256]
467+
max_shape = [1, 3, 512, 512]
468+
target = {
469+
"min": min_shape,
470+
"opt": opt_shape,
471+
"max": max_shape,
472+
"input_is_dynamic": True,
473+
"dtype": torchtrt.dtype.unknown,
474+
"format": torchtrt.TensorFormat.contiguous,
475+
"explicit_set_dtype": False
476+
}
477+
478+
i = torchtrt.Input(min_shape=min_shape, opt_shape=opt_shape, max_shape=max_shape)
479+
self.assertTrue(self._verify_correctness(i, target))
480+
481+
i = torchtrt.Input(min_shape=tuple(min_shape), opt_shape=tuple(opt_shape), max_shape=tuple(max_shape))
482+
self.assertTrue(self._verify_correctness(i, target))
483+
484+
tensor_shape = lambda shape: torch.randn(shape).shape
485+
i = torchtrt.Input(min_shape=tensor_shape(min_shape), opt_shape=tensor_shape(opt_shape), max_shape=tensor_shape(max_shape))
486+
self.assertTrue(self._verify_correctness(i, target))
354487

355488
def test_suite():
356489
suite = unittest.TestSuite()
@@ -371,6 +504,7 @@ def test_suite():
371504
TestModuleFallbackToTorch.parametrize(TestModuleFallbackToTorch, model=models.resnet18(pretrained=True)))
372505
suite.addTest(unittest.makeSuite(TestCheckMethodOpSupport))
373506
suite.addTest(unittest.makeSuite(TestDevice))
507+
suite.addTest(unittest.makeSuite(TestInput))
374508

375509
return suite
376510

0 commit comments

Comments
 (0)