@@ -12,35 +12,63 @@ def test_torch_tensorrt(model, inputs):
12
12
# fp32 test
13
13
with torch .inference_mode ():
14
14
ref_fp32 = model_ts (* inputs_ts )
15
- trt_ts_module = torch_tensorrt .compile (model_ts , inputs = inputs_ts , enabled_precisions = {torch .float32 })
15
+ trt_ts_module = torch_tensorrt .compile (
16
+ model_ts , inputs = inputs_ts , enabled_precisions = {torch .float32 }
17
+ )
16
18
result_fp32 = trt_ts_module (* inputs_ts )
17
- assert torch .nn .functional .cosine_similarity (ref_fp32 .flatten (), result_fp32 .flatten (), dim = 0 ) > 0.9999
19
+ assert (
20
+ torch .nn .functional .cosine_similarity (
21
+ ref_fp32 .flatten (), result_fp32 .flatten (), dim = 0
22
+ )
23
+ > 0.9999
24
+ )
18
25
# fp16 test
19
26
model_ts = model_ts .half ()
20
27
inputs_ts = [i .cuda ().half () for i in inputs_ts ]
21
28
with torch .inference_mode ():
22
29
ref_fp16 = model_ts (* inputs_ts )
23
- trt_ts_module = torch_tensorrt .compile (model_ts , inputs = inputs_ts , enabled_precisions = {torch .float16 })
30
+ trt_ts_module = torch_tensorrt .compile (
31
+ model_ts , inputs = inputs_ts , enabled_precisions = {torch .float16 }
32
+ )
24
33
result_fp16 = trt_ts_module (* inputs_ts )
25
- assert torch .nn .functional .cosine_similarity (ref_fp16 .flatten (), result_fp16 .flatten (), dim = 0 ) > 0.99
34
+ assert (
35
+ torch .nn .functional .cosine_similarity (
36
+ ref_fp16 .flatten (), result_fp16 .flatten (), dim = 0
37
+ )
38
+ > 0.99
39
+ )
26
40
27
41
# FX path
28
42
model_fx = copy .deepcopy (model )
29
43
inputs_fx = copy .deepcopy (inputs )
30
44
# fp32 test
31
45
with torch .inference_mode ():
32
46
ref_fp32 = model_fx (* inputs_fx )
33
- trt_fx_module = torch_tensorrt .compile (model_fx , ir = "fx" , inputs = inputs_fx , enabled_precisions = {torch .float32 })
47
+ trt_fx_module = torch_tensorrt .compile (
48
+ model_fx , ir = "fx" , inputs = inputs_fx , enabled_precisions = {torch .float32 }
49
+ )
34
50
result_fp32 = trt_fx_module (* inputs_fx )
35
- assert torch .nn .functional .cosine_similarity (ref_fp32 .flatten (), result_fp32 .flatten (), dim = 0 ) > 0.9999
51
+ assert (
52
+ torch .nn .functional .cosine_similarity (
53
+ ref_fp32 .flatten (), result_fp32 .flatten (), dim = 0
54
+ )
55
+ > 0.9999
56
+ )
36
57
# fp16 test
37
58
model_fx = model_fx .cuda ().half ()
38
59
inputs_fx = [i .cuda ().half () for i in inputs_fx ]
39
60
with torch .inference_mode ():
40
61
ref_fp16 = model_fx (* inputs_fx )
41
- trt_fx_module = torch_tensorrt .compile (model_fx , ir = "fx" , inputs = inputs_fx , enabled_precisions = {torch .float16 })
62
+ trt_fx_module = torch_tensorrt .compile (
63
+ model_fx , ir = "fx" , inputs = inputs_fx , enabled_precisions = {torch .float16 }
64
+ )
42
65
result_fp16 = trt_fx_module (* inputs_fx )
43
- assert torch .nn .functional .cosine_similarity (ref_fp16 .flatten (), result_fp16 .flatten (), dim = 0 ) > 0.99
66
+ assert (
67
+ torch .nn .functional .cosine_similarity (
68
+ ref_fp16 .flatten (), result_fp16 .flatten (), dim = 0
69
+ )
70
+ > 0.99
71
+ )
44
72
45
73
46
74
if __name__ == "__main__" :
0 commit comments