33import torch
44import torchvision .models as models
55import copy
6+ from typing import Dict
67
78from model_test_case import ModelTestCase
89
@@ -351,6 +352,138 @@ def test_from_torch(self):
351352 self .assertEqual (device .device_type , torchtrt .DeviceType .GPU )
352353 self .assertEqual (device .gpu_id , 0 )
353354
355+ class TestInput (unittest .TestCase ):
356+
357+ def _verify_correctness (self , struct : torchtrt .Input , target : Dict ) -> bool :
358+ internal = struct ._to_internal ()
359+
360+ list_eq = lambda al , bl : all ([a == b for (a , b ) in zip (al , bl )])
361+
362+ eq = lambda a , b : a == b
363+
364+ def field_is_correct (field , equal_fn , a1 , a2 ):
365+ equal = equal_fn (a1 , a2 )
366+ if not equal :
367+ print ("\n Field {} is incorrect: {} != {}" .format (field , a1 , a2 ))
368+ return equal
369+
370+ min_ = field_is_correct ("min" , list_eq , internal .min , target ["min" ])
371+ opt_ = field_is_correct ("opt" , list_eq , internal .opt , target ["opt" ])
372+ max_ = field_is_correct ("max" , list_eq , internal .max , target ["max" ])
373+ is_dynamic_ = field_is_correct ("is_dynamic" , eq , internal .input_is_dynamic , target ["input_is_dynamic" ])
374+ explicit_set_dtype_ = field_is_correct ("explicit_dtype" , eq , internal ._explicit_set_dtype , target ["explicit_set_dtype" ])
375+ dtype_ = field_is_correct ("dtype" , eq , int (internal .dtype ), int (target ["dtype" ]))
376+ format_ = field_is_correct ("format" , eq , int (internal .format ), int (target ["format" ]))
377+
378+ return all ([min_ ,opt_ ,max_ ,is_dynamic_ ,explicit_set_dtype_ ,dtype_ ,format_ ])
379+
380+
381+ def test_infer_from_example_tensor (self ):
382+ shape = [1 , 3 , 255 , 255 ]
383+ target = {
384+ "min" : shape ,
385+ "opt" : shape ,
386+ "max" : shape ,
387+ "input_is_dynamic" : False ,
388+ "dtype" : torchtrt .dtype .half ,
389+ "format" : torchtrt .TensorFormat .contiguous ,
390+ "explicit_set_dtype" : True
391+ }
392+
393+ example_tensor = torch .randn (shape ).half ()
394+ i = torchtrt .Input ._from_tensor (example_tensor )
395+ self .assertTrue (self ._verify_correctness (i , target ))
396+
397+
398+ def test_static_shape (self ):
399+ shape = [1 , 3 , 255 , 255 ]
400+ target = {
401+ "min" : shape ,
402+ "opt" : shape ,
403+ "max" : shape ,
404+ "input_is_dynamic" : False ,
405+ "dtype" : torchtrt .dtype .unknown ,
406+ "format" : torchtrt .TensorFormat .contiguous ,
407+ "explicit_set_dtype" : False
408+ }
409+
410+ i = torchtrt .Input (shape )
411+ self .assertTrue (self ._verify_correctness (i , target ))
412+
413+ i = torchtrt .Input (tuple (shape ))
414+ self .assertTrue (self ._verify_correctness (i , target ))
415+
416+ i = torchtrt .Input (torch .randn (shape ).shape )
417+ self .assertTrue (self ._verify_correctness (i , target ))
418+
419+ i = torchtrt .Input (shape = shape )
420+ self .assertTrue (self ._verify_correctness (i , target ))
421+
422+ i = torchtrt .Input (shape = tuple (shape ))
423+ self .assertTrue (self ._verify_correctness (i , target ))
424+
425+ i = torchtrt .Input (shape = torch .randn (shape ).shape )
426+ self .assertTrue (self ._verify_correctness (i , target ))
427+
428+ def test_data_type (self ):
429+ shape = [1 , 3 , 255 , 255 ]
430+ target = {
431+ "min" : shape ,
432+ "opt" : shape ,
433+ "max" : shape ,
434+ "input_is_dynamic" : False ,
435+ "dtype" : torchtrt .dtype .half ,
436+ "format" : torchtrt .TensorFormat .contiguous ,
437+ "explicit_set_dtype" : True
438+ }
439+
440+ i = torchtrt .Input (shape , dtype = torchtrt .dtype .half )
441+ self .assertTrue (self ._verify_correctness (i , target ))
442+
443+ i = torchtrt .Input (shape , dtype = torch .half )
444+ self .assertTrue (self ._verify_correctness (i , target ))
445+
446+ def test_tensor_format (self ):
447+ shape = [1 , 3 , 255 , 255 ]
448+ target = {
449+ "min" : shape ,
450+ "opt" : shape ,
451+ "max" : shape ,
452+ "input_is_dynamic" : False ,
453+ "dtype" : torchtrt .dtype .unknown ,
454+ "format" : torchtrt .TensorFormat .channels_last ,
455+ "explicit_set_dtype" : False
456+ }
457+
458+ i = torchtrt .Input (shape , format = torchtrt .TensorFormat .channels_last )
459+ self .assertTrue (self ._verify_correctness (i , target ))
460+
461+ i = torchtrt .Input (shape , format = torch .channels_last )
462+ self .assertTrue (self ._verify_correctness (i , target ))
463+
464+ def test_dynamic_shape (self ):
465+ min_shape = [1 , 3 , 128 , 128 ]
466+ opt_shape = [1 , 3 , 256 , 256 ]
467+ max_shape = [1 , 3 , 512 , 512 ]
468+ target = {
469+ "min" : min_shape ,
470+ "opt" : opt_shape ,
471+ "max" : max_shape ,
472+ "input_is_dynamic" : True ,
473+ "dtype" : torchtrt .dtype .unknown ,
474+ "format" : torchtrt .TensorFormat .contiguous ,
475+ "explicit_set_dtype" : False
476+ }
477+
478+ i = torchtrt .Input (min_shape = min_shape , opt_shape = opt_shape , max_shape = max_shape )
479+ self .assertTrue (self ._verify_correctness (i , target ))
480+
481+ i = torchtrt .Input (min_shape = tuple (min_shape ), opt_shape = tuple (opt_shape ), max_shape = tuple (max_shape ))
482+ self .assertTrue (self ._verify_correctness (i , target ))
483+
484+ tensor_shape = lambda shape : torch .randn (shape ).shape
485+ i = torchtrt .Input (min_shape = tensor_shape (min_shape ), opt_shape = tensor_shape (opt_shape ), max_shape = tensor_shape (max_shape ))
486+ self .assertTrue (self ._verify_correctness (i , target ))
354487
355488def test_suite ():
356489 suite = unittest .TestSuite ()
@@ -371,6 +504,7 @@ def test_suite():
371504 TestModuleFallbackToTorch .parametrize (TestModuleFallbackToTorch , model = models .resnet18 (pretrained = True )))
372505 suite .addTest (unittest .makeSuite (TestCheckMethodOpSupport ))
373506 suite .addTest (unittest .makeSuite (TestDevice ))
507+ suite .addTest (unittest .makeSuite (TestInput ))
374508
375509 return suite
376510
0 commit comments