Skip to content

Commit f466c71

Browse files
committed
add version utils to enable valid version str comparison
1 parent d192192 commit f466c71

File tree

3 files changed

+56
-8
lines changed

3 files changed

+56
-8
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import pytest
2+
3+
import torch2trt.version_utils
4+
5+
def test_version_utils():
6+
7+
a = torch2trt.version_utils.Version("10.1")
8+
9+
assert a >= "10.1"
10+
assert a >= "10.0"
11+
assert a > "7.0"
12+
assert a < "11.0"
13+
assert a == "10.1"
14+
assert a <= "10.1"
15+
assert a <= "10.2"

torch2trt/torch2trt.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,10 @@
1919

2020
from .flattener import Flattener
2121
from .flatten_module import Flatten, Unflatten
22+
from .version_utils import trt_version, torch_version
2223
# UTILITY FUNCTIONS
2324

2425

25-
def trt_version():
26-
return trt.__version__
27-
28-
29-
def torch_version():
30-
return torch.__version__
31-
32-
3326
def torch_dtype_to_trt(dtype):
3427
if trt_version() >= '7.0' and dtype == torch.bool:
3528
return trt.bool

torch2trt/version_utils.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import packaging.version
2+
import tensorrt as trt
3+
import torch
4+
5+
6+
def trt_version():
7+
return trt.__version__
8+
9+
10+
def torch_version():
11+
return torch.__version__
12+
13+
14+
class Version(packaging.version.Version):
15+
16+
def __ge__(self, other):
17+
if isinstance(other, str):
18+
other = Version(other)
19+
return super().__ge__(other)
20+
21+
def __le__(self, other):
22+
if isinstance(other, str):
23+
other = Version(other)
24+
return super().__le__(other)
25+
26+
def __eq__(self, other):
27+
if isinstance(other, str):
28+
other = Version(other)
29+
return super().__eq__(other)
30+
31+
def __gt__(self, other):
32+
if isinstance(other, str):
33+
other = Version(other)
34+
return super().__gt__(other)
35+
36+
def __lt__(self, other):
37+
if isinstance(other, str):
38+
other = Version(other)
39+
return super().__lt__(other)
40+

0 commit comments

Comments
 (0)