3
3
import torch
4
4
import torchvision .models as models
5
5
import copy
6
+ from typing import Dict
6
7
7
8
from model_test_case import ModelTestCase
8
9
@@ -351,6 +352,138 @@ def test_from_torch(self):
351
352
self .assertEqual (device .device_type , torchtrt .DeviceType .GPU )
352
353
self .assertEqual (device .gpu_id , 0 )
353
354
355
+ class TestInput (unittest .TestCase ):
356
+
357
+ def _verify_correctness (self , struct : torchtrt .Input , target : Dict ) -> bool :
358
+ internal = struct ._to_internal ()
359
+
360
+ list_eq = lambda al , bl : all ([a == b for (a , b ) in zip (al , bl )])
361
+
362
+ eq = lambda a , b : a == b
363
+
364
+ def field_is_correct (field , equal_fn , a1 , a2 ):
365
+ equal = equal_fn (a1 , a2 )
366
+ if not equal :
367
+ print ("\n Field {} is incorrect: {} != {}" .format (field , a1 , a2 ))
368
+ return equal
369
+
370
+ min_ = field_is_correct ("min" , list_eq , internal .min , target ["min" ])
371
+ opt_ = field_is_correct ("opt" , list_eq , internal .opt , target ["opt" ])
372
+ max_ = field_is_correct ("max" , list_eq , internal .max , target ["max" ])
373
+ is_dynamic_ = field_is_correct ("is_dynamic" , eq , internal .input_is_dynamic , target ["input_is_dynamic" ])
374
+ explicit_set_dtype_ = field_is_correct ("explicit_dtype" , eq , internal ._explicit_set_dtype , target ["explicit_set_dtype" ])
375
+ dtype_ = field_is_correct ("dtype" , eq , int (internal .dtype ), int (target ["dtype" ]))
376
+ format_ = field_is_correct ("format" , eq , int (internal .format ), int (target ["format" ]))
377
+
378
+ return all ([min_ ,opt_ ,max_ ,is_dynamic_ ,explicit_set_dtype_ ,dtype_ ,format_ ])
379
+
380
+
381
+ def test_infer_from_example_tensor (self ):
382
+ shape = [1 , 3 , 255 , 255 ]
383
+ target = {
384
+ "min" : shape ,
385
+ "opt" : shape ,
386
+ "max" : shape ,
387
+ "input_is_dynamic" : False ,
388
+ "dtype" : torchtrt .dtype .half ,
389
+ "format" : torchtrt .TensorFormat .contiguous ,
390
+ "explicit_set_dtype" : True
391
+ }
392
+
393
+ example_tensor = torch .randn (shape ).half ()
394
+ i = torchtrt .Input ._from_tensor (example_tensor )
395
+ self .assertTrue (self ._verify_correctness (i , target ))
396
+
397
+
398
+ def test_static_shape (self ):
399
+ shape = [1 , 3 , 255 , 255 ]
400
+ target = {
401
+ "min" : shape ,
402
+ "opt" : shape ,
403
+ "max" : shape ,
404
+ "input_is_dynamic" : False ,
405
+ "dtype" : torchtrt .dtype .unknown ,
406
+ "format" : torchtrt .TensorFormat .contiguous ,
407
+ "explicit_set_dtype" : False
408
+ }
409
+
410
+ i = torchtrt .Input (shape )
411
+ self .assertTrue (self ._verify_correctness (i , target ))
412
+
413
+ i = torchtrt .Input (tuple (shape ))
414
+ self .assertTrue (self ._verify_correctness (i , target ))
415
+
416
+ i = torchtrt .Input (torch .randn (shape ).shape )
417
+ self .assertTrue (self ._verify_correctness (i , target ))
418
+
419
+ i = torchtrt .Input (shape = shape )
420
+ self .assertTrue (self ._verify_correctness (i , target ))
421
+
422
+ i = torchtrt .Input (shape = tuple (shape ))
423
+ self .assertTrue (self ._verify_correctness (i , target ))
424
+
425
+ i = torchtrt .Input (shape = torch .randn (shape ).shape )
426
+ self .assertTrue (self ._verify_correctness (i , target ))
427
+
428
+ def test_data_type (self ):
429
+ shape = [1 , 3 , 255 , 255 ]
430
+ target = {
431
+ "min" : shape ,
432
+ "opt" : shape ,
433
+ "max" : shape ,
434
+ "input_is_dynamic" : False ,
435
+ "dtype" : torchtrt .dtype .half ,
436
+ "format" : torchtrt .TensorFormat .contiguous ,
437
+ "explicit_set_dtype" : True
438
+ }
439
+
440
+ i = torchtrt .Input (shape , dtype = torchtrt .dtype .half )
441
+ self .assertTrue (self ._verify_correctness (i , target ))
442
+
443
+ i = torchtrt .Input (shape , dtype = torch .half )
444
+ self .assertTrue (self ._verify_correctness (i , target ))
445
+
446
+ def test_tensor_format (self ):
447
+ shape = [1 , 3 , 255 , 255 ]
448
+ target = {
449
+ "min" : shape ,
450
+ "opt" : shape ,
451
+ "max" : shape ,
452
+ "input_is_dynamic" : False ,
453
+ "dtype" : torchtrt .dtype .unknown ,
454
+ "format" : torchtrt .TensorFormat .channels_last ,
455
+ "explicit_set_dtype" : False
456
+ }
457
+
458
+ i = torchtrt .Input (shape , format = torchtrt .TensorFormat .channels_last )
459
+ self .assertTrue (self ._verify_correctness (i , target ))
460
+
461
+ i = torchtrt .Input (shape , format = torch .channels_last )
462
+ self .assertTrue (self ._verify_correctness (i , target ))
463
+
464
+ def test_dynamic_shape (self ):
465
+ min_shape = [1 , 3 , 128 , 128 ]
466
+ opt_shape = [1 , 3 , 256 , 256 ]
467
+ max_shape = [1 , 3 , 512 , 512 ]
468
+ target = {
469
+ "min" : min_shape ,
470
+ "opt" : opt_shape ,
471
+ "max" : max_shape ,
472
+ "input_is_dynamic" : True ,
473
+ "dtype" : torchtrt .dtype .unknown ,
474
+ "format" : torchtrt .TensorFormat .contiguous ,
475
+ "explicit_set_dtype" : False
476
+ }
477
+
478
+ i = torchtrt .Input (min_shape = min_shape , opt_shape = opt_shape , max_shape = max_shape )
479
+ self .assertTrue (self ._verify_correctness (i , target ))
480
+
481
+ i = torchtrt .Input (min_shape = tuple (min_shape ), opt_shape = tuple (opt_shape ), max_shape = tuple (max_shape ))
482
+ self .assertTrue (self ._verify_correctness (i , target ))
483
+
484
+ tensor_shape = lambda shape : torch .randn (shape ).shape
485
+ i = torchtrt .Input (min_shape = tensor_shape (min_shape ), opt_shape = tensor_shape (opt_shape ), max_shape = tensor_shape (max_shape ))
486
+ self .assertTrue (self ._verify_correctness (i , target ))
354
487
355
488
def test_suite ():
356
489
suite = unittest .TestSuite ()
@@ -371,6 +504,7 @@ def test_suite():
371
504
TestModuleFallbackToTorch .parametrize (TestModuleFallbackToTorch , model = models .resnet18 (pretrained = True )))
372
505
suite .addTest (unittest .makeSuite (TestCheckMethodOpSupport ))
373
506
suite .addTest (unittest .makeSuite (TestDevice ))
507
+ suite .addTest (unittest .makeSuite (TestInput ))
374
508
375
509
return suite
376
510
0 commit comments