1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+ import operator
1415
1516from pytorch_lightning .utilities import _module_available
17+ from pytorch_lightning .utilities .imports import _compare_version
1618
1719
1820def test_module_exists ():
@@ -22,3 +24,24 @@ def test_module_exists():
2224 assert not _module_available ("torch.nn.asdf" )
2325 assert not _module_available ("asdf" )
2426 assert not _module_available ("asdf.bla.asdf" )
27+
28+
29+ def test_compare_version (monkeypatch ):
30+ from pytorch_lightning .utilities .imports import torch
31+
32+ monkeypatch .setattr (torch , "__version__" , "1.8.9" )
33+ assert not _compare_version ("torch" , operator .ge , "1.10.0" )
34+ assert _compare_version ("torch" , operator .lt , "1.10.0" )
35+
36+ monkeypatch .setattr (torch , "__version__" , "1.10.0.dev123" )
37+ assert _compare_version ("torch" , operator .ge , "1.10.0.dev123" )
38+ assert not _compare_version ("torch" , operator .ge , "1.10.0.dev124" )
39+
40+ assert _compare_version ("torch" , operator .ge , "1.10.0.dev123" , use_base_version = True )
41+ assert _compare_version ("torch" , operator .ge , "1.10.0.dev124" , use_base_version = True )
42+
43+ monkeypatch .setattr (torch , "__version__" , "1.10.0a0+0aef44c" ) # dev version before rc
44+ assert _compare_version ("torch" , operator .ge , "1.10.0.rc0" , use_base_version = True )
45+ assert not _compare_version ("torch" , operator .ge , "1.10.0.rc0" )
46+ assert _compare_version ("torch" , operator .ge , "1.10.0" , use_base_version = True )
47+ assert not _compare_version ("torch" , operator .ge , "1.10.0" )
0 commit comments